cp524 commited on
Commit
2fe9c08
·
1 Parent(s): 4789d91

Implement saved output loading for examples

Browse files
Files changed (1) hide show
  1. app.py +63 -1
app.py CHANGED
@@ -12,6 +12,7 @@ from src.smc.inference import (
12
  SMCGradInferenceConfig,
13
  FTInferenceConfig,
14
  )
 
15
 
16
  GALLERY_HEIGHT = "224px"
17
 
@@ -46,6 +47,49 @@ def _format_inference_output(out) -> str:
46
  except Exception:
47
  return "Could not parse inference output"
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # --- Per-method runner functions ---
51
  def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
@@ -150,7 +194,7 @@ with gr.Blocks() as demo:
150
  prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1)
151
  run_button = gr.Button("Run", variant="primary")
152
 
153
- gr.Examples(examples=examples, inputs=prompt)
154
 
155
  # --- Pretrained method row ---
156
  with gr.Row():
@@ -308,6 +352,24 @@ with gr.Blocks() as demo:
308
  inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
309
  outputs=[ft_gallery, ft_info],
310
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  # Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
313
  # to 2 (one per method) — increase if you add more methods.
 
12
  SMCGradInferenceConfig,
13
  FTInferenceConfig,
14
  )
15
+ from run_examples import get_out_if_exists
16
 
17
  GALLERY_HEIGHT = "224px"
18
 
 
47
  except Exception:
48
  return "Could not parse inference output"
49
 
50
+ def try_load_saved_outputs(prompt):
51
+ """
52
+ Check for saved outputs for the given prompt for each method and return
53
+ (pretrained_gallery, pretrained_info, smc_gallery, smc_info, ft_gallery, ft_info).
54
+
55
+ If no saved output exists for a method, returns an empty gallery and
56
+ \"No saved output\" for info for that method.
57
+ """
58
+ try:
59
+ # Pretrained
60
+ pre_cfg = PretrainedInferenceConfig(prompt=prompt)
61
+ pre_out = get_out_if_exists("pretrained", pre_cfg)
62
+ if pre_out is not None:
63
+ pre_gallery = pre_out.images
64
+ pre_info = _format_inference_output(pre_out)
65
+ else:
66
+ pre_gallery, pre_info = [], "No saved output"
67
+
68
+ # SMC-grad
69
+ smc_cfg = SMCGradInferenceConfig(prompt=prompt)
70
+ smc_out = get_out_if_exists("smc_grad", smc_cfg)
71
+ if smc_out is not None:
72
+ smc_gallery = smc_out.images
73
+ smc_info = _format_inference_output(smc_out)
74
+ else:
75
+ smc_gallery, smc_info = [], "No saved output"
76
+
77
+ # FT
78
+ ft_cfg = FTInferenceConfig(prompt=prompt)
79
+ ft_out = get_out_if_exists("ft", ft_cfg)
80
+ if ft_out is not None:
81
+ ft_gallery = ft_out.images
82
+ ft_info = _format_inference_output(ft_out)
83
+ else:
84
+ ft_gallery, ft_info = [], "No saved output"
85
+
86
+ return pre_gallery, pre_info, smc_gallery, smc_info, ft_gallery, ft_info
87
+
88
+ except Exception as e:
89
+ # Don't crash the UI; print the traceback and return empty placeholders
90
+ traceback.print_exc()
91
+ return [], "Error checking saved outputs", [], "Error checking saved outputs", [], "Error checking saved outputs"
92
+
93
 
94
  # --- Per-method runner functions ---
95
  def run_pretrained_ui(prompt, pretrained_negative_prompt, pretrained_CFG, pretrained_steps):
 
194
  prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1)
195
  run_button = gr.Button("Run", variant="primary")
196
 
197
+ examples_widget = gr.Examples(examples=examples, inputs=prompt)
198
 
199
  # --- Pretrained method row ---
200
  with gr.Row():
 
352
  inputs=[prompt, ft_negative_prompt, ft_CFG, ft_steps],
353
  outputs=[ft_gallery, ft_info],
354
  )
355
+
356
+ # Trigger when an example is selected
357
+ examples_widget.load_input_event.then(
358
+ fn=try_load_saved_outputs,
359
+ inputs=[prompt],
360
+ outputs=[
361
+ pretrained_gallery, pretrained_info,
362
+ smc_grad_gallery, smc_grad_info,
363
+ ft_gallery, ft_info,
364
+ ],
365
+ )
366
+
367
+ # Trigger once on page load for the initial prompt value (so example[0] loads on startup)
368
+ demo.load(
369
+ fn=try_load_saved_outputs,
370
+ inputs=[prompt],
371
+ outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info, ft_gallery, ft_info],
372
+ )
373
 
374
  # Enable Gradio queue to allow parallel execution of multiple handlers. Set concurrency
375
  # to 2 (one per method) — increase if you add more methods.