LevyJonas commited on
Commit
fdedbc8
·
verified ·
1 Parent(s): f0f8a0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -19
app.py CHANGED
@@ -1,15 +1,37 @@
 
1
  import gradio as gr
2
- from pipeline import run_search_and_generate, load_from_hf # reuse pipeline loader
 
 
 
3
 
4
  # --- Quick Starter file paths (must exist in your HF dataset repo) ---
5
  QS_1_PATH = "images/LakeWater/LakeWater_000550.jpg"
6
  QS_2_PATH = "images/DenseForest/DenseForest_000000.jpg"
7
  QS_3_PATH = "images/ResidentialDense/ResidentialDense_001050.jpg"
8
 
9
- def run_app(img1, img2, img3, img4, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
 
 
 
 
11
  retrieved, gen_i2i, gen_t2i, info = run_search_and_generate(
12
- user_imgs=[img1, img2, img3, img4],
13
  user_prompt=prompt,
14
  k_retrieve=k_retrieve,
15
  n_i2i=n_i2i,
@@ -19,37 +41,79 @@ def run_app(img1, img2, img3, img4, prompt, k_retrieve, n_i2i, n_t2i, strength_i
19
  gen_size=gen_size,
20
  seed=int(seed),
21
  )
 
22
  retr_gallery = [(r["img"], f"{r['label']} | cos={r['sim']:.3f}") for r in retrieved]
23
  i2i_gallery = [(im, f"img2img #{i+1}") for i, im in enumerate(gen_i2i)]
24
  t2i_gallery = [(im, f"txt2img #{i+1}") for i, im in enumerate(gen_t2i)]
25
  summary = "\n".join([f"{k}: {v}" for k, v in info.items()])
 
26
  return retr_gallery, i2i_gallery, t2i_gallery, summary
 
27
  except Exception as e:
28
  return [], [], [], f"Error: {e}"
29
 
 
30
  def quickstarter_fill(rel_path):
31
- # Fill only Image 1 by default (user can add more)
32
- return load_from_hf(rel_path), None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  with gr.Blocks(title="Satellite Patch: Retrieve + Generate") as demo:
35
  gr.Markdown(
36
  "# Satellite Patch: Retrieve + Generate\n"
37
- "Upload **up to 4** satellite patches + write a prompt → retrieve similar images and generate new variants.\n\n"
38
- "**Goal:** help predict / hallucinate a missing satellite patch from context."
39
  )
40
 
 
41
  gr.Markdown("### Quick Starters (1-click examples)")
42
  with gr.Row():
43
- qs1 = gr.Button("Quick Starter: LakeWater")
44
- qs2 = gr.Button("Quick Starter: DenseForest")
45
- qs3 = gr.Button("Quick Starter: ResidentialDense")
46
 
47
  with gr.Row():
 
48
  with gr.Column(scale=1):
49
- img1 = gr.Image(type="pil", label="Input Image 1 (required)")
50
- img2 = gr.Image(type="pil", label="Input Image 2 (optional)")
51
- img3 = gr.Image(type="pil", label="Input Image 3 (optional)")
52
- img4 = gr.Image(type="pil", label="Input Image 4 (optional)")
 
 
53
 
54
  prompt = gr.Textbox(
55
  label="Prompt (required for generation)",
@@ -68,21 +132,35 @@ with gr.Blocks(title="Satellite Patch: Retrieve + Generate") as demo:
68
 
69
  btn = gr.Button("Run")
70
 
 
71
  with gr.Column(scale=2):
72
  out_retr = gr.Gallery(label="Retrieved from Dataset", columns=5, height=260)
73
  out_i2i = gr.Gallery(label="Generated (img2img)", columns=5, height=260)
74
  out_t2i = gr.Gallery(label="Generated (txt2img)", columns=5, height=260)
75
  out_txt = gr.Textbox(label="Summary", lines=8)
76
 
 
77
  btn.click(
78
  run_app,
79
- inputs=[img1, img2, img3, img4, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
80
  outputs=[out_retr, out_i2i, out_t2i, out_txt],
81
  )
82
 
83
- # Quick starter: fill Image 1, then user can click Run
84
- qs1.click(quickstarter_fill, inputs=[gr.State(QS_1_PATH)], outputs=[img1, img2, img3, img4])
85
- qs2.click(quickstarter_fill, inputs=[gr.State(QS_2_PATH)], outputs=[img1, img2, img3, img4])
86
- qs3.click(quickstarter_fill, inputs=[gr.State(QS_3_PATH)], outputs=[img1, img2, img3, img4])
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  demo.launch()
 
1
+ # app.py (clean, no duplicates)
2
  import gradio as gr
3
+ from PIL import Image
4
+
5
+ # Pipeline (must support: run_search_and_generate(user_imgs=[...], user_prompt=...))
6
+ from pipeline import run_search_and_generate, load_from_hf
7
 
8
  # --- Quick Starter file paths (must exist in your HF dataset repo) ---
9
  QS_1_PATH = "images/LakeWater/LakeWater_000550.jpg"
10
  QS_2_PATH = "images/DenseForest/DenseForest_000000.jpg"
11
  QS_3_PATH = "images/ResidentialDense/ResidentialDense_001050.jpg"
12
 
13
+
14
+ def _files_to_pil_list(files, n_use: int):
15
+ """Convert uploaded files to a list of PIL images (use first n_use, capped to 1..4)."""
16
+ if not files:
17
+ return []
18
+ n = max(1, min(4, int(n_use), len(files)))
19
+ imgs = []
20
+ for f in files[:n]:
21
+ path = f.name if hasattr(f, "name") else str(f)
22
+ imgs.append(Image.open(path).convert("RGB"))
23
+ return imgs
24
+
25
+
26
+ def run_app(files, n_user_imgs, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed):
27
+ """Main app handler."""
28
  try:
29
+ user_imgs = _files_to_pil_list(files, n_user_imgs)
30
+ if len(user_imgs) == 0:
31
+ return [], [], [], "Error: Please upload at least 1 image."
32
+
33
  retrieved, gen_i2i, gen_t2i, info = run_search_and_generate(
34
+ user_imgs=user_imgs,
35
  user_prompt=prompt,
36
  k_retrieve=k_retrieve,
37
  n_i2i=n_i2i,
 
41
  gen_size=gen_size,
42
  seed=int(seed),
43
  )
44
+
45
  retr_gallery = [(r["img"], f"{r['label']} | cos={r['sim']:.3f}") for r in retrieved]
46
  i2i_gallery = [(im, f"img2img #{i+1}") for i, im in enumerate(gen_i2i)]
47
  t2i_gallery = [(im, f"txt2img #{i+1}") for i, im in enumerate(gen_t2i)]
48
  summary = "\n".join([f"{k}: {v}" for k, v in info.items()])
49
+
50
  return retr_gallery, i2i_gallery, t2i_gallery, summary
51
+
52
  except Exception as e:
53
  return [], [], [], f"Error: {e}"
54
 
55
+
56
  def quickstarter_fill(rel_path):
57
+ """
58
+ 1-click quick starter:
59
+ loads a dataset image from HF and returns it as the ONLY uploaded file.
60
+ The user can then adjust the slider (#images to use) and/or upload more.
61
+ """
62
+ img = load_from_hf(rel_path)
63
+ return [img] # This will populate the gr.Files-like input (we will use a hidden state)
64
+
65
+
66
+ # NOTE:
67
+ # Gradio's gr.Files expects actual uploaded files, not PIL objects.
68
+ # So for Quick Starters, we will route through a separate hidden Image input
69
+ # and run the pipeline directly (no file upload required).
70
+
71
+ def run_quickstarter(rel_path, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed):
72
+ try:
73
+ img = load_from_hf(rel_path)
74
+ retrieved, gen_i2i, gen_t2i, info = run_search_and_generate(
75
+ user_imgs=[img],
76
+ user_prompt=prompt,
77
+ k_retrieve=k_retrieve,
78
+ n_i2i=n_i2i,
79
+ n_t2i=n_t2i,
80
+ strength_i2i=strength_i2i,
81
+ steps=steps,
82
+ gen_size=gen_size,
83
+ seed=int(seed),
84
+ )
85
+ retr_gallery = [(r["img"], f"{r['label']} | cos={r['sim']:.3f}") for r in retrieved]
86
+ i2i_gallery = [(im, f"img2img #{i+1}") for i, im in enumerate(gen_i2i)]
87
+ t2i_gallery = [(im, f"txt2img #{i+1}") for i, im in enumerate(gen_t2i)]
88
+ summary = "\n".join([f"{k}: {v}" for k, v in info.items()])
89
+ return retr_gallery, i2i_gallery, t2i_gallery, summary
90
+ except Exception as e:
91
+ return [], [], [], f"Error: {e}"
92
+
93
 
94
  with gr.Blocks(title="Satellite Patch: Retrieve + Generate") as demo:
95
  gr.Markdown(
96
  "# Satellite Patch: Retrieve + Generate\n"
97
+ "Upload **up to 4 satellite patches** + write a prompt → retrieve similar images and generate new variants.\n\n"
98
+ "**Quick Start:** click one of the buttons below to instantly see an example (no upload required)."
99
  )
100
 
101
+ # --- Quick Starters (1-click) ---
102
  gr.Markdown("### Quick Starters (1-click examples)")
103
  with gr.Row():
104
+ qs1 = gr.Button("LakeWater")
105
+ qs2 = gr.Button("DenseForest")
106
+ qs3 = gr.Button("ResidentialDense")
107
 
108
  with gr.Row():
109
+ # --- Inputs ---
110
  with gr.Column(scale=1):
111
+ files = gr.Files(
112
+ label="Upload up to 4 satellite patch images",
113
+ file_types=["image"],
114
+ file_count="multiple"
115
+ )
116
+ n_user_imgs = gr.Slider(1, 4, value=1, step=1, label="How many uploaded images to use (1–4)")
117
 
118
  prompt = gr.Textbox(
119
  label="Prompt (required for generation)",
 
132
 
133
  btn = gr.Button("Run")
134
 
135
+ # --- Outputs ---
136
  with gr.Column(scale=2):
137
  out_retr = gr.Gallery(label="Retrieved from Dataset", columns=5, height=260)
138
  out_i2i = gr.Gallery(label="Generated (img2img)", columns=5, height=260)
139
  out_t2i = gr.Gallery(label="Generated (txt2img)", columns=5, height=260)
140
  out_txt = gr.Textbox(label="Summary", lines=8)
141
 
142
+ # Normal run (user-uploaded files)
143
  btn.click(
144
  run_app,
145
+ inputs=[files, n_user_imgs, prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
146
  outputs=[out_retr, out_i2i, out_t2i, out_txt],
147
  )
148
 
149
+ # Quick Starter runs (1-click)
150
+ qs1.click(
151
+ run_quickstarter,
152
+ inputs=[gr.State(QS_1_PATH), prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
153
+ outputs=[out_retr, out_i2i, out_t2i, out_txt],
154
+ )
155
+ qs2.click(
156
+ run_quickstarter,
157
+ inputs=[gr.State(QS_2_PATH), prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
158
+ outputs=[out_retr, out_i2i, out_t2i, out_txt],
159
+ )
160
+ qs3.click(
161
+ run_quickstarter,
162
+ inputs=[gr.State(QS_3_PATH), prompt, k_retrieve, n_i2i, n_t2i, strength_i2i, steps, gen_size, seed],
163
+ outputs=[out_retr, out_i2i, out_t2i, out_txt],
164
+ )
165
 
166
  demo.launch()