Timo commited on
Commit
cc0b700
·
1 Parent(s): 8c118f9
Files changed (1) hide show
  1. src/draft_model.py +5 -1
src/draft_model.py CHANGED
@@ -30,8 +30,11 @@ class DraftModel:
30
  repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
31
  )
32
 
 
 
 
33
  # ---- load network ---------------------------------------------------
34
- cfg = open(cfg_path, "r")
35
 
36
  self.net = MLP_CrossAttention(**cfg).to(self.device)
37
  self.net.load_state_dict(weight_path, map_location=self.device)
@@ -42,6 +45,7 @@ class DraftModel:
42
  hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
43
  add_nontransformed=True
44
  )
 
45
 
46
  # --------------------------------------------------------------------- #
47
  # Public API expected by streamlit_app.py #
 
30
  repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
31
  )
32
 
33
+ with open(cfg_path, "r") as f:
34
+ cfg = json.load(f)
35
+
36
  # ---- load network ---------------------------------------------------
37
+
38
 
39
  self.net = MLP_CrossAttention(**cfg).to(self.device)
40
  self.net.load_state_dict(weight_path, map_location=self.device)
 
45
  hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
46
  add_nontransformed=True
47
  )
48
+ self.emb_size = next(iter(self.embed_dict.values())).shape[0]
49
 
50
  # --------------------------------------------------------------------- #
51
  # Public API expected by streamlit_app.py #