Upload modeling_conceptframemet.py with huggingface_hub
Browse files- modeling_conceptframemet.py +26 -20
modeling_conceptframemet.py
CHANGED
|
@@ -288,33 +288,39 @@ class ConceptFrameMetForMetaphorDetection(nn.Module):
|
|
| 288 |
}
|
| 289 |
|
| 290 |
@classmethod
|
| 291 |
-
def from_pretrained(cls, model_path, **kwargs):
|
| 292 |
"""Load model from pretrained checkpoint"""
|
| 293 |
# Load config
|
| 294 |
config_path = os.path.join(model_path, "config.json")
|
| 295 |
with open(config_path, 'r') as f:
|
| 296 |
config = json.load(f)
|
| 297 |
|
| 298 |
-
#
|
| 299 |
-
model = cls(**kwargs)
|
| 300 |
-
|
| 301 |
-
# Load weights
|
| 302 |
weights_path = os.path.join(model_path, "pytorch_model.bin")
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
return model
|
| 320 |
|
|
|
|
| 288 |
}
|
| 289 |
|
| 290 |
@classmethod
|
| 291 |
+
def from_pretrained(cls, model_path, load_source_from_checkpoint=True, **kwargs):
|
| 292 |
"""Load model from pretrained checkpoint"""
|
| 293 |
# Load config
|
| 294 |
config_path = os.path.join(model_path, "config.json")
|
| 295 |
with open(config_path, 'r') as f:
|
| 296 |
config = json.load(f)
|
| 297 |
|
| 298 |
+
# Load weights first to check what's in checkpoint
|
|
|
|
|
|
|
|
|
|
| 299 |
weights_path = os.path.join(model_path, "pytorch_model.bin")
|
| 300 |
+
state_dict = torch.load(weights_path, map_location='cpu')
|
| 301 |
+
|
| 302 |
+
# Check if checkpoint has source_qa_model
|
| 303 |
+
has_source_in_checkpoint = any(k.startswith('source_qa_model.') for k in state_dict.keys())
|
| 304 |
+
|
| 305 |
+
# Initialize model WITHOUT loading external models if they're in checkpoint
|
| 306 |
+
if has_source_in_checkpoint and load_source_from_checkpoint:
|
| 307 |
+
# Don't load source from HF, we'll load from checkpoint
|
| 308 |
+
model = cls(source_qa_model_name=None, **kwargs)
|
| 309 |
+
# Manually set has_source_predictor since weights are in checkpoint
|
| 310 |
+
model.has_source_predictor = True
|
| 311 |
+
else:
|
| 312 |
+
model = cls(**kwargs)
|
| 313 |
+
|
| 314 |
+
# Load ALL weights from checkpoint (including source_qa_model if present)
|
| 315 |
+
# Only filter out frame_qa_model if we don't have it
|
| 316 |
+
filtered_state_dict = {}
|
| 317 |
+
for key, value in state_dict.items():
|
| 318 |
+
# Only skip frame_qa_model weights if we don't have it
|
| 319 |
+
if key.startswith('frame_qa_model.') and not model.has_frame_predictor:
|
| 320 |
+
continue
|
| 321 |
+
filtered_state_dict[key] = value
|
| 322 |
+
|
| 323 |
+
model.load_state_dict(filtered_state_dict, strict=False)
|
| 324 |
|
| 325 |
return model
|
| 326 |
|