ynuozhang commited on
Commit
5ebc1ff
·
1 Parent(s): 6c7e97c

update models

Browse files
Files changed (1) hide show
  1. inference.py +24 -2
inference.py CHANGED
@@ -238,8 +238,30 @@ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.devic
238
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
239
  layers=int(params["layers"]), dropout=dropout)
240
  elif model_name == "transformer":
241
- model = TransformerHead(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]),
242
- layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  else:
244
  raise ValueError(f"Unknown NN model_name={model_name}")
245
 
 
238
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
239
  layers=int(params["layers"]), dropout=dropout)
240
  elif model_name == "transformer":
241
+ print(
242
+ f"[LOAD] {prop_key}/{mode} transformer params keys:",
243
+ list(params.keys())
244
+ )
245
+
246
+ d_model = (
247
+ params.get("d_model")
248
+ or params.get("hidden")
249
+ or params.get("hidden_dim")
250
+ )
251
+ if d_model is None:
252
+ raise KeyError(
253
+ f"Transformer checkpoint missing d_model/hidden. "
254
+ f"Available keys: {list(params.keys())}"
255
+ )
256
+
257
+ model = TransformerHead(
258
+ in_dim=in_dim,
259
+ d_model=int(d_model),
260
+ nhead=int(params["nhead"]),
261
+ layers=int(params["layers"]),
262
+ ff=int(params.get("ff", 4 * int(d_model))),
263
+ dropout=dropout
264
+ )
265
  else:
266
  raise ValueError(f"Unknown NN model_name={model_name}")
267