nixie1981 commited on
Commit
735d3cb
·
verified ·
1 Parent(s): 850ff6f

Upload modeling_conceptframemet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_conceptframemet.py +13 -1
modeling_conceptframemet.py CHANGED
@@ -302,7 +302,19 @@ class ConceptFrameMetForMetaphorDetection(nn.Module):
302
  weights_path = os.path.join(model_path, "pytorch_model.bin")
303
  if os.path.exists(weights_path):
304
  state_dict = torch.load(weights_path, map_location='cpu')
305
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  return model
308
 
 
302
  weights_path = os.path.join(model_path, "pytorch_model.bin")
303
  if os.path.exists(weights_path):
304
  state_dict = torch.load(weights_path, map_location='cpu')
305
+
306
+ # Filter out source_qa_model and frame_qa_model weights if model doesn't have them
307
+ filtered_state_dict = {}
308
+ for key, value in state_dict.items():
309
+ # Skip source_qa_model weights if we don't have it
310
+ if key.startswith('source_qa_model.') and not model.has_source_predictor:
311
+ continue
312
+ # Skip frame_qa_model weights if we don't have it
313
+ if key.startswith('frame_qa_model.') and not model.has_frame_predictor:
314
+ continue
315
+ filtered_state_dict[key] = value
316
+
317
+ model.load_state_dict(filtered_state_dict, strict=False)
318
 
319
  return model
320