ttoosi commited on
Commit
0b968b4
·
verified ·
1 Parent(s): f56511f

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -0
inference.py CHANGED
@@ -420,6 +420,9 @@ class GenerativeInferenceModel:
420
  for source_key, value in state_dict.items():
421
  if source_key.startswith('module.model.'):
422
  target_key = source_key[len('module.model.'):]
 
 
 
423
  resnet_state_dict[target_key] = value
424
 
425
  print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
 
420
  for source_key, value in state_dict.items():
421
  if source_key.startswith('module.model.'):
422
  target_key = source_key[len('module.model.'):]
423
+ # Some ckpts have 'module.model.model.<...>'; remove the extra 'model.' too
424
+ if target_key.startswith('model.'):
425
+ target_key = target_key[len('model.'):]
426
  resnet_state_dict[target_key] = value
427
 
428
  print(f"Extracted {len(resnet_state_dict)} parameters from module.model")