Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +3 -0
inference.py
CHANGED
|
@@ -420,6 +420,9 @@ class GenerativeInferenceModel:
|
|
| 420 |
for source_key, value in state_dict.items():
|
| 421 |
if source_key.startswith('module.model.'):
|
| 422 |
target_key = source_key[len('module.model.'):]
|
|
|
|
|
|
|
|
|
|
| 423 |
resnet_state_dict[target_key] = value
|
| 424 |
|
| 425 |
print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
|
|
|
|
| 420 |
for source_key, value in state_dict.items():
|
| 421 |
if source_key.startswith('module.model.'):
|
| 422 |
target_key = source_key[len('module.model.'):]
|
| 423 |
+
# Some ckpts have 'module.model.model.<...>'; remove the extra 'model.' too
|
| 424 |
+
if target_key.startswith('model.'):
|
| 425 |
+
target_key = target_key[len('model.'):]
|
| 426 |
resnet_state_dict[target_key] = value
|
| 427 |
|
| 428 |
print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
|