Shalmoni commited on
Commit
809f869
·
verified ·
1 Parent(s): fdaee5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -39
app.py CHANGED
@@ -1,8 +1,7 @@
1
- import time, base64, io, os, requests, traceback
2
  from typing import Optional
3
- from PIL import Image, UnidentifiedImageError
4
  import gradio as gr
5
- import binascii
6
  import imageio.v2 as imageio
7
  import numpy as np
8
 
@@ -52,7 +51,9 @@ def horde_txt2img(prompt: str,
52
  "n": 1
53
  },
54
  "nsfw": False,
55
- "censor_nsfw": True
 
 
56
  }
57
  if model:
58
  payload["models"] = [model]
@@ -93,7 +94,7 @@ def horde_txt2img(prompt: str,
93
  dbg.append("SUBMIT exception:\n" + traceback.format_exc())
94
  return None, "\n".join(dbg)
95
 
96
- # -------- Poll --------
97
  start = time.time()
98
  while True:
99
  try:
@@ -120,50 +121,65 @@ def horde_txt2img(prompt: str,
120
  dbg.append("DONE but no generations returned.")
121
  return None, "\n".join(dbg)
122
 
123
- b64 = gens[0].get("img")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if not b64:
125
- dbg.append(f"GEN fields: {list(gens[0].keys())}")
126
  return None, "\n".join(dbg)
127
 
128
- # ---- robust decode path ----
 
 
 
 
 
 
 
 
 
 
 
 
129
  # 1) fix base64 padding if needed
130
  pad = (-len(b64)) % 4
131
  if pad:
132
  b64 = b64 + ("=" * pad)
133
-
134
  try:
135
  img_bytes = base64.b64decode(b64, validate=False)
136
  except binascii.Error as e:
137
  dbg.append(f"Base64 decode error: {e}")
138
- return None, "\n".join(dbg)
139
-
140
- # 2) log magic header to debug
141
- head = img_bytes[:12]
142
- dbg.append(f"header bytes: {head.hex(' ')}")
143
-
144
- # 3) try Pillow first
145
- try:
146
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
147
- return img, "\n".join(dbg)
148
- except Exception as e:
149
- dbg.append(f"PIL decode failed: {type(e).__name__}: {e}")
150
-
151
- # 4) imageio fallback (handles webp/png edge cases)
152
- try:
153
- arr = imageio.imread(io.BytesIO(img_bytes))
154
- if isinstance(arr, np.ndarray):
155
- if arr.ndim == 2: # grayscale → RGB
156
- arr = np.stack([arr, arr, arr], axis=-1)
157
- elif arr.shape[-1] == 4: # RGBA → RGB
158
- arr = arr[..., :3]
159
- img = Image.fromarray(arr.astype(np.uint8), mode="RGB")
160
- dbg.append("Decoded via imageio fallback.")
161
- return img, "\n".join(dbg)
162
- except Exception as e:
163
- dbg.append(f"imageio decode failed: {type(e).__name__}: {e}")
164
-
165
- dbg.append("PIL could not decode image bytes.")
166
- return None, "\n".join(dbg)
167
 
168
  if time.time() - start > POLL_TIMEOUT:
169
  dbg.append("TIMEOUT waiting for Horde.")
@@ -175,6 +191,43 @@ def horde_txt2img(prompt: str,
175
  dbg.append("POLL exception:\n" + traceback.format_exc())
176
  return None, "\n".join(dbg)
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def generate_image(prompt, steps, size):
179
  # size like "704x704"
180
  try:
@@ -229,7 +282,7 @@ with gr.Blocks(css=CUSTOM_CSS, title="Image Checkpoints – Stable Horde") as de
229
  debug_box = gr.Code(label="Debug log", interactive=False)
230
 
231
  prompt_boxes, gen_buttons, img_outputs = [], [], []
232
- for i in range(1, 4 + 1):
233
  with gr.Row():
234
  with gr.Column(scale=1, min_width=320):
235
  p = gr.Textbox(
 
1
+ import time, base64, io, os, requests, traceback, binascii
2
  from typing import Optional
3
+ from PIL import Image
4
  import gradio as gr
 
5
  import imageio.v2 as imageio
6
  import numpy as np
7
 
 
51
  "n": 1
52
  },
53
  "nsfw": False,
54
+ "censor_nsfw": True,
55
+ # Ask Horde to return a CDN URL if available (many deployments support this)
56
+ "r2": True
57
  }
58
  if model:
59
  payload["models"] = [model]
 
94
  dbg.append("SUBMIT exception:\n" + traceback.format_exc())
95
  return None, "\n".join(dbg)
96
 
97
+ # -------- Poll --------
98
  start = time.time()
99
  while True:
100
  try:
 
121
  dbg.append("DONE but no generations returned.")
122
  return None, "\n".join(dbg)
123
 
124
+ g0 = gens[0]
125
+ dbg.append(f"GEN keys: {list(g0.keys())}")
126
+ dbg.append(f"img_type: {g0.get('img_type')}")
127
+
128
+ # Prefer URL fields if present
129
+ url = g0.get("r2") or g0.get("url") or g0.get("src") or g0.get("image_url")
130
+ if isinstance(url, str) and (url.startswith("http://") or url.startswith("https://")):
131
+ dbg.append("Found URL in generation → fetching…")
132
+ try:
133
+ r = requests.get(url, timeout=60)
134
+ r.raise_for_status()
135
+ img_bytes = r.content
136
+ except Exception as e:
137
+ dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
138
+ return None, "\n".join(dbg)
139
+ return _decode_bytes_to_image(img_bytes, dbg)
140
+
141
+ # Else fall back to base64 field
142
+ b64 = g0.get("img")
143
  if not b64:
144
+ dbg.append("No 'img' field present.")
145
  return None, "\n".join(dbg)
146
 
147
+ # If 'img' looks like URL text (rare), just fetch it
148
+ if b64.startswith("http://") or b64.startswith("https://"):
149
+ dbg.append("img field is a URL string → fetching…")
150
+ try:
151
+ r = requests.get(b64, timeout=60)
152
+ r.raise_for_status()
153
+ img_bytes = r.content
154
+ except Exception as e:
155
+ dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
156
+ return None, "\n".join(dbg)
157
+ return _decode_bytes_to_image(img_bytes, dbg)
158
+
159
+ # Base64 path
160
  # 1) fix base64 padding if needed
161
  pad = (-len(b64)) % 4
162
  if pad:
163
  b64 = b64 + ("=" * pad)
 
164
  try:
165
  img_bytes = base64.b64decode(b64, validate=False)
166
  except binascii.Error as e:
167
  dbg.append(f"Base64 decode error: {e}")
168
+ # Try to interpret as text (maybe it's a URL encoded in base64)
169
+ try:
170
+ txt = base64.b64decode(b64 + "==", validate=False).decode("utf-8", "ignore").strip()
171
+ if txt.startswith("http"):
172
+ dbg.append("Base64 decoded to text URL → fetching…")
173
+ r = requests.get(txt, timeout=60)
174
+ r.raise_for_status()
175
+ img_bytes = r.content
176
+ else:
177
+ return None, "\n".join(dbg)
178
+ except Exception as e2:
179
+ dbg.append(f"Secondary b64/text parse failed: {type(e2).__name__}: {e2}")
180
+ return None, "\n".join(dbg)
181
+
182
+ return _decode_bytes_to_image(img_bytes, dbg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  if time.time() - start > POLL_TIMEOUT:
185
  dbg.append("TIMEOUT waiting for Horde.")
 
191
  dbg.append("POLL exception:\n" + traceback.format_exc())
192
  return None, "\n".join(dbg)
193
 
194
+ def _decode_bytes_to_image(img_bytes: bytes, dbg: list[str]):
195
+ # Log header
196
+ head = img_bytes[:12]
197
+ dbg.append(f"header bytes: {head.hex(' ')}")
198
+
199
+ # Try Pillow first
200
+ try:
201
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
202
+ return img, "\n".join(dbg)
203
+ except Exception as e:
204
+ dbg.append(f"PIL decode failed: {type(e).__name__}: {e}")
205
+
206
+ # Fallback: imageio
207
+ try:
208
+ arr = imageio.imread(io.BytesIO(img_bytes))
209
+ if isinstance(arr, np.ndarray):
210
+ if arr.ndim == 2: # grayscale → RGB
211
+ arr = np.stack([arr, arr, arr], axis=-1)
212
+ elif arr.shape[-1] == 4: # RGBA → RGB
213
+ arr = arr[..., :3]
214
+ img = Image.fromarray(arr.astype(np.uint8), mode="RGB")
215
+ dbg.append("Decoded via imageio fallback.")
216
+ return img, "\n".join(dbg)
217
+ except Exception as e:
218
+ dbg.append(f"imageio decode failed: {type(e).__name__}: {e}")
219
+
220
+ # Last resort: save bytes for inspection
221
+ try:
222
+ tmp = f"unknown_img_{int(time.time())}.bin"
223
+ with open(tmp, "wb") as f:
224
+ f.write(img_bytes)
225
+ dbg.append(f"Wrote undecodable bytes to {tmp}")
226
+ except Exception as e:
227
+ dbg.append(f"Could not write debug bytes: {type(e).__name__}: {e}")
228
+
229
+ return None, "\n".join(dbg)
230
+
231
  def generate_image(prompt, steps, size):
232
  # size like "704x704"
233
  try:
 
282
  debug_box = gr.Code(label="Debug log", interactive=False)
283
 
284
  prompt_boxes, gen_buttons, img_outputs = [], [], []
285
+ for i in range(1, 5):
286
  with gr.Row():
287
  with gr.Column(scale=1, min_width=320):
288
  p = gr.Textbox(