|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import BertTokenizer |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
from config import Config |
|
|
from models.mmrm import MMRM, BaselineImageModel, BaselineLanguageModel |
|
|
from evaluation.evaluate_real import RealWorldDataset |
|
|
|
|
|
|
|
|
config = Config() |
|
|
device = torch.device(config.device if torch.cuda.is_available() else "cpu") |
|
|
tokenizer = BertTokenizer.from_pretrained(config.roberta_model) |
|
|
|
|
|
|
|
|
def load_models(): |
|
|
"""Load models. Textual baseline is now loaded from HF Hub.""" |
|
|
models = {} |
|
|
from transformers import AutoModelForMaskedLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models['mmrm'] = None |
|
|
|
|
|
|
|
|
print("Loading Textual Baseline from HF Hub (rexera/mmrm-roberta)...") |
|
|
try: |
|
|
|
|
|
repo_id = "rexera/mmrm-roberta" |
|
|
lm_model = AutoModelForMaskedLM.from_pretrained(repo_id).to(device) |
|
|
lm_model.eval() |
|
|
models['text_baseline'] = lm_model |
|
|
print(f"Textual Baseline loaded from {repo_id}") |
|
|
except Exception as e: |
|
|
print(f"Error loading Textual Baseline from HF: {e}") |
|
|
models['text_baseline'] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models['visual_baseline'] = None |
|
|
|
|
|
|
|
|
print("Loading Zero-shot Baseline (GuwenBERT)...") |
|
|
try: |
|
|
|
|
|
zs_model = BaselineLanguageModel(config, fine_tuned=False).to(device) |
|
|
zs_model.eval() |
|
|
models['zero_shot'] = zs_model |
|
|
print(f"Zero-shot Baseline loaded (Pre-trained {config.roberta_model})") |
|
|
except Exception as e: |
|
|
print(f"Error loading Zero-shot Baseline: {e}") |
|
|
models['zero_shot'] = None |
|
|
|
|
|
return models |
|
|
|
|
|
|
|
|
MODELS = load_models() |
|
|
|
|
|
|
|
|
print("Loading Real World Data...") |
|
|
try: |
|
|
dataset = RealWorldDataset(config, tokenizer) |
|
|
|
|
|
sample_options = [] |
|
|
for i in range(len(dataset)): |
|
|
|
|
|
context = dataset.samples[i]['context'] |
|
|
|
|
|
label = f"Sample {i}: {context[:40]}..." |
|
|
sample_options.append((label, i)) |
|
|
except Exception as e: |
|
|
print(f"Error loading dataset: {e}") |
|
|
dataset = None |
|
|
sample_options = [] |
|
|
|
|
|
|
|
|
|
|
|
def tensor_to_pil(tensor): |
|
|
"""Convert tensor (1, 64, 64) float [0,1] to PIL Image.""" |
|
|
|
|
|
if tensor.dim() == 4: |
|
|
tensor = tensor.squeeze(0) |
|
|
if tensor.dim() == 3: |
|
|
tensor = tensor.squeeze(0) |
|
|
|
|
|
arr = tensor.cpu().detach().numpy() |
|
|
arr = (arr * 255).astype(np.uint8) |
|
|
return Image.fromarray(arr, mode='L') |
|
|
|
|
|
def format_top_k(logits, top_k=20): |
|
|
"""Return list of (token, probability) tuples.""" |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
top_probs, top_indices = torch.topk(probs, top_k, dim=-1) |
|
|
|
|
|
results = [] |
|
|
for p, idx in zip(top_probs[0], top_indices[0]): |
|
|
token = tokenizer.decode([idx.item()]) |
|
|
results.append((token, float(p.item()))) |
|
|
return results |
|
|
|
|
|
def run_inference(sample_idx): |
|
|
if dataset is None: |
|
|
return None, "Dataset not loaded", None, {} |
|
|
|
|
|
|
|
|
sample_idx = int(sample_idx) |
|
|
batch = dataset[sample_idx] |
|
|
|
|
|
|
|
|
input_ids = batch['input_ids'].unsqueeze(0).to(device) |
|
|
attention_mask = batch['attention_mask'].unsqueeze(0).to(device) |
|
|
mask_positions = batch['mask_positions'].unsqueeze(0).to(device) |
|
|
damaged_images = batch['damaged_images'].unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
|
|
|
first_image_tensor = damaged_images[0, 0] |
|
|
input_display_image = tensor_to_pil(first_image_tensor) |
|
|
|
|
|
context_text = dataset.samples[sample_idx]['context'] |
|
|
ground_truth_labels = dataset.samples[sample_idx]['labels'] |
|
|
ground_truth_text = " ".join(ground_truth_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
zs_res = [] |
|
|
if MODELS['zero_shot']: |
|
|
with torch.no_grad(): |
|
|
logits = MODELS['zero_shot'](input_ids, attention_mask, mask_positions) |
|
|
|
|
|
mask_logits = logits[:, 0, :] |
|
|
zs_res = format_top_k(mask_logits) |
|
|
else: |
|
|
zs_res = [("Model not loaded", 0.0)] |
|
|
|
|
|
|
|
|
text_res = [] |
|
|
if MODELS['text_baseline']: |
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = MODELS['text_baseline'](input_ids=input_ids, attention_mask=attention_mask) |
|
|
all_logits = outputs.logits |
|
|
|
|
|
|
|
|
|
|
|
batch_size, num_masks = mask_positions.shape |
|
|
mask_logits = torch.gather( |
|
|
all_logits, 1, mask_positions.unsqueeze(-1).expand(-1, -1, all_logits.size(-1)) |
|
|
) |
|
|
|
|
|
|
|
|
first_mask_logits = mask_logits[:, 0, :] |
|
|
text_res = format_top_k(first_mask_logits) |
|
|
else: |
|
|
text_res = [("Model not loaded (HF Migration)", 0.0)] |
|
|
|
|
|
|
|
|
visual_res = [] |
|
|
if MODELS['visual_baseline']: |
|
|
with torch.no_grad(): |
|
|
logits = MODELS['visual_baseline'](damaged_images) |
|
|
mask_logits = logits[:, 0, :] |
|
|
visual_res = format_top_k(mask_logits) |
|
|
else: |
|
|
visual_res = [("Model not loaded (custom weight specific)", 0.0)] |
|
|
|
|
|
|
|
|
mmrm_res = [] |
|
|
restored_pil = None |
|
|
if MODELS['mmrm']: |
|
|
with torch.no_grad(): |
|
|
logits, restored_imgs = MODELS['mmrm']( |
|
|
input_ids, attention_mask, mask_positions, damaged_images |
|
|
) |
|
|
|
|
|
mask_logits = logits[:, 0, :] |
|
|
mmrm_res = format_top_k(mask_logits) |
|
|
|
|
|
|
|
|
|
|
|
restored_tensor = restored_imgs[0, 0] |
|
|
restored_pil = tensor_to_pil(restored_tensor) |
|
|
else: |
|
|
mmrm_res = [("Model not loaded (custom weight specific)", 0.0)] |
|
|
|
|
|
|
|
|
raw_results = { |
|
|
'zs': zs_res, |
|
|
'text': text_res, |
|
|
'visual': visual_res, |
|
|
'mmrm': mmrm_res |
|
|
} |
|
|
|
|
|
return ( |
|
|
input_display_image, |
|
|
f"Context: {context_text}\nGround Truth: {ground_truth_text}", |
|
|
restored_pil, |
|
|
raw_results |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_size="sm")) as demo: |
|
|
gr.Markdown("# MMRM: Multimodal Multitask Restoring Model") |
|
|
gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Input Selection") |
|
|
with gr.Row(): |
|
|
sample_dropdown = gr.Dropdown( |
|
|
choices=[x[1] for x in sample_options], |
|
|
type="value", |
|
|
label="Select Sample", |
|
|
container=False, |
|
|
scale=3 |
|
|
) |
|
|
sample_dropdown.choices = sample_options |
|
|
run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60) |
|
|
|
|
|
with gr.Row(): |
|
|
top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Top K Predictions") |
|
|
|
|
|
with gr.Row(): |
|
|
input_image = gr.Image(label="Damaged Input", type="pil", height=250) |
|
|
|
|
|
with gr.Row(): |
|
|
input_text = gr.Textbox(label="Context Info", lines=5) |
|
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### Model Predictions") |
|
|
with gr.Row(): |
|
|
with gr.Column(min_width=80): |
|
|
zs_output = gr.Label(num_top_classes=20, label="Zero-shot") |
|
|
with gr.Column(min_width=80): |
|
|
text_output = gr.Label(num_top_classes=20, label="Textual") |
|
|
with gr.Column(min_width=80): |
|
|
visual_output = gr.Label(num_top_classes=20, label="Visual") |
|
|
with gr.Column(min_width=80): |
|
|
mmrm_output = gr.Label(num_top_classes=20, label="MMRM") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Visual Restoration") |
|
|
restored_output = gr.Image(label="MMRM Output", type="pil", height=250) |
|
|
|
|
|
|
|
|
|
|
|
raw_results_state = gr.State() |
|
|
|
|
|
def update_views(raw_results, k): |
|
|
if not raw_results: |
|
|
return {}, {}, {}, {} |
|
|
|
|
|
k = int(k) |
|
|
def slice_res(key): |
|
|
|
|
|
full_list = raw_results.get(key, []) |
|
|
return {term: score for term, score in full_list[:k]} |
|
|
|
|
|
return ( |
|
|
slice_res('zs'), |
|
|
slice_res('text'), |
|
|
slice_res('visual'), |
|
|
slice_res('mmrm') |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
run_event = run_btn.click( |
|
|
fn=run_inference, |
|
|
inputs=[sample_dropdown], |
|
|
outputs=[input_image, input_text, restored_output, raw_results_state] |
|
|
) |
|
|
|
|
|
|
|
|
run_event.success( |
|
|
fn=update_views, |
|
|
inputs=[raw_results_state, top_k_slider], |
|
|
outputs=[zs_output, text_output, visual_output, mmrm_output] |
|
|
) |
|
|
|
|
|
top_k_slider.change( |
|
|
fn=update_views, |
|
|
inputs=[raw_results_state, top_k_slider], |
|
|
outputs=[zs_output, text_output, visual_output, mmrm_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|