nixie1981 commited on
Commit
7f84c5d
·
verified ·
1 Parent(s): 00272b3

Upload modeling_conceptframemet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_conceptframemet.py +21 -25
modeling_conceptframemet.py CHANGED
@@ -288,39 +288,35 @@ class ConceptFrameMetForMetaphorDetection(nn.Module):
288
  }
289
 
290
  @classmethod
291
- def from_pretrained(cls, model_path, load_source_from_checkpoint=True, **kwargs):
292
  """Load model from pretrained checkpoint"""
293
- # Load config
294
- config_path = os.path.join(model_path, "config.json")
295
- with open(config_path, 'r') as f:
296
- config = json.load(f)
297
-
298
  # Load weights first to check what's in checkpoint
299
  weights_path = os.path.join(model_path, "pytorch_model.bin")
300
  state_dict = torch.load(weights_path, map_location='cpu')
301
 
302
- # Check if checkpoint has source_qa_model
303
  has_source_in_checkpoint = any(k.startswith('source_qa_model.') for k in state_dict.keys())
 
 
 
 
 
 
 
 
 
 
304
 
305
- # Initialize model WITHOUT loading external models if they're in checkpoint
306
- if has_source_in_checkpoint and load_source_from_checkpoint:
307
- # Don't load source from HF, we'll load from checkpoint
308
- model = cls(source_qa_model_name=None, **kwargs)
309
- # Manually set has_source_predictor since weights are in checkpoint
310
  model.has_source_predictor = True
311
- else:
312
- model = cls(**kwargs)
313
-
314
- # Load ALL weights from checkpoint (including source_qa_model if present)
315
- # Only filter out frame_qa_model if we don't have it
316
- filtered_state_dict = {}
317
- for key, value in state_dict.items():
318
- # Only skip frame_qa_model weights if we don't have it
319
- if key.startswith('frame_qa_model.') and not model.has_frame_predictor:
320
- continue
321
- filtered_state_dict[key] = value
322
-
323
- model.load_state_dict(filtered_state_dict, strict=False)
324
 
325
  return model
326
 
 
288
  }
289
 
290
  @classmethod
291
+ def from_pretrained(cls, model_path, **kwargs):
292
  """Load model from pretrained checkpoint"""
 
 
 
 
 
293
  # Load weights first to check what's in checkpoint
294
  weights_path = os.path.join(model_path, "pytorch_model.bin")
295
  state_dict = torch.load(weights_path, map_location='cpu')
296
 
297
+ # Check what's in the checkpoint
298
  has_source_in_checkpoint = any(k.startswith('source_qa_model.') for k in state_dict.keys())
299
+ has_frame_in_checkpoint = any(k.startswith('frame_qa_model.') for k in state_dict.keys())
300
+
301
+ # Initialize model:
302
+ # - Download frame_qa_model (nixie1981/sem_frames) - NOT in checkpoint
303
+ # - Don't download source_qa_model - IS in checkpoint
304
+ model = cls(
305
+ frame_qa_model_name="nixie1981/sem_frames", # Download - needed for frames!
306
+ source_qa_model_name=None, # Don't download - in checkpoint
307
+ **kwargs
308
+ )
309
 
310
+ # Manually set source flag since weights are in checkpoint
311
+ if has_source_in_checkpoint:
 
 
 
312
  model.has_source_predictor = True
313
+
314
+ # Load ALL weights from checkpoint (including source_qa_model)
315
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
316
+
317
+ print(f"Loaded {len(state_dict)} weights from checkpoint")
318
+ if missing:
319
+ print(f"Missing {len(missing)} keys")
 
 
 
 
 
 
320
 
321
  return model
322