rexera commited on
Commit
cffdcce
·
1 Parent(s): 0a7da3e
Files changed (1) hide show
  1. app.py +58 -20
app.py CHANGED
@@ -120,7 +120,7 @@ def tensor_to_pil(tensor):
120
  arr = (arr * 255).astype(np.uint8)
121
  return Image.fromarray(arr, mode='L')
122
 
123
- def format_top_k(logits, top_k=5):
124
  """Return list of (token, probability) tuples."""
125
  probs = F.softmax(logits, dim=-1)
126
  top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
@@ -133,7 +133,7 @@ def format_top_k(logits, top_k=5):
133
 
134
  def run_inference(sample_idx):
135
  if dataset is None:
136
- return None, "Dataset not loaded", [], [], [], [], None
137
 
138
  # Load sample
139
  sample_idx = int(sample_idx) # ensure int
@@ -217,19 +217,19 @@ def run_inference(sample_idx):
217
  else:
218
  mmrm_res = [("Model not loaded (custom weight specific)", 0.0)]
219
 
220
- # Format text outputs as string or dictionary for Label
221
- def format_output(results):
222
- out_dict = {k: v for k, v in results}
223
- return out_dict
 
 
 
224
 
225
  return (
226
  input_display_image,
227
  f"Context: {context_text}\nGround Truth: {ground_truth_text}",
228
- format_output(zs_res),
229
- format_output(text_res),
230
- format_output(visual_res),
231
- format_output(mmrm_res),
232
- restored_pil
233
  )
234
 
235
 
@@ -240,12 +240,12 @@ with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_s
240
  gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.")
241
 
242
  with gr.Row():
243
- # --- Left Column: Inputs ---
244
  with gr.Column(scale=1):
245
  gr.Markdown("### Input Selection")
246
  with gr.Row():
247
  sample_dropdown = gr.Dropdown(
248
- choices=[x[1] for x in sample_options], # Use index as value
249
  type="value",
250
  label="Select Sample",
251
  container=False,
@@ -254,6 +254,9 @@ with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_s
254
  sample_dropdown.choices = sample_options
255
  run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60)
256
 
 
 
 
257
  with gr.Row():
258
  input_image = gr.Image(label="Damaged Input", type="pil", height=250)
259
 
@@ -265,24 +268,59 @@ with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_s
265
  gr.Markdown("### Model Predictions")
266
  with gr.Row():
267
  with gr.Column(min_width=80):
268
- zs_output = gr.Label(num_top_classes=3, label="Zero-shot")
269
  with gr.Column(min_width=80):
270
- text_output = gr.Label(num_top_classes=3, label="Textual")
271
  with gr.Column(min_width=80):
272
- visual_output = gr.Label(num_top_classes=3, label="Visual")
273
  with gr.Column(min_width=80):
274
- mmrm_output = gr.Label(num_top_classes=3, label="MMRM")
275
 
276
  with gr.Row():
277
  with gr.Column():
278
  gr.Markdown("### Visual Restoration")
279
  restored_output = gr.Image(label="MMRM Output", type="pil", height=250)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- # Event
282
- run_btn.click(
 
283
  fn=run_inference,
284
  inputs=[sample_dropdown],
285
- outputs=[input_image, input_text, zs_output, text_output, visual_output, mmrm_output, restored_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  )
287
 
288
 
 
120
  arr = (arr * 255).astype(np.uint8)
121
  return Image.fromarray(arr, mode='L')
122
 
123
+ def format_top_k(logits, top_k=20):
124
  """Return list of (token, probability) tuples."""
125
  probs = F.softmax(logits, dim=-1)
126
  top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
 
133
 
134
  def run_inference(sample_idx):
135
  if dataset is None:
136
+ return None, "Dataset not loaded", None, {}
137
 
138
  # Load sample
139
  sample_idx = int(sample_idx) # ensure int
 
217
  else:
218
  mmrm_res = [("Model not loaded (custom weight specific)", 0.0)]
219
 
220
+ # Format raw results into a dictionary for State
221
+ raw_results = {
222
+ 'zs': zs_res,
223
+ 'text': text_res,
224
+ 'visual': visual_res,
225
+ 'mmrm': mmrm_res
226
+ }
227
 
228
  return (
229
  input_display_image,
230
  f"Context: {context_text}\nGround Truth: {ground_truth_text}",
231
+ restored_pil,
232
+ raw_results
 
 
 
233
  )
234
 
235
 
 
240
  gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.")
241
 
242
  with gr.Row():
243
+ # --- Left Column: Inputs ---
244
  with gr.Column(scale=1):
245
  gr.Markdown("### Input Selection")
246
  with gr.Row():
247
  sample_dropdown = gr.Dropdown(
248
+ choices=[x[1] for x in sample_options],
249
  type="value",
250
  label="Select Sample",
251
  container=False,
 
254
  sample_dropdown.choices = sample_options
255
  run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60)
256
 
257
+ with gr.Row():
258
+ top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Top K Predictions")
259
+
260
  with gr.Row():
261
  input_image = gr.Image(label="Damaged Input", type="pil", height=250)
262
 
 
268
  gr.Markdown("### Model Predictions")
269
  with gr.Row():
270
  with gr.Column(min_width=80):
271
+ zs_output = gr.Label(num_top_classes=20, label="Zero-shot")
272
  with gr.Column(min_width=80):
273
+ text_output = gr.Label(num_top_classes=20, label="Textual")
274
  with gr.Column(min_width=80):
275
+ visual_output = gr.Label(num_top_classes=20, label="Visual")
276
  with gr.Column(min_width=80):
277
+ mmrm_output = gr.Label(num_top_classes=20, label="MMRM")
278
 
279
  with gr.Row():
280
  with gr.Column():
281
  gr.Markdown("### Visual Restoration")
282
  restored_output = gr.Image(label="MMRM Output", type="pil", height=250)
283
+
284
+ # State to hold raw top-20 results for all models
285
+ # Structure: {"zs": [...], "text": [...], "visual": [...], "mmrm": [...]}
286
+ raw_results_state = gr.State()
287
+
288
+ def update_views(raw_results, k):
289
+ if not raw_results:
290
+ return {}, {}, {}, {}
291
+
292
+ k = int(k)
293
+ def slice_res(key):
294
+ # Take top k from list of tuples
295
+ full_list = raw_results.get(key, [])
296
+ return {term: score for term, score in full_list[:k]}
297
+
298
+ return (
299
+ slice_res('zs'),
300
+ slice_res('text'),
301
+ slice_res('visual'),
302
+ slice_res('mmrm')
303
+ )
304
 
305
+ # Event Chain
306
+ # 1. Run inference -> updates State and Images/Text
307
+ run_event = run_btn.click(
308
  fn=run_inference,
309
  inputs=[sample_dropdown],
310
+ outputs=[input_image, input_text, restored_output, raw_results_state]
311
+ )
312
+
313
+ # 2. Update Labels based on State and Slider (triggered by Run success OR Slider change)
314
+ run_event.success(
315
+ fn=update_views,
316
+ inputs=[raw_results_state, top_k_slider],
317
+ outputs=[zs_output, text_output, visual_output, mmrm_output]
318
+ )
319
+
320
+ top_k_slider.change(
321
+ fn=update_views,
322
+ inputs=[raw_results_state, top_k_slider],
323
+ outputs=[zs_output, text_output, visual_output, mmrm_output]
324
  )
325
 
326