quick patch
Browse files
app.py
CHANGED
|
@@ -17,10 +17,11 @@ tokenizer = BertTokenizer.from_pretrained(config.roberta_model)
|
|
| 17 |
|
| 18 |
# --- Model Loading ---
|
| 19 |
def load_models():
|
| 20 |
-
"""Load
|
| 21 |
models = {}
|
|
|
|
| 22 |
|
| 23 |
-
# 1. MMRM
|
| 24 |
# print("Loading MMRM...")
|
| 25 |
# try:
|
| 26 |
# mmrm = MMRM(config).to(device)
|
|
@@ -38,44 +39,21 @@ def load_models():
|
|
| 38 |
# print(f"Error loading MMRM: {e}")
|
| 39 |
# models['mmrm'] = None
|
| 40 |
models['mmrm'] = None
|
| 41 |
-
|
| 42 |
-
# 2. Textual Baseline (Fine-tuned RoBERTa)
|
| 43 |
-
print("Loading Textual Baseline...")
|
| 44 |
try:
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
# Phase 1 saves 'model_state_dict' (encoder) and 'decoder_state_dict' (decoder) separately
|
| 53 |
-
# We need to map them to BaselineLanguageModel's structure: 'context_encoder' and 'classifier'
|
| 54 |
-
new_state_dict = {}
|
| 55 |
-
|
| 56 |
-
# Map Context Encoder
|
| 57 |
-
if 'model_state_dict' in checkpoint:
|
| 58 |
-
for k, v in checkpoint['model_state_dict'].items():
|
| 59 |
-
new_state_dict[f'context_encoder.{k}'] = v
|
| 60 |
-
|
| 61 |
-
# Map Decoder (Classifier)
|
| 62 |
-
if 'decoder_state_dict' in checkpoint:
|
| 63 |
-
for k, v in checkpoint['decoder_state_dict'].items():
|
| 64 |
-
new_state_dict[f'classifier.{k}'] = v
|
| 65 |
-
|
| 66 |
-
lm_model.load_state_dict(new_state_dict)
|
| 67 |
-
lm_model.eval()
|
| 68 |
-
models['text_baseline'] = lm_model
|
| 69 |
-
print(f"Text Baseline loaded from {ckpt_path}")
|
| 70 |
-
else:
|
| 71 |
-
print(f"Text Baseline checkpoint not found at {ckpt_path}")
|
| 72 |
-
models['text_baseline'] = None
|
| 73 |
except Exception as e:
|
| 74 |
-
print(f"Error loading
|
| 75 |
models['text_baseline'] = None
|
| 76 |
-
models['text_baseline'] = None
|
| 77 |
|
| 78 |
-
# 3. Visual Baseline (ResNet)
|
| 79 |
# print("Loading Visual Baseline...")
|
| 80 |
# try:
|
| 81 |
# img_model = BaselineImageModel(config).to(device)
|
|
@@ -87,7 +65,7 @@ def load_models():
|
|
| 87 |
# models['visual_baseline'] = img_model
|
| 88 |
# print(f"Visual Baseline loaded from {ckpt_path}")
|
| 89 |
# else:
|
| 90 |
-
# print(f"Visual Baseline checkpoint not found at {ckpt_path}
|
| 91 |
# models['visual_baseline'] = None
|
| 92 |
# except Exception as e:
|
| 93 |
# print(f"Error loading Visual Baseline: {e}")
|
|
@@ -189,17 +167,26 @@ def run_inference(sample_idx):
|
|
| 189 |
else:
|
| 190 |
zs_res = [("Model not loaded", 0.0)]
|
| 191 |
|
| 192 |
-
# 1.
|
| 193 |
text_res = []
|
| 194 |
if MODELS['text_baseline']:
|
| 195 |
with torch.no_grad():
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
#
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
else:
|
| 202 |
-
text_res = [("Model not loaded (
|
| 203 |
|
| 204 |
# 2. Visual Baseline
|
| 205 |
visual_res = []
|
|
|
|
| 17 |
|
| 18 |
# --- Model Loading ---
|
| 19 |
def load_models():
|
| 20 |
+
"""Load models. Textual baseline is now loaded from HF Hub."""
|
| 21 |
models = {}
|
| 22 |
+
from transformers import AutoModelForMaskedLM
|
| 23 |
|
| 24 |
+
# 1. MMRM - Leave to None for now per user request
|
| 25 |
# print("Loading MMRM...")
|
| 26 |
# try:
|
| 27 |
# mmrm = MMRM(config).to(device)
|
|
|
|
| 39 |
# print(f"Error loading MMRM: {e}")
|
| 40 |
# models['mmrm'] = None
|
| 41 |
models['mmrm'] = None
|
| 42 |
+
|
| 43 |
+
# 2. Textual Baseline (Fine-tuned RoBERTa) - MIGRATED TO HF HUB
|
| 44 |
+
print("Loading Textual Baseline from HF Hub (rexera/mmrm-roberta)...")
|
| 45 |
try:
|
| 46 |
+
# Since this is now in standard HF format (RobertaForMaskedLM)
|
| 47 |
+
repo_id = "rexera/mmrm-roberta"
|
| 48 |
+
lm_model = AutoModelForMaskedLM.from_pretrained(repo_id).to(device)
|
| 49 |
+
lm_model.eval()
|
| 50 |
+
models['text_baseline'] = lm_model
|
| 51 |
+
print(f"Textual Baseline loaded from {repo_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
except Exception as e:
|
| 53 |
+
print(f"Error loading Textual Baseline from HF: {e}")
|
| 54 |
models['text_baseline'] = None
|
|
|
|
| 55 |
|
| 56 |
+
# 3. Visual Baseline (ResNet) - Leave to None for now
|
| 57 |
# print("Loading Visual Baseline...")
|
| 58 |
# try:
|
| 59 |
# img_model = BaselineImageModel(config).to(device)
|
|
|
|
| 65 |
# models['visual_baseline'] = img_model
|
| 66 |
# print(f"Visual Baseline loaded from {ckpt_path}")
|
| 67 |
# else:
|
| 68 |
+
# print(f"Visual Baseline checkpoint not found at {ckpt_path}")
|
| 69 |
# models['visual_baseline'] = None
|
| 70 |
# except Exception as e:
|
| 71 |
# print(f"Error loading Visual Baseline: {e}")
|
|
|
|
| 167 |
else:
|
| 168 |
zs_res = [("Model not loaded", 0.0)]
|
| 169 |
|
| 170 |
+
# 1. Textual Baseline (Fine-tuned HF Model)
|
| 171 |
text_res = []
|
| 172 |
if MODELS['text_baseline']:
|
| 173 |
with torch.no_grad():
|
| 174 |
+
# Standard HF model returns MaskedLMOutput
|
| 175 |
+
outputs = MODELS['text_baseline'](input_ids=input_ids, attention_mask=attention_mask)
|
| 176 |
+
all_logits = outputs.logits # [batch, seq_len, vocab_size]
|
| 177 |
+
|
| 178 |
+
# Extract logits at mask positions
|
| 179 |
+
# input_ids/mask_positions: [1, num_masks]
|
| 180 |
+
batch_size, num_masks = mask_positions.shape
|
| 181 |
+
mask_logits = torch.gather(
|
| 182 |
+
all_logits, 1, mask_positions.unsqueeze(-1).expand(-1, -1, all_logits.size(-1))
|
| 183 |
+
) # [batch, num_masks, vocab_size]
|
| 184 |
+
|
| 185 |
+
# Take first mask for display
|
| 186 |
+
first_mask_logits = mask_logits[:, 0, :]
|
| 187 |
+
text_res = format_top_k(first_mask_logits)
|
| 188 |
else:
|
| 189 |
+
text_res = [("Model not loaded (HF Migration)", 0.0)]
|
| 190 |
|
| 191 |
# 2. Visual Baseline
|
| 192 |
visual_res = []
|