MMRM / app.py
rexera's picture
followup
cffdcce
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import BertTokenizer
from PIL import Image
import numpy as np
import os
# Local imports (demo directory structure)
from config import Config
from models.mmrm import MMRM, BaselineImageModel, BaselineLanguageModel
from evaluation.evaluate_real import RealWorldDataset
# --- Global Setup ---
config = Config()
device = torch.device(config.device if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained(config.roberta_model)
# --- Model Loading ---
def load_models():
"""Load models. Textual baseline is now loaded from HF Hub."""
models = {}
from transformers import AutoModelForMaskedLM
# 1. MMRM - Leave to None for now per user request
# print("Loading MMRM...")
# try:
# mmrm = MMRM(config).to(device)
# ckpt_path = config.get_phase2_checkpoint_path() # Defaults to phase2_mmrm_best.pt
# if os.path.exists(ckpt_path):
# checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
# mmrm.load_state_dict(checkpoint['model_state_dict'])
# mmrm.eval()
# models['mmrm'] = mmrm
# print(f"MMRM loaded from {ckpt_path}")
# else:
# print(f"MMRM checkpoint not found at {ckpt_path}")
# models['mmrm'] = None
# except Exception as e:
# print(f"Error loading MMRM: {e}")
# models['mmrm'] = None
models['mmrm'] = None
# 2. Textual Baseline (Fine-tuned RoBERTa) - MIGRATED TO HF HUB
print("Loading Textual Baseline from HF Hub (rexera/mmrm-roberta)...")
try:
# Since this is now in standard HF format (RobertaForMaskedLM)
repo_id = "rexera/mmrm-roberta"
lm_model = AutoModelForMaskedLM.from_pretrained(repo_id).to(device)
lm_model.eval()
models['text_baseline'] = lm_model
print(f"Textual Baseline loaded from {repo_id}")
except Exception as e:
print(f"Error loading Textual Baseline from HF: {e}")
models['text_baseline'] = None
# 3. Visual Baseline (ResNet) - Leave to None for now
# print("Loading Visual Baseline...")
# try:
# img_model = BaselineImageModel(config).to(device)
# ckpt_path = config.get_baseline_checkpoint_path('img')
# if os.path.exists(ckpt_path):
# checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
# img_model.load_state_dict(checkpoint['model_state_dict'])
# img_model.eval()
# models['visual_baseline'] = img_model
# print(f"Visual Baseline loaded from {ckpt_path}")
# else:
# print(f"Visual Baseline checkpoint not found at {ckpt_path}")
# models['visual_baseline'] = None
# except Exception as e:
# print(f"Error loading Visual Baseline: {e}")
# models['visual_baseline'] = None
models['visual_baseline'] = None
# 4. Zero-shot Baseline (Pre-trained GuwenBERT) - DIRECT FROM HF
print("Loading Zero-shot Baseline (GuwenBERT)...")
try:
# fine_tuned=False loads the standard MLM head from huggingface
zs_model = BaselineLanguageModel(config, fine_tuned=False).to(device)
zs_model.eval()
models['zero_shot'] = zs_model
print(f"Zero-shot Baseline loaded (Pre-trained {config.roberta_model})")
except Exception as e:
print(f"Error loading Zero-shot Baseline: {e}")
models['zero_shot'] = None
return models
# Load models globally (or lazily if needed, but global is fine for demo script)
MODELS = load_models()
# --- Data Loading ---
print("Loading Real World Data...")
try:
dataset = RealWorldDataset(config, tokenizer)
# Create a mapping for dropdown: "Index: Context preview..."
sample_options = []
for i in range(len(dataset)):
# Provide a snippet of context for the dropdown
context = dataset.samples[i]['context']
# Show first 30 chars
label = f"Sample {i}: {context[:40]}..."
sample_options.append((label, i))
except Exception as e:
print(f"Error loading dataset: {e}")
dataset = None
sample_options = []
# --- Helper Functions ---
def tensor_to_pil(tensor):
"""Convert tensor (1, 64, 64) float [0,1] to PIL Image."""
# tensor: [1, H, W]
if tensor.dim() == 4:
tensor = tensor.squeeze(0) # Remove batch dim if present
if tensor.dim() == 3:
tensor = tensor.squeeze(0) # Remove channel dim if 1 channel
arr = tensor.cpu().detach().numpy()
arr = (arr * 255).astype(np.uint8)
return Image.fromarray(arr, mode='L')
def format_top_k(logits, top_k=20):
"""Return list of (token, probability) tuples."""
probs = F.softmax(logits, dim=-1)
top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
results = []
for p, idx in zip(top_probs[0], top_indices[0]):
token = tokenizer.decode([idx.item()])
results.append((token, float(p.item())))
return results
def run_inference(sample_idx):
if dataset is None:
return None, "Dataset not loaded", None, {}
# Load sample
sample_idx = int(sample_idx) # ensure int
batch = dataset[sample_idx]
# Prepare inputs
input_ids = batch['input_ids'].unsqueeze(0).to(device)
attention_mask = batch['attention_mask'].unsqueeze(0).to(device)
mask_positions = batch['mask_positions'].unsqueeze(0).to(device)
damaged_images = batch['damaged_images'].unsqueeze(0).to(device)
# damaged_images shape: [1, num_masks, 1, 64, 64]
# Get original raw images for display
first_image_tensor = damaged_images[0, 0] # [1, 64, 64]
input_display_image = tensor_to_pil(first_image_tensor)
context_text = dataset.samples[sample_idx]['context']
ground_truth_labels = dataset.samples[sample_idx]['labels']
ground_truth_text = " ".join(ground_truth_labels)
# --- Inference ---
# 0. Zero-shot Baseline
zs_res = []
if MODELS['zero_shot']:
with torch.no_grad():
logits = MODELS['zero_shot'](input_ids, attention_mask, mask_positions)
# logits: [1, num_masks, vocab_size]
mask_logits = logits[:, 0, :]
zs_res = format_top_k(mask_logits)
else:
zs_res = [("Model not loaded", 0.0)]
# 1. Textual Baseline (Fine-tuned HF Model)
text_res = []
if MODELS['text_baseline']:
with torch.no_grad():
# Standard HF model returns MaskedLMOutput
outputs = MODELS['text_baseline'](input_ids=input_ids, attention_mask=attention_mask)
all_logits = outputs.logits # [batch, seq_len, vocab_size]
# Extract logits at mask positions
# input_ids/mask_positions: [1, num_masks]
batch_size, num_masks = mask_positions.shape
mask_logits = torch.gather(
all_logits, 1, mask_positions.unsqueeze(-1).expand(-1, -1, all_logits.size(-1))
) # [batch, num_masks, vocab_size]
# Take first mask for display
first_mask_logits = mask_logits[:, 0, :]
text_res = format_top_k(first_mask_logits)
else:
text_res = [("Model not loaded (HF Migration)", 0.0)]
# 2. Visual Baseline
visual_res = []
if MODELS['visual_baseline']:
with torch.no_grad():
logits = MODELS['visual_baseline'](damaged_images)
mask_logits = logits[:, 0, :]
visual_res = format_top_k(mask_logits)
else:
visual_res = [("Model not loaded (custom weight specific)", 0.0)]
# 3. MMRM
mmrm_res = []
restored_pil = None
if MODELS['mmrm']:
with torch.no_grad():
logits, restored_imgs = MODELS['mmrm'](
input_ids, attention_mask, mask_positions, damaged_images
)
# logits: [1, num_masks, vocab_size]
mask_logits = logits[:, 0, :]
mmrm_res = format_top_k(mask_logits)
# Restored Image (Intermediate)
# restored_imgs: [1, num_masks, 1, 64, 64]
restored_tensor = restored_imgs[0, 0]
restored_pil = tensor_to_pil(restored_tensor)
else:
mmrm_res = [("Model not loaded (custom weight specific)", 0.0)]
# Format raw results into a dictionary for State
raw_results = {
'zs': zs_res,
'text': text_res,
'visual': visual_res,
'mmrm': mmrm_res
}
return (
input_display_image,
f"Context: {context_text}\nGround Truth: {ground_truth_text}",
restored_pil,
raw_results
)
# --- Gradio UI ---
with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_size="sm")) as demo:
gr.Markdown("# MMRM: Multimodal Multitask Restoring Model")
gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.")
with gr.Row():
# --- Left Column: Inputs ---
with gr.Column(scale=1):
gr.Markdown("### Input Selection")
with gr.Row():
sample_dropdown = gr.Dropdown(
choices=[x[1] for x in sample_options],
type="value",
label="Select Sample",
container=False,
scale=3
)
sample_dropdown.choices = sample_options
run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60)
with gr.Row():
top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Top K Predictions")
with gr.Row():
input_image = gr.Image(label="Damaged Input", type="pil", height=250)
with gr.Row():
input_text = gr.Textbox(label="Context Info", lines=5)
# --- Right Column: Outputs ---
with gr.Column(scale=2):
gr.Markdown("### Model Predictions")
with gr.Row():
with gr.Column(min_width=80):
zs_output = gr.Label(num_top_classes=20, label="Zero-shot")
with gr.Column(min_width=80):
text_output = gr.Label(num_top_classes=20, label="Textual")
with gr.Column(min_width=80):
visual_output = gr.Label(num_top_classes=20, label="Visual")
with gr.Column(min_width=80):
mmrm_output = gr.Label(num_top_classes=20, label="MMRM")
with gr.Row():
with gr.Column():
gr.Markdown("### Visual Restoration")
restored_output = gr.Image(label="MMRM Output", type="pil", height=250)
# State to hold raw top-20 results for all models
# Structure: {"zs": [...], "text": [...], "visual": [...], "mmrm": [...]}
raw_results_state = gr.State()
def update_views(raw_results, k):
if not raw_results:
return {}, {}, {}, {}
k = int(k)
def slice_res(key):
# Take top k from list of tuples
full_list = raw_results.get(key, [])
return {term: score for term, score in full_list[:k]}
return (
slice_res('zs'),
slice_res('text'),
slice_res('visual'),
slice_res('mmrm')
)
# Event Chain
# 1. Run inference -> updates State and Images/Text
run_event = run_btn.click(
fn=run_inference,
inputs=[sample_dropdown],
outputs=[input_image, input_text, restored_output, raw_results_state]
)
# 2. Update Labels based on State and Slider (triggered by Run success OR Slider change)
run_event.success(
fn=update_views,
inputs=[raw_results_state, top_k_slider],
outputs=[zs_output, text_output, visual_output, mmrm_output]
)
top_k_slider.change(
fn=update_views,
inputs=[raw_results_state, top_k_slider],
outputs=[zs_output, text_output, visual_output, mmrm_output]
)
if __name__ == "__main__":
demo.launch()