File size: 12,239 Bytes
c527d7d
87224ba
 
 
 
 
 
 
 
 
 
c527d7d
87224ba
 
 
 
c527d7d
87224ba
 
4796948
87224ba
4796948
87224ba
4796948
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4796948
 
 
7e84e35
4796948
 
 
 
 
 
7e84e35
4796948
7e84e35
87224ba
4796948
87224ba
 
 
 
 
 
 
 
 
 
 
4796948
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cffdcce
87224ba
 
 
 
 
 
 
 
 
 
 
 
cffdcce
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4796948
87224ba
 
 
4796948
 
 
 
 
 
 
 
 
 
 
 
 
 
87224ba
4796948
87224ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cffdcce
 
 
 
 
 
 
87224ba
 
 
 
cffdcce
 
87224ba
 
 
 
 
0a7da3e
87224ba
0a7da3e
87224ba
 
cffdcce
0a7da3e
 
 
 
cffdcce
0a7da3e
 
 
 
 
 
 
87224ba
cffdcce
 
 
0a7da3e
 
87224ba
0a7da3e
 
 
 
 
 
 
 
cffdcce
0a7da3e
cffdcce
0a7da3e
cffdcce
0a7da3e
cffdcce
87224ba
0a7da3e
 
 
 
cffdcce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87224ba
cffdcce
 
 
87224ba
 
cffdcce
 
 
 
 
 
 
 
 
 
 
 
 
 
87224ba
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import BertTokenizer
from PIL import Image
import numpy as np
import os
# Local imports (demo directory structure)
from config import Config
from models.mmrm import MMRM, BaselineImageModel, BaselineLanguageModel
from evaluation.evaluate_real import RealWorldDataset

# --- Global Setup ---
config = Config()
device = torch.device(config.device if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained(config.roberta_model)

# --- Model Loading ---
def load_models():
    """Load models. Textual baseline is now loaded from HF Hub."""
    models = {}
    from transformers import AutoModelForMaskedLM
    
    # 1. MMRM - Leave to None for now per user request
    # print("Loading MMRM...")
    # try:
    #     mmrm = MMRM(config).to(device)
    #     ckpt_path = config.get_phase2_checkpoint_path() # Defaults to phase2_mmrm_best.pt
    #     if os.path.exists(ckpt_path):
    #         checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
    #         mmrm.load_state_dict(checkpoint['model_state_dict'])
    #         mmrm.eval()
    #         models['mmrm'] = mmrm
    #         print(f"MMRM loaded from {ckpt_path}")
    #     else:
    #         print(f"MMRM checkpoint not found at {ckpt_path}")
    #         models['mmrm'] = None
    # except Exception as e:
    #     print(f"Error loading MMRM: {e}")
    #     models['mmrm'] = None
    models['mmrm'] = None
    
    # 2. Textual Baseline (Fine-tuned RoBERTa) - MIGRATED TO HF HUB
    print("Loading Textual Baseline from HF Hub (rexera/mmrm-roberta)...")
    try:
        # Since this is now in standard HF format (RobertaForMaskedLM)
        repo_id = "rexera/mmrm-roberta"
        lm_model = AutoModelForMaskedLM.from_pretrained(repo_id).to(device)
        lm_model.eval()
        models['text_baseline'] = lm_model
        print(f"Textual Baseline loaded from {repo_id}")
    except Exception as e:
        print(f"Error loading Textual Baseline from HF: {e}")
        models['text_baseline'] = None

    # 3. Visual Baseline (ResNet) - Leave to None for now
    # print("Loading Visual Baseline...")
    # try:
    #     img_model = BaselineImageModel(config).to(device)
    #     ckpt_path = config.get_baseline_checkpoint_path('img')
    #     if os.path.exists(ckpt_path):
    #         checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
    #         img_model.load_state_dict(checkpoint['model_state_dict'])
    #         img_model.eval()
    #         models['visual_baseline'] = img_model
    #         print(f"Visual Baseline loaded from {ckpt_path}")
    #     else:
    #         print(f"Visual Baseline checkpoint not found at {ckpt_path}")
    #         models['visual_baseline'] = None
    # except Exception as e:
    #     print(f"Error loading Visual Baseline: {e}")
    #     models['visual_baseline'] = None
    models['visual_baseline'] = None
    
    # 4. Zero-shot Baseline (Pre-trained GuwenBERT) - DIRECT FROM HF
    print("Loading Zero-shot Baseline (GuwenBERT)...")
    try:
        # fine_tuned=False loads the standard MLM head from huggingface
        zs_model = BaselineLanguageModel(config, fine_tuned=False).to(device)
        zs_model.eval()
        models['zero_shot'] = zs_model
        print(f"Zero-shot Baseline loaded (Pre-trained {config.roberta_model})")
    except Exception as e:
        print(f"Error loading Zero-shot Baseline: {e}")
        models['zero_shot'] = None
        
    return models

# Load models globally (or lazily if needed, but global is fine for demo script)
MODELS = load_models()

# --- Data Loading ---
print("Loading Real World Data...")
try:
    dataset = RealWorldDataset(config, tokenizer)
    # Create a mapping for dropdown: "Index: Context preview..."
    sample_options = []
    for i in range(len(dataset)):
        # Provide a snippet of context for the dropdown
        context = dataset.samples[i]['context']
        # Show first 30 chars
        label = f"Sample {i}: {context[:40]}..."
        sample_options.append((label, i))
except Exception as e:
    print(f"Error loading dataset: {e}")
    dataset = None
    sample_options = []

# --- Helper Functions ---

def tensor_to_pil(tensor):
    """Convert tensor (1, 64, 64) float [0,1] to PIL Image."""
    # tensor: [1, H, W]
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0) # Remove batch dim if present
    if tensor.dim() == 3:
        tensor = tensor.squeeze(0) # Remove channel dim if 1 channel
        
    arr = tensor.cpu().detach().numpy()
    arr = (arr * 255).astype(np.uint8)
    return Image.fromarray(arr, mode='L')

def format_top_k(logits, top_k=20):
    """Return list of (token, probability) tuples."""
    probs = F.softmax(logits, dim=-1)
    top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
    
    results = []
    for p, idx in zip(top_probs[0], top_indices[0]):
        token = tokenizer.decode([idx.item()])
        results.append((token, float(p.item())))
    return results

def run_inference(sample_idx):
    if dataset is None:
        return None, "Dataset not loaded", None, {}
    
    # Load sample
    sample_idx = int(sample_idx) # ensure int
    batch = dataset[sample_idx]
    
    # Prepare inputs
    input_ids = batch['input_ids'].unsqueeze(0).to(device)
    attention_mask = batch['attention_mask'].unsqueeze(0).to(device)
    mask_positions = batch['mask_positions'].unsqueeze(0).to(device)
    damaged_images = batch['damaged_images'].unsqueeze(0).to(device)
    # damaged_images shape: [1, num_masks, 1, 64, 64]
    
    # Get original raw images for display
    first_image_tensor = damaged_images[0, 0] # [1, 64, 64]
    input_display_image = tensor_to_pil(first_image_tensor)
    
    context_text = dataset.samples[sample_idx]['context']
    ground_truth_labels = dataset.samples[sample_idx]['labels']
    ground_truth_text = " ".join(ground_truth_labels)
    
    # --- Inference ---
    
    # 0. Zero-shot Baseline
    zs_res = []
    if MODELS['zero_shot']:
        with torch.no_grad():
            logits = MODELS['zero_shot'](input_ids, attention_mask, mask_positions)
            # logits: [1, num_masks, vocab_size]
            mask_logits = logits[:, 0, :]
            zs_res = format_top_k(mask_logits)
    else:
        zs_res = [("Model not loaded", 0.0)]
    
    # 1. Textual Baseline (Fine-tuned HF Model)
    text_res = []
    if MODELS['text_baseline']:
        with torch.no_grad():
            # Standard HF model returns MaskedLMOutput
            outputs = MODELS['text_baseline'](input_ids=input_ids, attention_mask=attention_mask)
            all_logits = outputs.logits # [batch, seq_len, vocab_size]
            
            # Extract logits at mask positions
            # input_ids/mask_positions: [1, num_masks]
            batch_size, num_masks = mask_positions.shape
            mask_logits = torch.gather(
                all_logits, 1, mask_positions.unsqueeze(-1).expand(-1, -1, all_logits.size(-1))
            )  # [batch, num_masks, vocab_size]
            
            # Take first mask for display
            first_mask_logits = mask_logits[:, 0, :]
            text_res = format_top_k(first_mask_logits)
    else:
        text_res = [("Model not loaded (HF Migration)", 0.0)]

    # 2. Visual Baseline
    visual_res = []
    if MODELS['visual_baseline']:
        with torch.no_grad():
            logits = MODELS['visual_baseline'](damaged_images)
            mask_logits = logits[:, 0, :]
            visual_res = format_top_k(mask_logits)
    else:
        visual_res = [("Model not loaded (custom weight specific)", 0.0)]
        
    # 3. MMRM
    mmrm_res = []
    restored_pil = None
    if MODELS['mmrm']:
        with torch.no_grad():
            logits, restored_imgs = MODELS['mmrm'](
                input_ids, attention_mask, mask_positions, damaged_images
            )
            # logits: [1, num_masks, vocab_size]
            mask_logits = logits[:, 0, :]
            mmrm_res = format_top_k(mask_logits)
            
            # Restored Image (Intermediate)
            # restored_imgs: [1, num_masks, 1, 64, 64]
            restored_tensor = restored_imgs[0, 0]
            restored_pil = tensor_to_pil(restored_tensor)
    else:
        mmrm_res = [("Model not loaded (custom weight specific)", 0.0)]
    
    # Format raw results into a dictionary for State
    raw_results = {
        'zs': zs_res,
        'text': text_res,
        'visual': visual_res,
        'mmrm': mmrm_res
    }

    return (
        input_display_image,
        f"Context: {context_text}\nGround Truth: {ground_truth_text}",
        restored_pil,
        raw_results
    )


# --- Gradio UI ---

with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_size="sm")) as demo:
    gr.Markdown("# MMRM: Multimodal Multitask Restoring Model")
    gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.")
    
    with gr.Row():
    # --- Left Column: Inputs ---
        with gr.Column(scale=1):
            gr.Markdown("### Input Selection")
            with gr.Row():
                sample_dropdown = gr.Dropdown(
                    choices=[x[1] for x in sample_options], 
                    type="value",
                    label="Select Sample",
                    container=False,
                    scale=3
                )
                sample_dropdown.choices = sample_options
                run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60)
            
            with gr.Row():
                top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Top K Predictions")

            with gr.Row():
                input_image = gr.Image(label="Damaged Input", type="pil", height=250)
            
            with gr.Row():
                input_text = gr.Textbox(label="Context Info", lines=5)

        # --- Right Column: Outputs ---
        with gr.Column(scale=2):
            gr.Markdown("### Model Predictions")
            with gr.Row():
                with gr.Column(min_width=80):
                    zs_output = gr.Label(num_top_classes=20, label="Zero-shot")
                with gr.Column(min_width=80):
                    text_output = gr.Label(num_top_classes=20, label="Textual")
                with gr.Column(min_width=80):
                    visual_output = gr.Label(num_top_classes=20, label="Visual")
                with gr.Column(min_width=80):
                    mmrm_output = gr.Label(num_top_classes=20, label="MMRM")
            
            with gr.Row():
                with gr.Column():
                    gr.Markdown("### Visual Restoration")
                    restored_output = gr.Image(label="MMRM Output", type="pil", height=250)
    
    # State to hold raw top-20 results for all models
    # Structure: {"zs": [...], "text": [...], "visual": [...], "mmrm": [...]}
    raw_results_state = gr.State()

    def update_views(raw_results, k):
        if not raw_results:
            return {}, {}, {}, {}
        
        k = int(k)
        def slice_res(key):
            # Take top k from list of tuples
            full_list = raw_results.get(key, [])
            return {term: score for term, score in full_list[:k]}
            
        return (
            slice_res('zs'),
            slice_res('text'),
            slice_res('visual'),
            slice_res('mmrm')
        )

    # Event Chain
    # 1. Run inference -> updates State and Images/Text
    run_event = run_btn.click(
        fn=run_inference,
        inputs=[sample_dropdown],
        outputs=[input_image, input_text, restored_output, raw_results_state]
    )
    
    # 2. Update Labels based on State and Slider (triggered by Run success OR Slider change)
    run_event.success(
        fn=update_views,
        inputs=[raw_results_state, top_k_slider],
        outputs=[zs_output, text_output, visual_output, mmrm_output]
    )
    
    top_k_slider.change(
        fn=update_views,
        inputs=[raw_results_state, top_k_slider],
        outputs=[zs_output, text_output, visual_output, mmrm_output]
    )


if __name__ == "__main__":
    demo.launch()