rexera commited on
Commit
4796948
·
1 Parent(s): 7e84e35

quick patch

Browse files
Files changed (1) hide show
  1. app.py +31 -44
app.py CHANGED
@@ -17,10 +17,11 @@ tokenizer = BertTokenizer.from_pretrained(config.roberta_model)
17
 
18
  # --- Model Loading ---
19
  def load_models():
20
- """Load all three models: MMRM, Text Baseline, Image Baseline."""
21
  models = {}
 
22
 
23
- # 1. MMRM
24
  # print("Loading MMRM...")
25
  # try:
26
  # mmrm = MMRM(config).to(device)
@@ -38,44 +39,21 @@ def load_models():
38
  # print(f"Error loading MMRM: {e}")
39
  # models['mmrm'] = None
40
  models['mmrm'] = None
41
-
42
- # 2. Textual Baseline (Fine-tuned RoBERTa)
43
- print("Loading Textual Baseline...")
44
  try:
45
- # Phase 1 uses fine_tuned=True structure
46
- lm_model = BaselineLanguageModel(config, fine_tuned=True).to(device)
47
- # ckpt_path = config.get_phase1_checkpoint_path()
48
- ckpt_path = 'rexera/mmrm-roberta'
49
- if os.path.exists(ckpt_path):
50
- checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
51
-
52
- # Phase 1 saves 'model_state_dict' (encoder) and 'decoder_state_dict' (decoder) separately
53
- # We need to map them to BaselineLanguageModel's structure: 'context_encoder' and 'classifier'
54
- new_state_dict = {}
55
-
56
- # Map Context Encoder
57
- if 'model_state_dict' in checkpoint:
58
- for k, v in checkpoint['model_state_dict'].items():
59
- new_state_dict[f'context_encoder.{k}'] = v
60
-
61
- # Map Decoder (Classifier)
62
- if 'decoder_state_dict' in checkpoint:
63
- for k, v in checkpoint['decoder_state_dict'].items():
64
- new_state_dict[f'classifier.{k}'] = v
65
-
66
- lm_model.load_state_dict(new_state_dict)
67
- lm_model.eval()
68
- models['text_baseline'] = lm_model
69
- print(f"Text Baseline loaded from {ckpt_path}")
70
- else:
71
- print(f"Text Baseline checkpoint not found at {ckpt_path}")
72
- models['text_baseline'] = None
73
  except Exception as e:
74
- print(f"Error loading Text Baseline: {e}")
75
  models['text_baseline'] = None
76
- models['text_baseline'] = None
77
 
78
- # 3. Visual Baseline (ResNet)
79
  # print("Loading Visual Baseline...")
80
  # try:
81
  # img_model = BaselineImageModel(config).to(device)
@@ -87,7 +65,7 @@ def load_models():
87
  # models['visual_baseline'] = img_model
88
  # print(f"Visual Baseline loaded from {ckpt_path}")
89
  # else:
90
- # print(f"Visual Baseline checkpoint not found at {ckpt_path} (Expected if not deployed yet)")
91
  # models['visual_baseline'] = None
92
  # except Exception as e:
93
  # print(f"Error loading Visual Baseline: {e}")
@@ -189,17 +167,26 @@ def run_inference(sample_idx):
189
  else:
190
  zs_res = [("Model not loaded", 0.0)]
191
 
192
- # 1. Text Baseline
193
  text_res = []
194
  if MODELS['text_baseline']:
195
  with torch.no_grad():
196
- logits = MODELS['text_baseline'](input_ids, attention_mask, mask_positions)
197
- # logits: [1, num_masks, vocab_size]
198
- # Take first mask
199
- mask_logits = logits[:, 0, :]
200
- text_res = format_top_k(mask_logits)
 
 
 
 
 
 
 
 
 
201
  else:
202
- text_res = [("Model not loaded (custom weight specific)", 0.0)]
203
 
204
  # 2. Visual Baseline
205
  visual_res = []
 
17
 
18
  # --- Model Loading ---
19
  def load_models():
20
+ """Load models. Textual baseline is now loaded from HF Hub."""
21
  models = {}
22
+ from transformers import AutoModelForMaskedLM
23
 
24
+ # 1. MMRM - Leave to None for now per user request
25
  # print("Loading MMRM...")
26
  # try:
27
  # mmrm = MMRM(config).to(device)
 
39
  # print(f"Error loading MMRM: {e}")
40
  # models['mmrm'] = None
41
  models['mmrm'] = None
42
+
43
+ # 2. Textual Baseline (Fine-tuned RoBERTa) - MIGRATED TO HF HUB
44
+ print("Loading Textual Baseline from HF Hub (rexera/mmrm-roberta)...")
45
  try:
46
+ # Since this is now in standard HF format (RobertaForMaskedLM)
47
+ repo_id = "rexera/mmrm-roberta"
48
+ lm_model = AutoModelForMaskedLM.from_pretrained(repo_id).to(device)
49
+ lm_model.eval()
50
+ models['text_baseline'] = lm_model
51
+ print(f"Textual Baseline loaded from {repo_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
+ print(f"Error loading Textual Baseline from HF: {e}")
54
  models['text_baseline'] = None
 
55
 
56
+ # 3. Visual Baseline (ResNet) - Leave to None for now
57
  # print("Loading Visual Baseline...")
58
  # try:
59
  # img_model = BaselineImageModel(config).to(device)
 
65
  # models['visual_baseline'] = img_model
66
  # print(f"Visual Baseline loaded from {ckpt_path}")
67
  # else:
68
+ # print(f"Visual Baseline checkpoint not found at {ckpt_path}")
69
  # models['visual_baseline'] = None
70
  # except Exception as e:
71
  # print(f"Error loading Visual Baseline: {e}")
 
167
  else:
168
  zs_res = [("Model not loaded", 0.0)]
169
 
170
+ # 1. Textual Baseline (Fine-tuned HF Model)
171
  text_res = []
172
  if MODELS['text_baseline']:
173
  with torch.no_grad():
174
+ # Standard HF model returns MaskedLMOutput
175
+ outputs = MODELS['text_baseline'](input_ids=input_ids, attention_mask=attention_mask)
176
+ all_logits = outputs.logits # [batch, seq_len, vocab_size]
177
+
178
+ # Extract logits at mask positions
179
+ # input_ids/mask_positions: [1, num_masks]
180
+ batch_size, num_masks = mask_positions.shape
181
+ mask_logits = torch.gather(
182
+ all_logits, 1, mask_positions.unsqueeze(-1).expand(-1, -1, all_logits.size(-1))
183
+ ) # [batch, num_masks, vocab_size]
184
+
185
+ # Take first mask for display
186
+ first_mask_logits = mask_logits[:, 0, :]
187
+ text_res = format_top_k(first_mask_logits)
188
  else:
189
+ text_res = [("Model not loaded (HF Migration)", 0.0)]
190
 
191
  # 2. Visual Baseline
192
  visual_res = []