Upload modeling_conceptframemet.py with huggingface_hub
Browse files- modeling_conceptframemet.py +13 -1
modeling_conceptframemet.py
CHANGED
|
@@ -302,7 +302,19 @@ class ConceptFrameMetForMetaphorDetection(nn.Module):
|
|
| 302 |
weights_path = os.path.join(model_path, "pytorch_model.bin")
|
| 303 |
if os.path.exists(weights_path):
|
| 304 |
state_dict = torch.load(weights_path, map_location='cpu')
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
return model
|
| 308 |
|
|
|
|
| 302 |
weights_path = os.path.join(model_path, "pytorch_model.bin")
|
| 303 |
if os.path.exists(weights_path):
|
| 304 |
state_dict = torch.load(weights_path, map_location='cpu')
|
| 305 |
+
|
| 306 |
+
# Filter out source_qa_model and frame_qa_model weights if model doesn't have them
|
| 307 |
+
filtered_state_dict = {}
|
| 308 |
+
for key, value in state_dict.items():
|
| 309 |
+
# Skip source_qa_model weights if we don't have it
|
| 310 |
+
if key.startswith('source_qa_model.') and not model.has_source_predictor:
|
| 311 |
+
continue
|
| 312 |
+
# Skip frame_qa_model weights if we don't have it
|
| 313 |
+
if key.startswith('frame_qa_model.') and not model.has_frame_predictor:
|
| 314 |
+
continue
|
| 315 |
+
filtered_state_dict[key] = value
|
| 316 |
+
|
| 317 |
+
model.load_state_dict(filtered_state_dict, strict=False)
|
| 318 |
|
| 319 |
return model
|
| 320 |
|