Spaces:
Running
Running
ynuozhang
commited on
Commit
·
5ebc1ff
1
Parent(s):
6c7e97c
update models
Browse files- inference.py +24 -2
inference.py
CHANGED
|
@@ -238,8 +238,30 @@ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.devic
|
|
| 238 |
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 239 |
layers=int(params["layers"]), dropout=dropout)
|
| 240 |
elif model_name == "transformer":
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
else:
|
| 244 |
raise ValueError(f"Unknown NN model_name={model_name}")
|
| 245 |
|
|
|
|
| 238 |
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 239 |
layers=int(params["layers"]), dropout=dropout)
|
| 240 |
elif model_name == "transformer":
|
| 241 |
+
print(
|
| 242 |
+
f"[LOAD] {prop_key}/{mode} transformer params keys:",
|
| 243 |
+
list(params.keys())
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
d_model = (
|
| 247 |
+
params.get("d_model")
|
| 248 |
+
or params.get("hidden")
|
| 249 |
+
or params.get("hidden_dim")
|
| 250 |
+
)
|
| 251 |
+
if d_model is None:
|
| 252 |
+
raise KeyError(
|
| 253 |
+
f"Transformer checkpoint missing d_model/hidden. "
|
| 254 |
+
f"Available keys: {list(params.keys())}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
model = TransformerHead(
|
| 258 |
+
in_dim=in_dim,
|
| 259 |
+
d_model=int(d_model),
|
| 260 |
+
nhead=int(params["nhead"]),
|
| 261 |
+
layers=int(params["layers"]),
|
| 262 |
+
ff=int(params.get("ff", 4 * int(d_model))),
|
| 263 |
+
dropout=dropout
|
| 264 |
+
)
|
| 265 |
else:
|
| 266 |
raise ValueError(f"Unknown NN model_name={model_name}")
|
| 267 |
|