nixie1981 commited on
Commit
00272b3
·
verified ·
1 Parent(s): 735d3cb

Upload modeling_conceptframemet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_conceptframemet.py +26 -20
modeling_conceptframemet.py CHANGED
@@ -288,33 +288,39 @@ class ConceptFrameMetForMetaphorDetection(nn.Module):
288
  }
289
 
290
  @classmethod
291
- def from_pretrained(cls, model_path, **kwargs):
292
  """Load model from pretrained checkpoint"""
293
  # Load config
294
  config_path = os.path.join(model_path, "config.json")
295
  with open(config_path, 'r') as f:
296
  config = json.load(f)
297
 
298
- # Initialize model
299
- model = cls(**kwargs)
300
-
301
- # Load weights
302
  weights_path = os.path.join(model_path, "pytorch_model.bin")
303
- if os.path.exists(weights_path):
304
- state_dict = torch.load(weights_path, map_location='cpu')
305
-
306
- # Filter out source_qa_model and frame_qa_model weights if model doesn't have them
307
- filtered_state_dict = {}
308
- for key, value in state_dict.items():
309
- # Skip source_qa_model weights if we don't have it
310
- if key.startswith('source_qa_model.') and not model.has_source_predictor:
311
- continue
312
- # Skip frame_qa_model weights if we don't have it
313
- if key.startswith('frame_qa_model.') and not model.has_frame_predictor:
314
- continue
315
- filtered_state_dict[key] = value
316
-
317
- model.load_state_dict(filtered_state_dict, strict=False)
 
 
 
 
 
 
 
 
 
318
 
319
  return model
320
 
 
288
  }
289
 
290
  @classmethod
291
+ def from_pretrained(cls, model_path, load_source_from_checkpoint=True, **kwargs):
292
  """Load model from pretrained checkpoint"""
293
  # Load config
294
  config_path = os.path.join(model_path, "config.json")
295
  with open(config_path, 'r') as f:
296
  config = json.load(f)
297
 
298
+ # Load weights first to check what's in checkpoint
 
 
 
299
  weights_path = os.path.join(model_path, "pytorch_model.bin")
300
+ state_dict = torch.load(weights_path, map_location='cpu')
301
+
302
+ # Check if checkpoint has source_qa_model
303
+ has_source_in_checkpoint = any(k.startswith('source_qa_model.') for k in state_dict.keys())
304
+
305
+ # Initialize model WITHOUT loading external models if they're in checkpoint
306
+ if has_source_in_checkpoint and load_source_from_checkpoint:
307
+ # Don't load source from HF, we'll load from checkpoint
308
+ model = cls(source_qa_model_name=None, **kwargs)
309
+ # Manually set has_source_predictor since weights are in checkpoint
310
+ model.has_source_predictor = True
311
+ else:
312
+ model = cls(**kwargs)
313
+
314
+ # Load ALL weights from checkpoint (including source_qa_model if present)
315
+ # Only filter out frame_qa_model if we don't have it
316
+ filtered_state_dict = {}
317
+ for key, value in state_dict.items():
318
+ # Only skip frame_qa_model weights if we don't have it
319
+ if key.startswith('frame_qa_model.') and not model.has_frame_predictor:
320
+ continue
321
+ filtered_state_dict[key] = value
322
+
323
+ model.load_state_dict(filtered_state_dict, strict=False)
324
 
325
  return model
326