Shalmoni commited on
Commit
86e0290
·
verified ·
1 Parent(s): c4bcf4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -23
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import time, base64, io, os, requests, traceback, binascii
2
  from typing import Optional
3
  from PIL import Image
4
  import gradio as gr
@@ -13,7 +13,7 @@ HORDE_STATUS = "https://stablehorde.net/api/v2/generate/status/{id}"
13
 
14
  # HF Space secret recommended for priority
15
  HORDE_API_KEY = os.getenv("HORDE_API_KEY", "")
16
- CLIENT_AGENT = "StitchMaster/0.2 (https://huggingface.co/spaces/your-space)"
17
 
18
  DEFAULT_STEPS = 24
19
  DEFAULT_W = 704 # keep defaults under the 715px threshold
@@ -22,6 +22,7 @@ POLL_INTERVAL = 2.5
22
  POLL_TIMEOUT = 240 # bump if queues are long
23
  MODEL = None # or set e.g. "SDXL 1.0"
24
 
 
25
  def _headers():
26
  # Always send an apikey; fallback to anonymous for testing
27
  return {
@@ -46,9 +47,35 @@ def _b64_to_bytes(s: str) -> bytes:
46
  try:
47
  return base64.b64decode(s, validate=False)
48
  except Exception:
49
- # final fallback using urlsafe decoder
50
  return base64.urlsafe_b64decode(s + "=" * ((4 - len(s) % 4) % 4))
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def build_prompt(user_text: str, is_first: bool, lock_longshot: bool = True) -> str:
53
  """Compose continuity-aware prompt text."""
54
  user_text = (user_text or "").strip()
@@ -72,7 +99,7 @@ def build_prompt(user_text: str, is_first: bool, lock_longshot: bool = True) ->
72
  return base
73
 
74
  # =========================
75
- # Horde client with debugging (txt2img OR img2img)
76
  # =========================
77
  def horde_generate(
78
  prompt: str,
@@ -105,7 +132,7 @@ def horde_generate(
105
  payload["params"]["steps"] = min(int(payload["params"]["steps"]), 30)
106
  payload["params"]["width"] = min(int(payload["params"]["width"]), 704)
107
  payload["params"]["height"] = min(int(payload["params"]["height"]), 704)
108
- dbg.append("Fallback applied: steps<=30, width/height<=704. Retrying submit...")
109
  sub = requests.post(HORDE_URL, json=payload, headers=_headers(), timeout=30)
110
  return sub
111
 
@@ -143,7 +170,6 @@ def horde_generate(
143
  dbg.append(f"SUBMIT json={submit_j}")
144
  raise gr.Error("Horde submit succeeded but no job id returned.")
145
  dbg.append(f"JOB id={job_id}")
146
- # Poll & decode
147
  return _poll_and_decode(job_id, dbg)
148
  except Exception:
149
  dbg.append("IMG2IMG path failed, falling back to text-only:\n" + traceback.format_exc())
@@ -211,8 +237,14 @@ def _poll_and_decode(job_id: str, dbg: list[str]):
211
  dbg.append(f"GEN keys: {list(g0.keys())}")
212
  dbg.append(f"img_type: {g0.get('img_type')}")
213
 
214
- # Prefer URL if present
215
- url = g0.get("r2") or g0.get("url") or g0.get("src") or g0.get("image_url")
 
 
 
 
 
 
216
  if isinstance(url, str) and (url.startswith("http://") or url.startswith("https://")):
217
  dbg.append("Found URL in generation → fetching…")
218
  try:
@@ -224,28 +256,46 @@ def _poll_and_decode(job_id: str, dbg: list[str]):
224
  dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
225
  return None, "\n".join(dbg)
226
 
227
- # Base64 branch
228
  b64 = g0.get("img")
229
  if not b64:
230
  dbg.append("No 'img' field present.")
231
  return None, "\n".join(dbg)
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  try:
234
  img_bytes = _b64_to_bytes(b64)
235
  except Exception as e:
236
  dbg.append(f"Base64/urlsafe decode failed: {type(e).__name__}: {e}")
237
- # Try interpret as URL text after decode
238
- try:
239
- txt = _b64_to_bytes(b64).decode("utf-8", "ignore").strip()
240
- if txt.startswith("http"):
241
- r = requests.get(txt, timeout=60)
242
- r.raise_for_status()
243
- img_bytes = r.content
244
- else:
245
- return None, "\n".join(dbg)
246
- except Exception as e2:
247
- dbg.append(f"Secondary b64→text URL parse failed: {type(e2).__name__}: {e2}")
248
- return None, "\n".join(dbg)
 
 
 
 
 
249
 
250
  return _decode_bytes_to_image(img_bytes, dbg)
251
 
@@ -422,5 +472,4 @@ with gr.Blocks(css=CUSTOM_CSS, title="Image Checkpoints – Stable Horde (txt2im
422
  b4.click(fn=generate_next, inputs=[p4, steps, size, lock, img3, change], outputs=[img4, debug_box])
423
 
424
  if __name__ == "__main__":
425
- demo.queue().launch()
426
-
 
1
+ import time, base64, io, os, requests, traceback, binascii, gzip, zlib, bz2, lzma, re
2
  from typing import Optional
3
  from PIL import Image
4
  import gradio as gr
 
13
 
14
  # HF Space secret recommended for priority
15
  HORDE_API_KEY = os.getenv("HORDE_API_KEY", "")
16
+ CLIENT_AGENT = "StitchMaster/0.3 (https://huggingface.co/spaces/your-space)"
17
 
18
  DEFAULT_STEPS = 24
19
  DEFAULT_W = 704 # keep defaults under the 715px threshold
 
22
  POLL_TIMEOUT = 240 # bump if queues are long
23
  MODEL = None # or set e.g. "SDXL 1.0"
24
 
25
+ # ---------------- Helpers ----------------
26
  def _headers():
27
  # Always send an apikey; fallback to anonymous for testing
28
  return {
 
47
  try:
48
  return base64.b64decode(s, validate=False)
49
  except Exception:
 
50
  return base64.urlsafe_b64decode(s + "=" * ((4 - len(s) % 4) % 4))
51
 
52
+ def _maybe_decompress(buf: bytes, dbg: list[str]) -> bytes:
53
+ """
54
+ If bytes are compressed, try gzip, zlib, bz2, lzma (in that order).
55
+ Returns original buf if none succeed.
56
+ """
57
+ head = buf[:4]
58
+ try:
59
+ # gzip
60
+ if len(buf) > 2 and buf[0:2] == b"\x1f\x8b":
61
+ dbg.append("Detected gzip; decompressing…")
62
+ return gzip.decompress(buf)
63
+ # zlib (78 01/9C/DA)
64
+ if len(buf) > 2 and buf[0] == 0x78 and buf[1] in (0x01, 0x5E, 0x9C, 0xDA):
65
+ dbg.append("Detected zlib; decompressing…")
66
+ return zlib.decompress(buf)
67
+ # bz2
68
+ if len(buf) > 3 and buf[0:3] == b"BZh":
69
+ dbg.append("Detected bz2; decompressing…")
70
+ return bz2.decompress(buf)
71
+ # lzma/xz
72
+ if len(buf) > 6 and buf[0:6] == b"\xfd7zXZ\x00":
73
+ dbg.append("Detected lzma/xz; decompressing…")
74
+ return lzma.decompress(buf)
75
+ except Exception as e:
76
+ dbg.append(f"Decompress probe failed: {type(e).__name__}: {e}")
77
+ return buf
78
+
79
  def build_prompt(user_text: str, is_first: bool, lock_longshot: bool = True) -> str:
80
  """Compose continuity-aware prompt text."""
81
  user_text = (user_text or "").strip()
 
99
  return base
100
 
101
  # =========================
102
+ # Horde client (txt2img OR img2img)
103
  # =========================
104
  def horde_generate(
105
  prompt: str,
 
132
  payload["params"]["steps"] = min(int(payload["params"]["steps"]), 30)
133
  payload["params"]["width"] = min(int(payload["params"]["width"]), 704)
134
  payload["params"]["height"] = min(int(payload["params"]["height"]), 704)
135
+ dbg.append("Fallback applied: steps<=30, width/height<=704. Retrying submit")
136
  sub = requests.post(HORDE_URL, json=payload, headers=_headers(), timeout=30)
137
  return sub
138
 
 
170
  dbg.append(f"SUBMIT json={submit_j}")
171
  raise gr.Error("Horde submit succeeded but no job id returned.")
172
  dbg.append(f"JOB id={job_id}")
 
173
  return _poll_and_decode(job_id, dbg)
174
  except Exception:
175
  dbg.append("IMG2IMG path failed, falling back to text-only:\n" + traceback.format_exc())
 
237
  dbg.append(f"GEN keys: {list(g0.keys())}")
238
  dbg.append(f"img_type: {g0.get('img_type')}")
239
 
240
+ # Prefer URL if present (also check gen_metadata for r2 url)
241
+ url = (
242
+ g0.get("r2")
243
+ or g0.get("url")
244
+ or g0.get("src")
245
+ or g0.get("image_url")
246
+ or (isinstance(g0.get("gen_metadata"), dict) and (g0["gen_metadata"].get("r2") or g0["gen_metadata"].get("url")))
247
+ )
248
  if isinstance(url, str) and (url.startswith("http://") or url.startswith("https://")):
249
  dbg.append("Found URL in generation → fetching…")
250
  try:
 
256
  dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
257
  return None, "\n".join(dbg)
258
 
259
+ # Base64 (or hex/encoded) path
260
  b64 = g0.get("img")
261
  if not b64:
262
  dbg.append("No 'img' field present.")
263
  return None, "\n".join(dbg)
264
 
265
+ # If 'img' looks like a URL string, fetch it
266
+ if b64.startswith("http://") or b64.startswith("https://"):
267
+ dbg.append("img field is a URL string → fetching…")
268
+ try:
269
+ r = requests.get(b64, timeout=60)
270
+ r.raise_for_status()
271
+ img_bytes = r.content
272
+ return _decode_bytes_to_image(img_bytes, dbg)
273
+ except Exception as e:
274
+ dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
275
+ return None, "\n".join(dbg)
276
+
277
+ # Try base64-url safe first
278
  try:
279
  img_bytes = _b64_to_bytes(b64)
280
  except Exception as e:
281
  dbg.append(f"Base64/urlsafe decode failed: {type(e).__name__}: {e}")
282
+ img_bytes = None
283
+
284
+ # If that failed or header looks wrong, try hex
285
+ if not img_bytes or len(img_bytes) < 8:
286
+ if re.fullmatch(r"[0-9a-fA-F]+", b64) and len(b64) % 2 == 0:
287
+ dbg.append("img looks like hex → decoding…")
288
+ try:
289
+ img_bytes = bytes.fromhex(b64)
290
+ except Exception as e:
291
+ dbg.append(f"Hex decode failed: {type(e).__name__}: {e}")
292
+ img_bytes = None
293
+
294
+ if not img_bytes:
295
+ return None, "\n".join(dbg)
296
+
297
+ # Some workers compress payloads—try to decompress if needed
298
+ img_bytes = _maybe_decompress(img_bytes, dbg)
299
 
300
  return _decode_bytes_to_image(img_bytes, dbg)
301
 
 
472
  b4.click(fn=generate_next, inputs=[p4, steps, size, lock, img3, change], outputs=[img4, debug_box])
473
 
474
  if __name__ == "__main__":
475
+ demo.queue().launch()