rexera commited on
Commit
7e84e35
·
1 Parent(s): 87224ba
Files changed (1) hide show
  1. app.py +33 -32
app.py CHANGED
@@ -40,38 +40,39 @@ def load_models():
40
  models['mmrm'] = None
41
 
42
  # 2. Textual Baseline (Fine-tuned RoBERTa)
43
- # print("Loading Textual Baseline...")
44
- # try:
45
- # # Phase 1 uses fine_tuned=True structure
46
- # lm_model = BaselineLanguageModel(config, fine_tuned=True).to(device)
47
- # ckpt_path = config.get_phase1_checkpoint_path()
48
- # if os.path.exists(ckpt_path):
49
- # checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
50
- #
51
- # # Phase 1 saves 'model_state_dict' (encoder) and 'decoder_state_dict' (decoder) separately
52
- # # We need to map them to BaselineLanguageModel's structure: 'context_encoder' and 'classifier'
53
- # new_state_dict = {}
54
- #
55
- # # Map Context Encoder
56
- # if 'model_state_dict' in checkpoint:
57
- # for k, v in checkpoint['model_state_dict'].items():
58
- # new_state_dict[f'context_encoder.{k}'] = v
59
- #
60
- # # Map Decoder (Classifier)
61
- # if 'decoder_state_dict' in checkpoint:
62
- # for k, v in checkpoint['decoder_state_dict'].items():
63
- # new_state_dict[f'classifier.{k}'] = v
64
- #
65
- # lm_model.load_state_dict(new_state_dict)
66
- # lm_model.eval()
67
- # models['text_baseline'] = lm_model
68
- # print(f"Text Baseline loaded from {ckpt_path}")
69
- # else:
70
- # print(f"Text Baseline checkpoint not found at {ckpt_path}")
71
- # models['text_baseline'] = None
72
- # except Exception as e:
73
- # print(f"Error loading Text Baseline: {e}")
74
- # models['text_baseline'] = None
 
75
  models['text_baseline'] = None
76
 
77
  # 3. Visual Baseline (ResNet)
 
40
  models['mmrm'] = None
41
 
42
  # 2. Textual Baseline (Fine-tuned RoBERTa)
43
+ print("Loading Textual Baseline...")
44
+ try:
45
+ # Phase 1 uses fine_tuned=True structure
46
+ lm_model = BaselineLanguageModel(config, fine_tuned=True).to(device)
47
+ # ckpt_path = config.get_phase1_checkpoint_path()
48
+ ckpt_path = 'rexera/mmrm-roberta'
49
+ if os.path.exists(ckpt_path):
50
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
51
+
52
+ # Phase 1 saves 'model_state_dict' (encoder) and 'decoder_state_dict' (decoder) separately
53
+ # We need to map them to BaselineLanguageModel's structure: 'context_encoder' and 'classifier'
54
+ new_state_dict = {}
55
+
56
+ # Map Context Encoder
57
+ if 'model_state_dict' in checkpoint:
58
+ for k, v in checkpoint['model_state_dict'].items():
59
+ new_state_dict[f'context_encoder.{k}'] = v
60
+
61
+ # Map Decoder (Classifier)
62
+ if 'decoder_state_dict' in checkpoint:
63
+ for k, v in checkpoint['decoder_state_dict'].items():
64
+ new_state_dict[f'classifier.{k}'] = v
65
+
66
+ lm_model.load_state_dict(new_state_dict)
67
+ lm_model.eval()
68
+ models['text_baseline'] = lm_model
69
+ print(f"Text Baseline loaded from {ckpt_path}")
70
+ else:
71
+ print(f"Text Baseline checkpoint not found at {ckpt_path}")
72
+ models['text_baseline'] = None
73
+ except Exception as e:
74
+ print(f"Error loading Text Baseline: {e}")
75
+ models['text_baseline'] = None
76
  models['text_baseline'] = None
77
 
78
  # 3. Visual Baseline (ResNet)