import gradio as gr import torch import torch.nn.functional as F from transformers import BertTokenizer from PIL import Image import numpy as np import os # Local imports (demo directory structure) from config import Config from models.mmrm import MMRM, BaselineImageModel, BaselineLanguageModel from evaluation.evaluate_real import RealWorldDataset # --- Global Setup --- config = Config() device = torch.device(config.device if torch.cuda.is_available() else "cpu") tokenizer = BertTokenizer.from_pretrained(config.roberta_model) # --- Model Loading --- def load_models(): """Load models. Textual baseline is now loaded from HF Hub.""" models = {} from transformers import AutoModelForMaskedLM # 1. MMRM - Leave to None for now per user request # print("Loading MMRM...") # try: # mmrm = MMRM(config).to(device) # ckpt_path = config.get_phase2_checkpoint_path() # Defaults to phase2_mmrm_best.pt # if os.path.exists(ckpt_path): # checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) # mmrm.load_state_dict(checkpoint['model_state_dict']) # mmrm.eval() # models['mmrm'] = mmrm # print(f"MMRM loaded from {ckpt_path}") # else: # print(f"MMRM checkpoint not found at {ckpt_path}") # models['mmrm'] = None # except Exception as e: # print(f"Error loading MMRM: {e}") # models['mmrm'] = None models['mmrm'] = None # 2. Textual Baseline (Fine-tuned RoBERTa) - MIGRATED TO HF HUB print("Loading Textual Baseline from HF Hub (rexera/mmrm-roberta)...") try: # Since this is now in standard HF format (RobertaForMaskedLM) repo_id = "rexera/mmrm-roberta" lm_model = AutoModelForMaskedLM.from_pretrained(repo_id).to(device) lm_model.eval() models['text_baseline'] = lm_model print(f"Textual Baseline loaded from {repo_id}") except Exception as e: print(f"Error loading Textual Baseline from HF: {e}") models['text_baseline'] = None # 3. Visual Baseline (ResNet) - Leave to None for now # print("Loading Visual Baseline...") # try: # img_model = BaselineImageModel(config).to(device) # ckpt_path = config.get_baseline_checkpoint_path('img') # if os.path.exists(ckpt_path): # checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) # img_model.load_state_dict(checkpoint['model_state_dict']) # img_model.eval() # models['visual_baseline'] = img_model # print(f"Visual Baseline loaded from {ckpt_path}") # else: # print(f"Visual Baseline checkpoint not found at {ckpt_path}") # models['visual_baseline'] = None # except Exception as e: # print(f"Error loading Visual Baseline: {e}") # models['visual_baseline'] = None models['visual_baseline'] = None # 4. Zero-shot Baseline (Pre-trained GuwenBERT) - DIRECT FROM HF print("Loading Zero-shot Baseline (GuwenBERT)...") try: # fine_tuned=False loads the standard MLM head from huggingface zs_model = BaselineLanguageModel(config, fine_tuned=False).to(device) zs_model.eval() models['zero_shot'] = zs_model print(f"Zero-shot Baseline loaded (Pre-trained {config.roberta_model})") except Exception as e: print(f"Error loading Zero-shot Baseline: {e}") models['zero_shot'] = None return models # Load models globally (or lazily if needed, but global is fine for demo script) MODELS = load_models() # --- Data Loading --- print("Loading Real World Data...") try: dataset = RealWorldDataset(config, tokenizer) # Create a mapping for dropdown: "Index: Context preview..." sample_options = [] for i in range(len(dataset)): # Provide a snippet of context for the dropdown context = dataset.samples[i]['context'] # Show first 30 chars label = f"Sample {i}: {context[:40]}..." sample_options.append((label, i)) except Exception as e: print(f"Error loading dataset: {e}") dataset = None sample_options = [] # --- Helper Functions --- def tensor_to_pil(tensor): """Convert tensor (1, 64, 64) float [0,1] to PIL Image.""" # tensor: [1, H, W] if tensor.dim() == 4: tensor = tensor.squeeze(0) # Remove batch dim if present if tensor.dim() == 3: tensor = tensor.squeeze(0) # Remove channel dim if 1 channel arr = tensor.cpu().detach().numpy() arr = (arr * 255).astype(np.uint8) return Image.fromarray(arr, mode='L') def format_top_k(logits, top_k=20): """Return list of (token, probability) tuples.""" probs = F.softmax(logits, dim=-1) top_probs, top_indices = torch.topk(probs, top_k, dim=-1) results = [] for p, idx in zip(top_probs[0], top_indices[0]): token = tokenizer.decode([idx.item()]) results.append((token, float(p.item()))) return results def run_inference(sample_idx): if dataset is None: return None, "Dataset not loaded", None, {} # Load sample sample_idx = int(sample_idx) # ensure int batch = dataset[sample_idx] # Prepare inputs input_ids = batch['input_ids'].unsqueeze(0).to(device) attention_mask = batch['attention_mask'].unsqueeze(0).to(device) mask_positions = batch['mask_positions'].unsqueeze(0).to(device) damaged_images = batch['damaged_images'].unsqueeze(0).to(device) # damaged_images shape: [1, num_masks, 1, 64, 64] # Get original raw images for display first_image_tensor = damaged_images[0, 0] # [1, 64, 64] input_display_image = tensor_to_pil(first_image_tensor) context_text = dataset.samples[sample_idx]['context'] ground_truth_labels = dataset.samples[sample_idx]['labels'] ground_truth_text = " ".join(ground_truth_labels) # --- Inference --- # 0. Zero-shot Baseline zs_res = [] if MODELS['zero_shot']: with torch.no_grad(): logits = MODELS['zero_shot'](input_ids, attention_mask, mask_positions) # logits: [1, num_masks, vocab_size] mask_logits = logits[:, 0, :] zs_res = format_top_k(mask_logits) else: zs_res = [("Model not loaded", 0.0)] # 1. Textual Baseline (Fine-tuned HF Model) text_res = [] if MODELS['text_baseline']: with torch.no_grad(): # Standard HF model returns MaskedLMOutput outputs = MODELS['text_baseline'](input_ids=input_ids, attention_mask=attention_mask) all_logits = outputs.logits # [batch, seq_len, vocab_size] # Extract logits at mask positions # input_ids/mask_positions: [1, num_masks] batch_size, num_masks = mask_positions.shape mask_logits = torch.gather( all_logits, 1, mask_positions.unsqueeze(-1).expand(-1, -1, all_logits.size(-1)) ) # [batch, num_masks, vocab_size] # Take first mask for display first_mask_logits = mask_logits[:, 0, :] text_res = format_top_k(first_mask_logits) else: text_res = [("Model not loaded (HF Migration)", 0.0)] # 2. Visual Baseline visual_res = [] if MODELS['visual_baseline']: with torch.no_grad(): logits = MODELS['visual_baseline'](damaged_images) mask_logits = logits[:, 0, :] visual_res = format_top_k(mask_logits) else: visual_res = [("Model not loaded (custom weight specific)", 0.0)] # 3. MMRM mmrm_res = [] restored_pil = None if MODELS['mmrm']: with torch.no_grad(): logits, restored_imgs = MODELS['mmrm']( input_ids, attention_mask, mask_positions, damaged_images ) # logits: [1, num_masks, vocab_size] mask_logits = logits[:, 0, :] mmrm_res = format_top_k(mask_logits) # Restored Image (Intermediate) # restored_imgs: [1, num_masks, 1, 64, 64] restored_tensor = restored_imgs[0, 0] restored_pil = tensor_to_pil(restored_tensor) else: mmrm_res = [("Model not loaded (custom weight specific)", 0.0)] # Format raw results into a dictionary for State raw_results = { 'zs': zs_res, 'text': text_res, 'visual': visual_res, 'mmrm': mmrm_res } return ( input_display_image, f"Context: {context_text}\nGround Truth: {ground_truth_text}", restored_pil, raw_results ) # --- Gradio UI --- with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_size="sm")) as demo: gr.Markdown("# MMRM: Multimodal Multitask Restoring Model") gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.") with gr.Row(): # --- Left Column: Inputs --- with gr.Column(scale=1): gr.Markdown("### Input Selection") with gr.Row(): sample_dropdown = gr.Dropdown( choices=[x[1] for x in sample_options], type="value", label="Select Sample", container=False, scale=3 ) sample_dropdown.choices = sample_options run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60) with gr.Row(): top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Top K Predictions") with gr.Row(): input_image = gr.Image(label="Damaged Input", type="pil", height=250) with gr.Row(): input_text = gr.Textbox(label="Context Info", lines=5) # --- Right Column: Outputs --- with gr.Column(scale=2): gr.Markdown("### Model Predictions") with gr.Row(): with gr.Column(min_width=80): zs_output = gr.Label(num_top_classes=20, label="Zero-shot") with gr.Column(min_width=80): text_output = gr.Label(num_top_classes=20, label="Textual") with gr.Column(min_width=80): visual_output = gr.Label(num_top_classes=20, label="Visual") with gr.Column(min_width=80): mmrm_output = gr.Label(num_top_classes=20, label="MMRM") with gr.Row(): with gr.Column(): gr.Markdown("### Visual Restoration") restored_output = gr.Image(label="MMRM Output", type="pil", height=250) # State to hold raw top-20 results for all models # Structure: {"zs": [...], "text": [...], "visual": [...], "mmrm": [...]} raw_results_state = gr.State() def update_views(raw_results, k): if not raw_results: return {}, {}, {}, {} k = int(k) def slice_res(key): # Take top k from list of tuples full_list = raw_results.get(key, []) return {term: score for term, score in full_list[:k]} return ( slice_res('zs'), slice_res('text'), slice_res('visual'), slice_res('mmrm') ) # Event Chain # 1. Run inference -> updates State and Images/Text run_event = run_btn.click( fn=run_inference, inputs=[sample_dropdown], outputs=[input_image, input_text, restored_output, raw_results_state] ) # 2. Update Labels based on State and Slider (triggered by Run success OR Slider change) run_event.success( fn=update_views, inputs=[raw_results_state, top_k_slider], outputs=[zs_output, text_output, visual_output, mmrm_output] ) top_k_slider.change( fn=update_views, inputs=[raw_results_state, top_k_slider], outputs=[zs_output, text_output, visual_output, mmrm_output] ) if __name__ == "__main__": demo.launch()