Spaces:
Sleeping
Sleeping
Timo commited on
Commit ·
cc0b700
1
Parent(s): 8c118f9
OK
Browse files- src/draft_model.py +5 -1
src/draft_model.py
CHANGED
|
@@ -30,8 +30,11 @@ class DraftModel:
|
|
| 30 |
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
|
| 31 |
)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
# ---- load network ---------------------------------------------------
|
| 34 |
-
|
| 35 |
|
| 36 |
self.net = MLP_CrossAttention(**cfg).to(self.device)
|
| 37 |
self.net.load_state_dict(weight_path, map_location=self.device)
|
|
@@ -42,6 +45,7 @@ class DraftModel:
|
|
| 42 |
hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
|
| 43 |
add_nontransformed=True
|
| 44 |
)
|
|
|
|
| 45 |
|
| 46 |
# --------------------------------------------------------------------- #
|
| 47 |
# Public API expected by streamlit_app.py #
|
|
|
|
| 30 |
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
|
| 31 |
)
|
| 32 |
|
| 33 |
+
with open(cfg_path, "r") as f:
|
| 34 |
+
cfg = json.load(f)
|
| 35 |
+
|
| 36 |
# ---- load network ---------------------------------------------------
|
| 37 |
+
|
| 38 |
|
| 39 |
self.net = MLP_CrossAttention(**cfg).to(self.device)
|
| 40 |
self.net.load_state_dict(weight_path, map_location=self.device)
|
|
|
|
| 45 |
hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
|
| 46 |
add_nontransformed=True
|
| 47 |
)
|
| 48 |
+
self.emb_size = next(iter(self.embed_dict.values())).shape[0]
|
| 49 |
|
| 50 |
# --------------------------------------------------------------------- #
|
| 51 |
# Public API expected by streamlit_app.py #
|