comdoleger commited on
Commit
1b14f71
·
verified ·
1 Parent(s): 221da0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -88
app.py CHANGED
@@ -10,14 +10,20 @@ from PIL import Image
10
  BTEN_API_KEY = os.getenv("API_KEY")
11
  URL = os.getenv("URL")
12
 
13
-
14
  def image_to_base64(image: Image.Image) -> str:
 
15
  with io.BytesIO() as buffer:
16
  image.save(buffer, format="PNG")
17
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
18
 
19
 
20
  def ensure_image(img) -> Image.Image:
 
 
 
 
 
 
21
  if isinstance(img, Image.Image):
22
  return img
23
  elif isinstance(img, str):
@@ -28,7 +34,19 @@ def ensure_image(img) -> Image.Image:
28
  raise ValueError("Cannot convert input to a PIL Image.")
29
 
30
 
31
- def call_baseten_generate(image: Image.Image, prompt: str, steps: int, strength: float, height: int, width: int, lora_name: str, remove_bg: bool) -> Image.Image | None:
 
 
 
 
 
 
 
 
 
 
 
 
32
  image = ensure_image(image)
33
  b64_image = image_to_base64(image)
34
  payload = {
@@ -41,10 +59,14 @@ def call_baseten_generate(image: Image.Image, prompt: str, steps: int, strength:
41
  "lora_name": lora_name,
42
  "bgrm": remove_bg,
43
  }
44
- headers = {"Authorization": f"Api-Key {BTEN_API_KEY or os.getenv('API_KEY')}"}
 
 
 
45
  try:
46
  if not URL:
47
  raise ValueError("The URL environment variable is not set.")
 
48
  response = requests.post(URL, headers=headers, json=payload)
49
  if response.status_code == 200:
50
  data = response.json()
@@ -61,17 +83,20 @@ def call_baseten_generate(image: Image.Image, prompt: str, steps: int, strength:
61
  return None
62
 
63
 
64
- # ================== MODE CONFIG =====================
65
 
66
- Mode = TypedDict("Mode", {
67
- "model": str,
68
- "prompt": str,
69
- "default_strength": float,
70
- "default_height": int,
71
- "default_width": int,
72
- "models": list[str],
73
- "remove_bg": bool,
74
- })
 
 
 
75
 
76
  MODE_DEFAULTS: dict[str, Mode] = {
77
  "Subject Generation": {
@@ -80,7 +105,12 @@ MODE_DEFAULTS: dict[str, Mode] = {
80
  "default_strength": 1.2,
81
  "default_height": 512,
82
  "default_width": 512,
83
- "models": ["zendsd_512_146000", "subject_99000_512", "zen_26000_512"],
 
 
 
 
 
84
  "remove_bg": True,
85
  },
86
  "Background Generation": {
@@ -89,7 +119,20 @@ MODE_DEFAULTS: dict[str, Mode] = {
89
  "default_strength": 1.2,
90
  "default_height": 1024,
91
  "default_width": 1024,
92
- "models": ["bgwlight_15000_1024", "bg_canny_58000_1024", "gen_back_7000_1024"],
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  "remove_bg": True,
94
  },
95
  "Canny": {
@@ -107,7 +150,9 @@ MODE_DEFAULTS: dict[str, Mode] = {
107
  "default_strength": 1.2,
108
  "default_height": 1024,
109
  "default_width": 1024,
110
- "models": ["depth_9800_1024"],
 
 
111
  "remove_bg": True,
112
  },
113
  "Deblurring": {
@@ -116,87 +161,63 @@ MODE_DEFAULTS: dict[str, Mode] = {
116
  "default_strength": 1.2,
117
  "default_height": 1024,
118
  "default_width": 1024,
119
- "models": ["deblurr_1024_10000"],
120
  "remove_bg": False,
121
  },
122
  }
123
 
124
- # ================== PRESET EXAMPLES =====================
125
-
126
- MODE_EXAMPLES = {
127
- "Subject Generation": [
128
- ["assets/subj1.jpg", "Close-up portrait of a fruit bowl", "assets/subj1_out.jpg"],
129
- ["assets/subj2.jpg", "A penguin standing in snow", "assets/subj2_out.jpg"],
130
- ["assets/subj3.jpg", "A cat with glowing eyes", "assets/subj3_out.jpg"],
131
- ["assets/subj4.jpg", "A child playing with bubbles", "assets/subj4_out.jpg"],
132
- ["assets/subj5.jpg", "A stylish young man in neon lights", "assets/subj5_out.jpg"],
133
- ["assets/subj6.jpg", "Old man with a mysterious look", "assets/subj6_out.jpg"],
134
- ],
135
-
136
- "Background Generation": [
137
- ["assets/bg1.jpg", "Modern living room with plants", "assets/bg1_out.jpg"],
138
- ["assets/bg2.jpg", "Fantasy forest background", "assets/bg2_out.jpg"],
139
- ["assets/bg3.jpg", "Futuristic cityscape", "assets/bg3_out.jpg"],
140
- ["assets/bg4.jpg", "Minimalist white studio", "assets/bg4_out.jpg"],
141
- ["assets/bg5.jpg", "Snowy mountain landscape", "assets/bg5_out.jpg"],
142
- ["assets/bg6.jpg", "Golden sunset over the sea", "assets/bg6_out.jpg"],
143
- ],
144
-
145
- "Canny": [
146
- ["assets/canny1.jpg", "A neon cyberpunk city skyline", "assets/canny1_out.jpg"],
147
- ["assets/canny2.jpg", "A robot walking in the fog", "assets/canny2_out.jpg"],
148
- ["assets/canny3.jpg", "A futuristic vehicle parked under a bridge", "assets/canny3_out.jpg"],
149
- ["assets/canny4.jpg", "Sci-fi lab interior with glowing machinery", "assets/canny4_out.jpg"],
150
- ["assets/canny5.jpg", "A portrait of a woman outlined in neon", "assets/canny5_out.jpg"],
151
- ["assets/canny6.jpg", "Post-apocalyptic abandoned street", "assets/canny6_out.jpg"],
152
- ],
153
-
154
- "Depth": [
155
- ["assets/depth1.jpg", "A narrow alleyway with deep perspective", "assets/depth1_out.jpg"],
156
- ["assets/depth2.jpg", "A mountain road vanishing into the distance", "assets/depth2_out.jpg"],
157
- ["assets/depth3.jpg", "A hallway with strong depth of field", "assets/depth3_out.jpg"],
158
- ["assets/depth4.jpg", "A misty forest path stretching far away", "assets/depth4_out.jpg"],
159
- ["assets/depth5.jpg", "A bridge over a deep canyon", "assets/depth5_out.jpg"],
160
- ["assets/depth6.jpg", "An underground tunnel with receding arches", "assets/depth6_out.jpg"],
161
- ],
162
-
163
- "Deblurring": [
164
- ["assets/deblur1.jpg", "", "assets/deblur1_out.jpg"],
165
- ["assets/deblur2.jpg", "", "assets/deblur2_out.jpg"],
166
- ["assets/deblur3.jpg", "", "assets/deblur3_out.jpg"],
167
- ["assets/deblur4.jpg", "", "assets/deblur4_out.jpg"],
168
- ["assets/deblur5.jpg", "", "assets/deblur5_out.jpg"],
169
- ["assets/deblur6.jpg", "", "assets/deblur6_out.jpg"],
170
- ],
171
- }
172
-
173
-
174
- # ================== UI =====================
175
 
176
  header = """
177
  <h1>🌍 ZenCtrl / FLUX</h1>
178
  <div align="center" style="line-height: 1;">
179
- <a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg"></a>
180
- <a href="https://huggingface.co/spaces/fotographerai/ZenCtrl" target="_blank"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg"></a>
181
- <a href="https://discord.com/invite/b9RuYQ3F8k" target="_blank"><img src="https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord"></a>
 
 
182
  </div>
183
  """
184
 
 
 
 
185
  with gr.Blocks(title="🌍 ZenCtrl") as demo:
186
  gr.HTML(header)
187
- gr.Markdown("# ZenCtrl Demo")
188
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  with gr.Tabs():
190
  for mode in MODE_DEFAULTS:
191
  with gr.Tab(mode):
192
  defaults = MODE_DEFAULTS[mode]
193
  gr.Markdown(f"### {mode} Mode")
 
194
 
195
  with gr.Row():
196
- with gr.Column(scale=2):
197
- input_image = gr.Image(label="Input Image", type="pil")
 
 
 
 
 
 
198
  generate_button = gr.Button("Generate")
199
- with gr.Blocks():
200
  model_dropdown = gr.Dropdown(
201
  label="Model",
202
  choices=defaults["models"],
@@ -208,12 +229,20 @@ with gr.Blocks(title="🌍 ZenCtrl") as demo:
208
  )
209
 
210
  with gr.Column(scale=2):
211
- output_image = gr.Image(label="Generated Image", type="pil")
 
 
 
 
 
 
212
 
 
213
  prompt_box = gr.Textbox(
214
  label="Prompt", value=defaults["prompt"], lines=2
215
  )
216
 
 
217
  with gr.Accordion("Generation Parameters", open=False):
218
  with gr.Row():
219
  step_slider = gr.Slider(
@@ -242,8 +271,26 @@ with gr.Blocks(title="🌍 ZenCtrl") as demo:
242
  label="Width",
243
  )
244
 
245
- def on_generate_click(model_name, prompt, steps, strength, height, width, remove_bg, image):
246
- return call_baseten_generate(image, prompt, steps, strength, height, width, model_name, remove_bg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  generate_button.click(
249
  fn=on_generate_click,
@@ -258,16 +305,9 @@ with gr.Blocks(title="🌍 ZenCtrl") as demo:
258
  input_image,
259
  ],
260
  outputs=[output_image],
 
261
  )
262
 
263
- # ---------------- Templates --------------------
264
- gr.Dataset(
265
- label="Presets (Input / Prompt / Output)",
266
- headers=["Input", "Prompt", "Output"],
267
- components=[input_image, prompt_box, output_image],
268
- samples=MODE_EXAMPLES.get(mode, []),
269
- samples_per_page=6,
270
- )
271
 
272
  if __name__ == "__main__":
273
- demo.launch()
 
10
  BTEN_API_KEY = os.getenv("API_KEY")
11
  URL = os.getenv("URL")
12
 
 
13
  def image_to_base64(image: Image.Image) -> str:
14
+ """Convert a PIL image to a base64-encoded PNG string."""
15
  with io.BytesIO() as buffer:
16
  image.save(buffer, format="PNG")
17
  return base64.b64encode(buffer.getvalue()).decode("utf-8")
18
 
19
 
20
  def ensure_image(img) -> Image.Image:
21
+ """
22
+ Ensure the input is a PIL Image.
23
+ If it's already a PIL Image, return it.
24
+ If it's a string (file path), open it.
25
+ If it's a dict with a "name" key, open the file at that path.
26
+ """
27
  if isinstance(img, Image.Image):
28
  return img
29
  elif isinstance(img, str):
 
34
  raise ValueError("Cannot convert input to a PIL Image.")
35
 
36
 
37
+ def call_baseten_generate(
38
+ image: Image.Image,
39
+ prompt: str,
40
+ steps: int,
41
+ strength: float,
42
+ height: int,
43
+ width: int,
44
+ lora_name: str,
45
+ remove_bg: bool,
46
+ ) -> Image.Image | None:
47
+ """
48
+ Call the Baseten /predict endpoint with provided parameters and return the generated image.
49
+ """
50
  image = ensure_image(image)
51
  b64_image = image_to_base64(image)
52
  payload = {
 
59
  "lora_name": lora_name,
60
  "bgrm": remove_bg,
61
  }
62
+ if not BTEN_API_KEY:
63
+ headers = {"Authorization": f"Api-Key {os.getenv('API_KEY')}"}
64
+ else:
65
+ headers = {"Authorization": f"Api-Key {BTEN_API_KEY}"}
66
  try:
67
  if not URL:
68
  raise ValueError("The URL environment variable is not set.")
69
+
70
  response = requests.post(URL, headers=headers, json=payload)
71
  if response.status_code == 200:
72
  data = response.json()
 
83
  return None
84
 
85
 
86
+ # Mode defaults for each tab.
87
 
88
+ Mode = TypedDict(
89
+ "Mode",
90
+ {
91
+ "model": str,
92
+ "prompt": str,
93
+ "default_strength": float,
94
+ "default_height": int,
95
+ "default_width": int,
96
+ "models": list[str],
97
+ "remove_bg": bool,
98
+ },
99
+ )
100
 
101
  MODE_DEFAULTS: dict[str, Mode] = {
102
  "Subject Generation": {
 
105
  "default_strength": 1.2,
106
  "default_height": 512,
107
  "default_width": 512,
108
+ "models": [
109
+ "zendsd_512_146000",
110
+ "subject_99000_512",
111
+ # "zen_pers_11000",
112
+ "zen_26000_512",
113
+ ],
114
  "remove_bg": True,
115
  },
116
  "Background Generation": {
 
119
  "default_strength": 1.2,
120
  "default_height": 1024,
121
  "default_width": 1024,
122
+ "models": [
123
+ "bgwlight_15000_1024",
124
+ # "rmgb_12000_1024",
125
+ "bg_canny_58000_1024",
126
+ # "gen_back_3000_1024",
127
+ "gen_back_7000_1024",
128
+ # "gen_bckgnd_18000_512",
129
+ # "gen_bckgnd_18000_512",
130
+ # "loose_25000_512",
131
+ # "looser_23000_1024",
132
+ # "looser_bg_gen_21000_1280",
133
+ # "old_looser_46000_1024",
134
+ # "relight_bg_gen_31000_1024",
135
+ ],
136
  "remove_bg": True,
137
  },
138
  "Canny": {
 
150
  "default_strength": 1.2,
151
  "default_height": 1024,
152
  "default_width": 1024,
153
+ "models": [
154
+ "depth_9800_1024",
155
+ ],
156
  "remove_bg": True,
157
  },
158
  "Deblurring": {
 
161
  "default_strength": 1.2,
162
  "default_height": 1024,
163
  "default_width": 1024,
164
+ "models": ["deblurr_1024_10000"], # "slight_deblurr_18000",
165
  "remove_bg": False,
166
  },
167
  }
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  header = """
171
  <h1>🌍 ZenCtrl / FLUX</h1>
172
  <div align="center" style="line-height: 1;">
173
+ <a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank" style="margin: 2px;" name="github_repo_link"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg" alt="GitHub Repo" style="display: inline-block; vertical-align: middle;"></a>
174
+ <a href="https://huggingface.co/spaces/fotographerai/ZenCtrl" target="_blank" name="huggingface_space_link"><img src="https://img.shields.io/badge/🤗_HuggingFace-Space-ffbd45.svg" alt="HuggingFace Space" style="display: inline-block; vertical-align: middle;"></a>
175
+ <a href="https://discord.com/invite/b9RuYQ3F8k" target="_blank" style="margin: 2px;" name="discord_link"><img src="https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord" alt="Discord" style="display: inline-block; vertical-align: middle;"></a>
176
+ <a href="https://fotographer.ai/" target="_blank" style="margin: 2px;" name="lp_link"><img src="https://img.shields.io/badge/Website-Landing_Page-blue" alt="LP" style="display: inline-block; vertical-align: middle;"></a>
177
+ <a href="https://x.com/FotographerAI" target="_blank" style="margin: 2px;" name="twitter_link"><img src="https://img.shields.io/twitter/follow/FotographerAI?style=social" alt="X" style="display: inline-block; vertical-align: middle;"></a>
178
  </div>
179
  """
180
 
181
+ defaults = MODE_DEFAULTS["Subject Generation"]
182
+
183
+
184
  with gr.Blocks(title="🌍 ZenCtrl") as demo:
185
  gr.HTML(header)
186
+ gr.Markdown(
187
+ """
188
+ # ZenCtrl Demo
189
+ [WIP] One Agent to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject image—without fine-tuning.
190
+ We are first releasing some of the task specific weights and will release the codes soon.
191
+ The goal is to unify all of the visual content generation tasks with a single LLM...
192
+
193
+ **Modes:**
194
+ - **Subject Generation:** Focuses on generating detailed subject portraits.
195
+ - **Background Generation:** Creates dynamic, vibrant backgrounds:
196
+ You can generate part of the image from sketch while keeping part of it as it is.
197
+ - **Canny:** Emphasizes strong edge detection.
198
+ - **Depth:** Produces images with realistic depth and perspective.
199
+
200
+ For more details, shoot us a message on discord.
201
+ """
202
+ )
203
  with gr.Tabs():
204
  for mode in MODE_DEFAULTS:
205
  with gr.Tab(mode):
206
  defaults = MODE_DEFAULTS[mode]
207
  gr.Markdown(f"### {mode} Mode")
208
+ gr.Markdown(f"**Default Model:** {defaults['model']}")
209
 
210
  with gr.Row():
211
+ with gr.Column(scale=2, min_width=370):
212
+ input_image = gr.Image(
213
+ label="Upload Image",
214
+ type="pil",
215
+ scale=3,
216
+ height=370,
217
+ min_width=100,
218
+ )
219
  generate_button = gr.Button("Generate")
220
+ with gr.Blocks(title="Options"):
221
  model_dropdown = gr.Dropdown(
222
  label="Model",
223
  choices=defaults["models"],
 
229
  )
230
 
231
  with gr.Column(scale=2):
232
+ output_image = gr.Image(
233
+ label="Generated Image",
234
+ type="pil",
235
+ height=573,
236
+ scale=4,
237
+ min_width=100,
238
+ )
239
 
240
+ gr.Markdown("#### Prompt")
241
  prompt_box = gr.Textbox(
242
  label="Prompt", value=defaults["prompt"], lines=2
243
  )
244
 
245
+ # Wrap generation parameters in an Accordion for collapsible view.
246
  with gr.Accordion("Generation Parameters", open=False):
247
  with gr.Row():
248
  step_slider = gr.Slider(
 
271
  label="Width",
272
  )
273
 
274
+ def on_generate_click(
275
+ model_name,
276
+ prompt,
277
+ steps,
278
+ strength,
279
+ height,
280
+ width,
281
+ remove_bg,
282
+ image,
283
+ ):
284
+ return call_baseten_generate(
285
+ image,
286
+ prompt,
287
+ steps,
288
+ strength,
289
+ height,
290
+ width,
291
+ model_name,
292
+ remove_bg,
293
+ )
294
 
295
  generate_button.click(
296
  fn=on_generate_click,
 
305
  input_image,
306
  ],
307
  outputs=[output_image],
308
+ concurrency_limit=None
309
  )
310
 
 
 
 
 
 
 
 
 
311
 
312
  if __name__ == "__main__":
313
+ demo.launch()