AlgoX commited on
Commit
a7dd184
·
1 Parent(s): ead1f5d
Files changed (2) hide show
  1. app.py +3 -3
  2. requirements.txt +2 -0
app.py CHANGED
@@ -688,7 +688,7 @@ hawk_model = HawkPredictor(
688
  conv_kernel_size=hawk_config["conv_kernel_size"],
689
  dropout=hawk_config["dropout"]
690
  )
691
- hawk_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "hawk_best_model.pt"), map_location=device)['model_state_dict'])
692
  hawk_model.to(device)
693
  hawk_model.eval()
694
  models["hawk"] = hawk_model
@@ -706,7 +706,7 @@ mamba_model = Mamba2Predictor(
706
  conv_kernel_size=mamba_config["conv_kernel_size"],
707
  dropout=mamba_config["dropout"]
708
  )
709
- mamba_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "mamba_best_model.pt"), map_location=device)['model_state_dict'])
710
  mamba_model.to(device)
711
  mamba_model.eval()
712
  models["mamba"] = mamba_model
@@ -724,7 +724,7 @@ xlstm_model = xLSTMPredictor(
724
  dropout=xlstm_config["dropout"],
725
  expand_factor=xlstm_config["expand_factor"]
726
  )
727
- xlstm_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "xlstm_best_model.pt"), map_location=device)['model_state_dict'])
728
  xlstm_model.to(device)
729
  xlstm_model.eval()
730
  models["xlstm"] = xlstm_model
 
688
  conv_kernel_size=hawk_config["conv_kernel_size"],
689
  dropout=hawk_config["dropout"]
690
  )
691
+ hawk_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "hawk_best_model.pt"), map_location=device, weights_only=False)['model_state_dict'])
692
  hawk_model.to(device)
693
  hawk_model.eval()
694
  models["hawk"] = hawk_model
 
706
  conv_kernel_size=mamba_config["conv_kernel_size"],
707
  dropout=mamba_config["dropout"]
708
  )
709
+ mamba_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "mamba_best_model.pt"), map_location=device, weights_only=False)['model_state_dict'])
710
  mamba_model.to(device)
711
  mamba_model.eval()
712
  models["mamba"] = mamba_model
 
724
  dropout=xlstm_config["dropout"],
725
  expand_factor=xlstm_config["expand_factor"]
726
  )
727
+ xlstm_model.load_state_dict(torch.load(os.path.join(MODELS_DIR, "xlstm_best_model.pt"), map_location=device, weights_only=False)['model_state_dict'])
728
  xlstm_model.to(device)
729
  xlstm_model.eval()
730
  models["xlstm"] = xlstm_model
requirements.txt CHANGED
@@ -12,3 +12,5 @@ lightgbm
12
  gradio==5.43.1
13
  huggingface-hub==0.36.0
14
  transformers==4.57.1
 
 
 
12
  gradio==5.43.1
13
  huggingface-hub==0.36.0
14
  transformers==4.57.1
15
+ matplotlib
16
+ safetensors