followup
Browse files
app.py
CHANGED
|
@@ -120,7 +120,7 @@ def tensor_to_pil(tensor):
|
|
| 120 |
arr = (arr * 255).astype(np.uint8)
|
| 121 |
return Image.fromarray(arr, mode='L')
|
| 122 |
|
| 123 |
-
def format_top_k(logits, top_k=
|
| 124 |
"""Return list of (token, probability) tuples."""
|
| 125 |
probs = F.softmax(logits, dim=-1)
|
| 126 |
top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
|
|
@@ -133,7 +133,7 @@ def format_top_k(logits, top_k=5):
|
|
| 133 |
|
| 134 |
def run_inference(sample_idx):
|
| 135 |
if dataset is None:
|
| 136 |
-
return None, "Dataset not loaded",
|
| 137 |
|
| 138 |
# Load sample
|
| 139 |
sample_idx = int(sample_idx) # ensure int
|
|
@@ -217,19 +217,19 @@ def run_inference(sample_idx):
|
|
| 217 |
else:
|
| 218 |
mmrm_res = [("Model not loaded (custom weight specific)", 0.0)]
|
| 219 |
|
| 220 |
-
# Format
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
return (
|
| 226 |
input_display_image,
|
| 227 |
f"Context: {context_text}\nGround Truth: {ground_truth_text}",
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
format_output(visual_res),
|
| 231 |
-
format_output(mmrm_res),
|
| 232 |
-
restored_pil
|
| 233 |
)
|
| 234 |
|
| 235 |
|
|
@@ -240,12 +240,12 @@ with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_s
|
|
| 240 |
gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.")
|
| 241 |
|
| 242 |
with gr.Row():
|
| 243 |
-
|
| 244 |
with gr.Column(scale=1):
|
| 245 |
gr.Markdown("### Input Selection")
|
| 246 |
with gr.Row():
|
| 247 |
sample_dropdown = gr.Dropdown(
|
| 248 |
-
choices=[x[1] for x in sample_options],
|
| 249 |
type="value",
|
| 250 |
label="Select Sample",
|
| 251 |
container=False,
|
|
@@ -254,6 +254,9 @@ with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_s
|
|
| 254 |
sample_dropdown.choices = sample_options
|
| 255 |
run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60)
|
| 256 |
|
|
|
|
|
|
|
|
|
|
| 257 |
with gr.Row():
|
| 258 |
input_image = gr.Image(label="Damaged Input", type="pil", height=250)
|
| 259 |
|
|
@@ -265,24 +268,59 @@ with gr.Blocks(title="MMRM Demo", theme=gr.themes.Soft(spacing_size="sm", text_s
|
|
| 265 |
gr.Markdown("### Model Predictions")
|
| 266 |
with gr.Row():
|
| 267 |
with gr.Column(min_width=80):
|
| 268 |
-
zs_output = gr.Label(num_top_classes=
|
| 269 |
with gr.Column(min_width=80):
|
| 270 |
-
text_output = gr.Label(num_top_classes=
|
| 271 |
with gr.Column(min_width=80):
|
| 272 |
-
visual_output = gr.Label(num_top_classes=
|
| 273 |
with gr.Column(min_width=80):
|
| 274 |
-
mmrm_output = gr.Label(num_top_classes=
|
| 275 |
|
| 276 |
with gr.Row():
|
| 277 |
with gr.Column():
|
| 278 |
gr.Markdown("### Visual Restoration")
|
| 279 |
restored_output = gr.Image(label="MMRM Output", type="pil", height=250)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
-
# Event
|
| 282 |
-
|
|
|
|
| 283 |
fn=run_inference,
|
| 284 |
inputs=[sample_dropdown],
|
| 285 |
-
outputs=[input_image, input_text,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
)
|
| 287 |
|
| 288 |
|
|
|
|
| 120 |
arr = (arr * 255).astype(np.uint8)
|
| 121 |
return Image.fromarray(arr, mode='L')
|
| 122 |
|
| 123 |
+
def format_top_k(logits, top_k=20):
|
| 124 |
"""Return list of (token, probability) tuples."""
|
| 125 |
probs = F.softmax(logits, dim=-1)
|
| 126 |
top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
|
|
|
|
| 133 |
|
| 134 |
def run_inference(sample_idx):
|
| 135 |
if dataset is None:
|
| 136 |
+
return None, "Dataset not loaded", None, {}
|
| 137 |
|
| 138 |
# Load sample
|
| 139 |
sample_idx = int(sample_idx) # ensure int
|
|
|
|
| 217 |
else:
|
| 218 |
mmrm_res = [("Model not loaded (custom weight specific)", 0.0)]
|
| 219 |
|
| 220 |
+
# Format raw results into a dictionary for State
|
| 221 |
+
raw_results = {
|
| 222 |
+
'zs': zs_res,
|
| 223 |
+
'text': text_res,
|
| 224 |
+
'visual': visual_res,
|
| 225 |
+
'mmrm': mmrm_res
|
| 226 |
+
}
|
| 227 |
|
| 228 |
return (
|
| 229 |
input_display_image,
|
| 230 |
f"Context: {context_text}\nGround Truth: {ground_truth_text}",
|
| 231 |
+
restored_pil,
|
| 232 |
+
raw_results
|
|
|
|
|
|
|
|
|
|
| 233 |
)
|
| 234 |
|
| 235 |
|
|
|
|
| 240 |
gr.Markdown("Comparing MMRM with baselines on real-world damaged characters.")
|
| 241 |
|
| 242 |
with gr.Row():
|
| 243 |
+
# --- Left Column: Inputs ---
|
| 244 |
with gr.Column(scale=1):
|
| 245 |
gr.Markdown("### Input Selection")
|
| 246 |
with gr.Row():
|
| 247 |
sample_dropdown = gr.Dropdown(
|
| 248 |
+
choices=[x[1] for x in sample_options],
|
| 249 |
type="value",
|
| 250 |
label="Select Sample",
|
| 251 |
container=False,
|
|
|
|
| 254 |
sample_dropdown.choices = sample_options
|
| 255 |
run_btn = gr.Button("Run", variant="primary", scale=1, min_width=60)
|
| 256 |
|
| 257 |
+
with gr.Row():
|
| 258 |
+
top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Top K Predictions")
|
| 259 |
+
|
| 260 |
with gr.Row():
|
| 261 |
input_image = gr.Image(label="Damaged Input", type="pil", height=250)
|
| 262 |
|
|
|
|
| 268 |
gr.Markdown("### Model Predictions")
|
| 269 |
with gr.Row():
|
| 270 |
with gr.Column(min_width=80):
|
| 271 |
+
zs_output = gr.Label(num_top_classes=20, label="Zero-shot")
|
| 272 |
with gr.Column(min_width=80):
|
| 273 |
+
text_output = gr.Label(num_top_classes=20, label="Textual")
|
| 274 |
with gr.Column(min_width=80):
|
| 275 |
+
visual_output = gr.Label(num_top_classes=20, label="Visual")
|
| 276 |
with gr.Column(min_width=80):
|
| 277 |
+
mmrm_output = gr.Label(num_top_classes=20, label="MMRM")
|
| 278 |
|
| 279 |
with gr.Row():
|
| 280 |
with gr.Column():
|
| 281 |
gr.Markdown("### Visual Restoration")
|
| 282 |
restored_output = gr.Image(label="MMRM Output", type="pil", height=250)
|
| 283 |
+
|
| 284 |
+
# State to hold raw top-20 results for all models
|
| 285 |
+
# Structure: {"zs": [...], "text": [...], "visual": [...], "mmrm": [...]}
|
| 286 |
+
raw_results_state = gr.State()
|
| 287 |
+
|
| 288 |
+
def update_views(raw_results, k):
|
| 289 |
+
if not raw_results:
|
| 290 |
+
return {}, {}, {}, {}
|
| 291 |
+
|
| 292 |
+
k = int(k)
|
| 293 |
+
def slice_res(key):
|
| 294 |
+
# Take top k from list of tuples
|
| 295 |
+
full_list = raw_results.get(key, [])
|
| 296 |
+
return {term: score for term, score in full_list[:k]}
|
| 297 |
+
|
| 298 |
+
return (
|
| 299 |
+
slice_res('zs'),
|
| 300 |
+
slice_res('text'),
|
| 301 |
+
slice_res('visual'),
|
| 302 |
+
slice_res('mmrm')
|
| 303 |
+
)
|
| 304 |
|
| 305 |
+
# Event Chain
|
| 306 |
+
# 1. Run inference -> updates State and Images/Text
|
| 307 |
+
run_event = run_btn.click(
|
| 308 |
fn=run_inference,
|
| 309 |
inputs=[sample_dropdown],
|
| 310 |
+
outputs=[input_image, input_text, restored_output, raw_results_state]
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# 2. Update Labels based on State and Slider (triggered by Run success OR Slider change)
|
| 314 |
+
run_event.success(
|
| 315 |
+
fn=update_views,
|
| 316 |
+
inputs=[raw_results_state, top_k_slider],
|
| 317 |
+
outputs=[zs_output, text_output, visual_output, mmrm_output]
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
top_k_slider.change(
|
| 321 |
+
fn=update_views,
|
| 322 |
+
inputs=[raw_results_state, top_k_slider],
|
| 323 |
+
outputs=[zs_output, text_output, visual_output, mmrm_output]
|
| 324 |
)
|
| 325 |
|
| 326 |
|