Shalmoni commited on
Commit
5a6bbaa
·
verified ·
1 Parent(s): 3711439

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -449
app.py CHANGED
@@ -1,475 +1,254 @@
1
- import time, base64, io, os, requests, traceback, binascii, gzip, zlib, bz2, lzma, re
2
- from typing import Optional
 
 
3
  from PIL import Image
4
  import gradio as gr
5
- import imageio.v2 as imageio
6
- import numpy as np
7
-
8
- # =========================
9
- # Stable Horde config
10
- # =========================
11
- HORDE_URL = "https://stablehorde.net/api/v2/generate/async"
12
- HORDE_STATUS = "https://stablehorde.net/api/v2/generate/status/{id}"
13
-
14
- # HF Space secret recommended for priority
15
- HORDE_API_KEY = os.getenv("HORDE_API_KEY", "")
16
- CLIENT_AGENT = "StitchMaster/0.3 (https://huggingface.co/spaces/your-space)"
17
-
18
- DEFAULT_STEPS = 24
19
- DEFAULT_W = 704 # keep defaults under the 715px threshold
20
- DEFAULT_H = 704
21
- POLL_INTERVAL = 2.5
22
- POLL_TIMEOUT = 240 # bump if queues are long
23
- MODEL = None # or set e.g. "SDXL 1.0"
24
-
25
- # ---------------- Helpers ----------------
26
- def _headers():
27
- # Always send an apikey; fallback to anonymous for testing
28
- return {
29
- "Client-Agent": CLIENT_AGENT,
30
- "apikey": HORDE_API_KEY if HORDE_API_KEY else "0000000000"
31
- }
32
-
33
- def pil_to_b64(img_pil: Image.Image) -> str:
34
  buf = io.BytesIO()
35
- img_pil.save(buf, format="PNG")
36
- return base64.b64encode(buf.getvalue()).decode("utf-8")
37
 
38
- def _b64_to_bytes(s: str) -> bytes:
39
- """
40
- Robustly decode base64 / base64url (handles '-' '_' and missing padding).
41
- """
42
- s = (s or "").strip()
43
- s = s.replace("-", "+").replace("_", "/")
44
- pad = (-len(s)) % 4
45
- if pad:
46
- s += "=" * pad
47
- try:
48
- return base64.b64decode(s, validate=False)
49
- except Exception:
50
- return base64.urlsafe_b64decode(s + "=" * ((4 - len(s) % 4) % 4))
51
 
52
- def _maybe_decompress(buf: bytes, dbg: list[str]) -> bytes:
53
- """
54
- If bytes are compressed, try gzip, zlib, bz2, lzma (in that order).
55
- Returns original buf if none succeed.
56
- """
57
- head = buf[:4]
58
- try:
59
- # gzip
60
- if len(buf) > 2 and buf[0:2] == b"\x1f\x8b":
61
- dbg.append("Detected gzip; decompressing…")
62
- return gzip.decompress(buf)
63
- # zlib (78 01/9C/DA)
64
- if len(buf) > 2 and buf[0] == 0x78 and buf[1] in (0x01, 0x5E, 0x9C, 0xDA):
65
- dbg.append("Detected zlib; decompressing…")
66
- return zlib.decompress(buf)
67
- # bz2
68
- if len(buf) > 3 and buf[0:3] == b"BZh":
69
- dbg.append("Detected bz2; decompressing…")
70
- return bz2.decompress(buf)
71
- # lzma/xz
72
- if len(buf) > 6 and buf[0:6] == b"\xfd7zXZ\x00":
73
- dbg.append("Detected lzma/xz; decompressing…")
74
- return lzma.decompress(buf)
75
- except Exception as e:
76
- dbg.append(f"Decompress probe failed: {type(e).__name__}: {e}")
77
- return buf
78
-
79
- def build_prompt(user_text: str, is_first: bool, lock_longshot: bool = True) -> str:
80
- """Compose continuity-aware prompt text."""
81
- user_text = (user_text or "").strip()
82
- longshot_plus = (
83
- "single continuous long shot; no cuts or new shot; no angle switch; "
84
- "smooth camera motion (pan/tilt/zoom only); unbroken continuity"
85
- )
86
- if is_first:
87
- base = f"Opening frame. {user_text}" if user_text else "Opening frame."
88
- if lock_longshot:
89
- base += ". " + longshot_plus
90
- return base
91
- # Subsequent frames
92
- base = (
93
- "Treat the previous frame as a still from the same continuous long shot. "
94
- "Maintain style, subject identity, lighting, and camera continuity. "
95
- f"Generate the next moment: {user_text if user_text else 'advance the action naturally.'}"
96
- )
97
- if lock_longshot:
98
- base += ". " + longshot_plus
99
- return base
100
-
101
- # =========================
102
- # Horde client (txt2img OR img2img)
103
- # =========================
104
- def horde_generate(
105
- prompt: str,
106
- steps: int = DEFAULT_STEPS,
107
- width: int = DEFAULT_W,
108
- height: int = DEFAULT_H,
109
- model: Optional[str] = MODEL,
110
- init_image: Optional[Image.Image] = None,
111
- denoise: float = 0.45, # 0.0 = identical, 1.0 = big change
112
- ):
113
  """
114
- If init_image is provided, tries img2img first (source_image + source_processing='img2img').
115
- Falls back to txt2img if Horde rejects it.
116
  """
117
  dbg = []
118
- if not (prompt and prompt.strip()):
119
- raise gr.Error("Please enter a prompt.")
120
-
121
- def _submit(payload):
122
- sub = requests.post(HORDE_URL, json=payload, headers=_headers(), timeout=30)
123
- # Auto-fallback if KudosUpfront required
124
- if sub.status_code == 403:
125
- try:
126
- body = sub.json()
127
- except Exception:
128
- body = {"message": sub.text}
129
- msg = (body.get("message") or "").lower()
130
- rc = body.get("rc") or ""
131
- if "kudos" in msg or rc == "KudosUpfront":
132
- payload["params"]["steps"] = min(int(payload["params"]["steps"]), 30)
133
- payload["params"]["width"] = min(int(payload["params"]["width"]), 704)
134
- payload["params"]["height"] = min(int(payload["params"]["height"]), 704)
135
- dbg.append("Fallback applied: steps<=30, width/height<=704. Retrying submit…")
136
- sub = requests.post(HORDE_URL, json=payload, headers=_headers(), timeout=30)
137
- return sub
138
-
139
- # ---- try img2img if init_image provided ----
140
- tried_img2img = False
141
- if init_image is not None:
142
- tried_img2img = True
143
- payload = {
144
- "prompt": prompt.strip(),
145
- "params": {
146
- "steps": int(steps),
147
- "width": int(width),
148
- "height": int(height),
149
- "n": 1,
150
- "denoise": float(denoise)
151
- },
152
- "nsfw": False,
153
- "censor_nsfw": True,
154
- "source_processing": "img2img",
155
- "source_image": pil_to_b64(init_image),
156
- "r2": True
157
- }
158
- if model:
159
- payload["models"] = [model]
160
 
161
- try:
162
- submit = _submit(payload)
163
- dbg.append(f"SUBMIT (img2img) status={submit.status_code}")
164
- if submit.status_code >= 300:
165
- dbg.append(f"SUBMIT body={submit.text[:500]}")
166
- submit.raise_for_status()
167
- submit_j = submit.json()
168
- job_id = submit_j.get("id")
169
- if not job_id:
170
- dbg.append(f"SUBMIT json={submit_j}")
171
- raise gr.Error("Horde submit succeeded but no job id returned.")
172
- dbg.append(f"JOB id={job_id}")
173
- return _poll_and_decode(job_id, dbg)
174
- except Exception:
175
- dbg.append("IMG2IMG path failed, falling back to text-only:\n" + traceback.format_exc())
176
-
177
- # ---- txt2img path ----
178
- payload = {
179
- "prompt": prompt.strip(),
180
- "params": {
181
- "steps": int(steps),
182
- "width": int(width),
183
- "height": int(height),
184
- "n": 1
185
- },
186
- "nsfw": False,
187
- "censor_nsfw": True,
188
- "r2": True
189
- }
190
- if model:
191
- payload["models"] = [model]
192
 
193
- try:
194
- submit = _submit(payload)
195
- dbg.append(f"SUBMIT (txt2img{', after img2img fail' if tried_img2img else ''}) status={submit.status_code}")
196
- if submit.status_code >= 300:
197
- dbg.append(f"SUBMIT body={submit.text[:500]}")
198
- submit.raise_for_status()
199
- submit_j = submit.json()
200
- job_id = submit_j.get("id")
201
- if not job_id:
202
- dbg.append(f"SUBMIT json={submit_j}")
203
- raise gr.Error("Horde submit succeeded but no job id returned.")
204
- dbg.append(f"JOB id={job_id}")
205
- return _poll_and_decode(job_id, dbg)
206
- except Exception:
207
- dbg.append("SUBMIT exception:\n" + traceback.format_exc())
208
- return None, "\n".join(dbg)
209
 
210
- def _poll_and_decode(job_id: str, dbg: list[str]):
211
- start = time.time()
212
- while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  try:
214
- status_r = requests.get(HORDE_STATUS.format(id=job_id), headers=_headers(), timeout=30)
215
- if status_r.status_code >= 300:
216
- dbg.append(f"POLL status={status_r.status_code}")
217
- dbg.append(f"POLL body={status_r.text[:500]}")
218
- status_r.raise_for_status()
219
- s = status_r.json()
220
-
221
- k = s.get("kudos", "?")
222
- queue = s.get("queue_position", "?")
223
- eta = s.get("wait_time", "?")
224
- dbg.append(f"queue={queue} eta≈{eta}s kudos={k}")
225
-
226
- if s.get("faulted"):
227
- dbg.append(f"FAULT: {s}")
228
- return None, "\n".join(dbg)
229
-
230
- if s.get("done"):
231
- gens = s.get("generations") or []
232
- if not gens:
233
- dbg.append("DONE but no generations returned.")
234
- return None, "\n".join(dbg)
235
-
236
- g0 = gens[0]
237
- dbg.append(f"GEN keys: {list(g0.keys())}")
238
- dbg.append(f"img_type: {g0.get('img_type')}")
239
-
240
- # Prefer URL if present (also check gen_metadata for r2 url)
241
- url = (
242
- g0.get("r2")
243
- or g0.get("url")
244
- or g0.get("src")
245
- or g0.get("image_url")
246
- or (isinstance(g0.get("gen_metadata"), dict) and (g0["gen_metadata"].get("r2") or g0["gen_metadata"].get("url")))
247
- )
248
- if isinstance(url, str) and (url.startswith("http://") or url.startswith("https://")):
249
- dbg.append("Found URL in generation → fetching…")
250
- try:
251
- r = requests.get(url, timeout=60)
252
- r.raise_for_status()
253
- img_bytes = r.content
254
- return _decode_bytes_to_image(img_bytes, dbg)
255
- except Exception as e:
256
- dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
257
- return None, "\n".join(dbg)
258
-
259
- # Base64 (or hex/encoded) path
260
- b64 = g0.get("img")
261
- if not b64:
262
- dbg.append("No 'img' field present.")
263
- return None, "\n".join(dbg)
264
-
265
- # If 'img' looks like a URL string, fetch it
266
- if b64.startswith("http://") or b64.startswith("https://"):
267
- dbg.append("img field is a URL string → fetching…")
268
- try:
269
- r = requests.get(b64, timeout=60)
270
- r.raise_for_status()
271
- img_bytes = r.content
272
- return _decode_bytes_to_image(img_bytes, dbg)
273
- except Exception as e:
274
- dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
275
- return None, "\n".join(dbg)
276
-
277
- # Try base64-url safe first
278
- try:
279
- img_bytes = _b64_to_bytes(b64)
280
- except Exception as e:
281
- dbg.append(f"Base64/urlsafe decode failed: {type(e).__name__}: {e}")
282
- img_bytes = None
283
-
284
- # If that failed or header looks wrong, try hex
285
- if not img_bytes or len(img_bytes) < 8:
286
- if re.fullmatch(r"[0-9a-fA-F]+", b64) and len(b64) % 2 == 0:
287
- dbg.append("img looks like hex → decoding…")
288
- try:
289
- img_bytes = bytes.fromhex(b64)
290
- except Exception as e:
291
- dbg.append(f"Hex decode failed: {type(e).__name__}: {e}")
292
- img_bytes = None
293
-
294
- if not img_bytes:
295
- return None, "\n".join(dbg)
296
-
297
- # Some workers compress payloads—try to decompress if needed
298
- img_bytes = _maybe_decompress(img_bytes, dbg)
299
-
300
- return _decode_bytes_to_image(img_bytes, dbg)
301
-
302
- if time.time() - start > POLL_TIMEOUT:
303
- dbg.append("TIMEOUT waiting for Horde.")
304
- return None, "\n".join(dbg)
305
-
306
- time.sleep(POLL_INTERVAL)
307
-
308
  except Exception:
309
- dbg.append("POLL exception:\n" + traceback.format_exc())
310
- return None, "\n".join(dbg)
311
-
312
- def _decode_bytes_to_image(img_bytes: bytes, dbg: list[str]):
313
- head = img_bytes[:12]
314
- dbg.append(f"header bytes: {head.hex(' ')}")
315
- try:
316
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
317
- return img, "\n".join(dbg)
318
- except Exception as e:
319
- dbg.append(f"PIL decode failed: {type(e).__name__}: {e}")
320
 
321
- try:
322
- arr = imageio.imread(io.BytesIO(img_bytes))
323
- if isinstance(arr, np.ndarray):
324
- if arr.ndim == 2: # grayscale → RGB
325
- arr = np.stack([arr, arr, arr], axis=-1)
326
- elif arr.shape[-1] == 4: # RGBA → RGB
327
- arr = arr[..., :3]
328
- img = Image.fromarray(arr.astype(np.uint8), mode="RGB")
329
- dbg.append("Decoded via imageio fallback.")
330
- return img, "\n".join(dbg)
331
  except Exception as e:
332
- dbg.append(f"imageio decode failed: {type(e).__name__}: {e}")
 
333
 
334
- try:
335
- tmp = f"unknown_img_{int(time.time())}.bin"
336
- with open(tmp, "wb") as f:
337
- f.write(img_bytes)
338
- dbg.append(f"Wrote undecodable bytes to {tmp}")
339
- except Exception as e:
340
- dbg.append(f"Could not write debug bytes: {type(e).__name__}: {e}")
341
-
342
- return None, "\n".join(dbg)
343
-
344
- # =========================
345
- # Gradio glue
346
- # =========================
347
- def generate_opening(prompt_text, steps, size, lock):
348
- w, h = _parse_size(size)
349
- prompt = build_prompt(prompt_text, is_first=True, lock_longshot=lock)
350
- img, debug = horde_generate(prompt, steps=steps, width=w, height=h, init_image=None)
351
- if img is None:
352
- gr.Warning("Generation failed. See debug log for details.")
353
- return img, debug
354
-
355
- def generate_next(prompt_text, steps, size, lock, prev_img, change):
356
- w, h = _parse_size(size)
357
- prompt = build_prompt(prompt_text, is_first=False, lock_longshot=lock)
358
- init_img = prev_img if isinstance(prev_img, Image.Image) else None
359
- img, debug = horde_generate(prompt, steps=steps, width=w, height=h,
360
- init_image=init_img, denoise=float(change))
361
- if img is None:
362
- gr.Warning("Generation failed. See debug log for details.")
363
- return img, debug
364
-
365
- def _parse_size(s):
366
- try:
367
- w, h = [int(x.strip()) for x in str(s).lower().split("x")]
368
- except Exception:
369
- w, h = DEFAULT_W, DEFAULT_H
370
- return w, h
371
-
372
- # =========================
373
- # UI
374
- # =========================
375
- CUSTOM_CSS = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  .gradio-container { padding: 24px; }
377
- /* Rounded prompt boxes */
378
- .prompt-box textarea {
379
- border-radius: 18px !important;
380
- min-height: 90px;
381
- font-size: 16px;
382
- line-height: 1.4;
383
- padding: 14px 16px;
384
- }
385
- /* Pill buttons */
386
- .pill button {
387
- border-radius: 999px !important;
388
- padding: 10px 18px;
389
- font-size: 15px;
390
- }
391
- /* Rounded image boxes */
392
- .image-out .wrap, .image-out .svelte-1ipelgc {
393
- border-radius: 22px !important;
394
- }
395
  """
396
 
397
- with gr.Blocks(css=CUSTOM_CSS, title="Image Checkpoints Stable Horde (txt2img + img2img)") as demo:
398
- gr.Markdown("### Image Checkpoints (Stable Horde) Opening shot + next scenes\n"
399
- "Image 2–4 use the previous output as the init image (img2img) with a continuity slider.")
 
 
 
400
 
401
  with gr.Row():
402
- steps = gr.Slider(8, 50, value=DEFAULT_STEPS, step=1, label="Steps (quality/time)")
403
- size = gr.Dropdown(
404
- choices=["512x512", "704x704", "704x512", "640x640"],
405
- value=f"{DEFAULT_W}x{DEFAULT_H}",
406
- label="Resolution"
407
- )
408
- lock = gr.Checkbox(value=True, label="Lock camera (long shot, no cuts)")
 
 
409
 
410
- # Continuity / denoise slider for img2img steps (2–4)
411
- change = gr.Slider(0.05, 0.95, value=0.45, step=0.05, label="Change from previous frame (denoise)")
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
- # Shared debug panel
414
- debug_box = gr.Code(label="Debug log", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- # ---- Row 1: Opening shot ----
417
- with gr.Row():
418
- with gr.Column(scale=1, min_width=320):
419
- p1 = gr.Textbox(
420
- placeholder="Describe the opening shot",
421
- lines=4,
422
- label=None,
423
- elem_classes=["prompt-box"]
424
- )
425
- b1 = gr.Button("Generate image 1", elem_classes=["pill"])
426
- with gr.Column(scale=2, min_width=380):
427
- img1 = gr.Image(label="Image 1 output", type="pil", elem_classes=["image-out"])
428
-
429
- # ---- Row 2: Next scene ----
430
- with gr.Row():
431
- with gr.Column(scale=1, min_width=320):
432
- p2 = gr.Textbox(
433
- placeholder="Describe the next scene",
434
- lines=4,
435
- label=None,
436
- elem_classes=["prompt-box"]
437
- )
438
- b2 = gr.Button("Generate image 2", elem_classes=["pill"])
439
- with gr.Column(scale=2, min_width=380):
440
- img2 = gr.Image(label="Image 2 output", type="pil", elem_classes=["image-out"])
441
-
442
- # ---- Row 3: Next scene ----
443
- with gr.Row():
444
- with gr.Column(scale=1, min_width=320):
445
- p3 = gr.Textbox(
446
- placeholder="Describe the next scene",
447
- lines=4,
448
- label=None,
449
- elem_classes=["prompt-box"]
450
- )
451
- b3 = gr.Button("Generate image 3", elem_classes=["pill"])
452
- with gr.Column(scale=2, min_width=380):
453
- img3 = gr.Image(label="Image 3 output", type="pil", elem_classes=["image-out"])
454
-
455
- # ---- Row 4: Next scene ----
456
- with gr.Row():
457
- with gr.Column(scale=1, min_width=320):
458
- p4 = gr.Textbox(
459
- placeholder="Describe the next scene",
460
- lines=4,
461
- label=None,
462
- elem_classes=["prompt-box"]
463
- )
464
- b4 = gr.Button("Generate image 4", elem_classes=["pill"])
465
- with gr.Column(scale=2, min_width=380):
466
- img4 = gr.Image(label="Image 4 output", type="pil", elem_classes=["image-out"])
467
-
468
- # Wire callbacks
469
- b1.click(fn=generate_opening, inputs=[p1, steps, size, lock], outputs=[img1, debug_box])
470
- b2.click(fn=generate_next, inputs=[p2, steps, size, lock, img1, change], outputs=[img2, debug_box])
471
- b3.click(fn=generate_next, inputs=[p3, steps, size, lock, img2, change], outputs=[img3, debug_box])
472
- b4.click(fn=generate_next, inputs=[p4, steps, size, lock, img3, change], outputs=[img4, debug_box])
473
 
474
  if __name__ == "__main__":
475
- demo.queue().launch()
 
1
+ import os, io, time, random, base64, zipfile
2
+ from typing import List, Tuple, Optional
3
+
4
+ import requests
5
  from PIL import Image
6
  import gradio as gr
7
+
8
+ # ========= Config =========
9
+ MAX_FRAMES = 8 # how many upload slots & rows to render
10
+ MODAL_BASE = "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
+
12
+ # ========= Helpers =========
13
+ def _save_video_bytes(data: bytes, tag: str) -> str:
14
+ os.makedirs("/mnt/data", exist_ok=True)
15
+ path = f"/mnt/data/{tag}_{int(time.time())}.mp4"
16
+ with open(path, "wb") as f:
17
+ f.write(data)
18
+ return path
19
+
20
+ def _png_bytes_from_pil(img: Image.Image) -> bytes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  buf = io.BytesIO()
22
+ img.save(buf, format="PNG")
23
+ return buf.getvalue()
24
 
25
+ def _download_to_bytes(url: str) -> bytes:
26
+ r = requests.get(url, timeout=180)
27
+ r.raise_for_status()
28
+ return r.content
 
 
 
 
 
 
 
 
 
29
 
30
+ def call_modal_i2v(start_img: Image.Image, prompt: str, seed: Optional[int]) -> Tuple[Optional[str], str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
+ POST to Modal with multipart 'image_bytes' and query args prompt & seed.
33
+ Returns (mp4_path_or_None, debug_log).
34
  """
35
  dbg = []
36
+ if seed in (None, 0, -1):
37
+ seed = random.randint(1, 2**31 - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Build URL (encode prompt)
40
+ from urllib.parse import quote
41
+ url = f"{MODAL_BASE}?prompt={quote(prompt)}&seed={seed}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ files = {"image_bytes": ("start.png", _png_bytes_from_pil(start_img), "image/png")}
44
+ headers = {"accept": "application/json"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ try:
47
+ resp = requests.post(url, files=files, headers=headers, timeout=600)
48
+ ctype = (resp.headers.get("content-type") or "").lower()
49
+ dbg.append(f"HTTP {resp.status_code}; content-type={ctype}")
50
+
51
+ # Case A: raw bytes (not JSON)
52
+ if "application/json" not in ctype:
53
+ resp.raise_for_status()
54
+ path = _save_video_bytes(resp.content, "pair")
55
+ dbg.append(f"Saved raw video to {path}")
56
+ return path, "\n".join(dbg)
57
+
58
+ # Case B: JSON containing url or base64
59
+ data = resp.json()
60
+ video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
61
+ video_b64 = data.get("video_b64") or data.get("videoBase64")
62
+
63
+ if video_url and isinstance(video_url, str):
64
+ b = _download_to_bytes(video_url)
65
+ path = _save_video_bytes(b, "pair")
66
+ dbg.append(f"Downloaded video from {video_url} -> {path}")
67
+ return path, "\n".join(dbg)
68
+
69
+ if video_b64 and isinstance(video_b64, str):
70
+ pad = (-len(video_b64)) % 4
71
+ if pad: video_b64 += "=" * pad
72
+ b = base64.b64decode(video_b64)
73
+ path = _save_video_bytes(b, "pair")
74
+ dbg.append("Decoded base64 video.")
75
+ return path, "\n".join(dbg)
76
+
77
+ # Nothing usable returned
78
  try:
79
+ dbg.append(f"Backend JSON: {str(data)[:500]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  except Exception:
81
+ pass
82
+ return None, "\n".join(dbg)
 
 
 
 
 
 
 
 
 
83
 
 
 
 
 
 
 
 
 
 
 
84
  except Exception as e:
85
+ dbg.append(f"Exception: {type(e).__name__}: {e}")
86
+ return None, "\n".join(dbg)
87
 
88
+ # ========= State handlers =========
89
+ def add_images(files: List[str], images_state: List[Image.Image], names_state: List[str]):
90
+ """
91
+ Append uploads to state; return updated previews and row visibilities.
92
+ """
93
+ imgs, names = list(images_state), list(names_state)
94
+ for f in files or []:
95
+ try:
96
+ img = Image.open(f).convert("RGB")
97
+ imgs.append(img)
98
+ names.append(os.path.basename(f))
99
+ except Exception:
100
+ continue
101
+
102
+ # Outputs to update: image slots, labels, visibilities; pair rows visible up to len-1
103
+ img_values, img_labels, img_vis = [], [], []
104
+ pair_vis = []
105
+ for i in range(MAX_FRAMES):
106
+ if i < len(imgs):
107
+ img_values.append(imgs[i])
108
+ img_labels.append(f"Image {i+1}")
109
+ img_vis.append(True)
110
+ else:
111
+ img_values.append(None)
112
+ img_labels.append(f"Image {i+1}")
113
+ img_vis.append(False)
114
+
115
+ for i in range(MAX_FRAMES - 1):
116
+ pair_vis.append(i < len(imgs) - 1)
117
+
118
+ return imgs, names, img_values, img_labels, img_vis, pair_vis
119
+
120
+ def clear_all():
121
+ img_values = [None]*MAX_FRAMES
122
+ img_labels = [f"Image {i+1}" for i in range(MAX_FRAMES)]
123
+ img_vis = [False]*MAX_FRAMES
124
+ pair_vis = [False]*(MAX_FRAMES-1)
125
+ return [], [], img_values, img_labels, img_vis, pair_vis
126
+
127
+ def stitch_pair(index: int,
128
+ images: List[Image.Image],
129
+ prompt: str,
130
+ seed: int):
131
+ """
132
+ index is 0-based pair (0 => 1&2, 1 => 2&3...)
133
+ We call Modal using the *first* image of the pair as the init image.
134
+ """
135
+ if not images or len(images) < index+2:
136
+ gr.Warning("Upload more images first.")
137
+ return None, "Not enough images."
138
+
139
+ # Compose a minimal helpful prompt for continuity
140
+ user = (prompt or "").strip()
141
+ extra = f"(Transition between frame {index+1} → {index+2} of the same shot.)"
142
+ final_prompt = f"{user} {extra}".strip()
143
+
144
+ path, dbg = call_modal_i2v(images[index], final_prompt, seed)
145
+ if path is None:
146
+ gr.Warning("Stitch failed. See debug log.")
147
+ return path, dbg
148
+
149
+ # ========= UI =========
150
+ CSS = """
151
  .gradio-container { padding: 24px; }
152
+ .pill button { border-radius: 999px !important; padding: 10px 18px; }
153
+ .rounded textarea { border-radius: 16px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  """
155
 
156
+ with gr.Blocks(css=CSS, title="Stitch Upload & Stitch Adjacent Pairs") as demo:
157
+ gr.Markdown("## Stitch Upload stills, then generate between-frames videos\n"
158
+ "Upload images in order. For each adjacent pair (1&2, 2&3, ), write a short transition prompt and click **Stitch**.")
159
+
160
+ images_state = gr.State([]) # List[PIL.Image]
161
+ names_state = gr.State([]) # List[str]
162
 
163
  with gr.Row():
164
+ # Left column: image slots
165
+ with gr.Column(scale=1, min_width=340):
166
+ uploader = gr.Files(label="Add images (in order)", file_types=["image"], file_count="multiple")
167
+ clear_btn = gr.Button("Clear all", elem_classes=["pill"])
168
+ image_slots = []
169
+ for i in range(MAX_FRAMES):
170
+ image_slots.append(
171
+ gr.Image(label=f"Image {i+1}", interactive=False, visible=False)
172
+ )
173
 
174
+ # Middle column: per-pair prompt + button
175
+ with gr.Column(scale=1, min_width=340):
176
+ seed_in = gr.Number(value=0, precision=0, label="Seed (0 = random)")
177
+ prompt_boxes = []
178
+ stitch_buttons = []
179
+ for i in range(MAX_FRAMES - 1):
180
+ prompt_boxes.append(
181
+ gr.Textbox(
182
+ placeholder=f"Prompt for transition between Image {i+1} & {i+2}",
183
+ lines=2, label="Prompt", elem_classes=["rounded"], visible=False
184
+ )
185
+ )
186
+ stitch_buttons.append(
187
+ gr.Button(f"Stitch {i+1}&{i+2}", elem_classes=["pill"], visible=False)
188
+ )
189
 
190
+ # Right column: per-pair video outputs + shared debug
191
+ with gr.Column(scale=1, min_width=360):
192
+ video_outputs = []
193
+ for i in range(MAX_FRAMES - 1):
194
+ video_outputs.append(
195
+ gr.Video(label=f"Video (image {i+1}+{i+2}) output", visible=False)
196
+ )
197
+ debug_box = gr.Code(label="Debug log", interactive=False)
198
+
199
+ # ---- Wiring: upload & clear ----
200
+ uploader.upload(
201
+ fn=add_images,
202
+ inputs=[uploader, images_state, names_state],
203
+ outputs=[
204
+ images_state, names_state,
205
+ # image values, labels, visibilities
206
+ *image_slots, # values (Image components accept PIL Image)
207
+ *[s for s in image_slots], # labels: set via .label below (we'll hack via .update)
208
+ *[s for s in image_slots], # visibility
209
+ *[b for b in stitch_buttons] # visibility for rows (we’ll mirror to prompt/video too)
210
+ ],
211
+ queue=False
212
+ )
213
 
214
+ # NOTE: Gradio can't directly set multiple attributes with one function return to each component slot,
215
+ # so we will do a lightweight post-upload JS update using .update. Simpler: tie visibility of prompt/video
216
+ # to the corresponding button's visibility in another handler:
217
+
218
+ def reflect_row_visibility(images: List[Image.Image]):
219
+ n = len(images)
220
+ vis = [i < n-1 for i in range(MAX_FRAMES-1)]
221
+ # return prompt visibilities, button visibilities, video visibilities
222
+ return [gr.Textbox(visible=vis[i]) for i in range(MAX_FRAMES-1)] + \
223
+ [gr.Button(visible=vis[i]) for i in range(MAX_FRAMES-1)] + \
224
+ [gr.Video(visible=vis[i]) for i in range(MAX_FRAMES-1)]
225
+
226
+ uploader.upload(
227
+ fn=reflect_row_visibility,
228
+ inputs=[images_state],
229
+ outputs=[*prompt_boxes, *stitch_buttons, *video_outputs],
230
+ queue=False
231
+ )
232
+
233
+ clear_btn.click(
234
+ fn=clear_all,
235
+ inputs=[],
236
+ outputs=[images_state, names_state, *image_slots, *image_slots, *image_slots, *stitch_buttons],
237
+ queue=False
238
+ ).then(
239
+ fn=lambda imgs: reflect_row_visibility(imgs),
240
+ inputs=[images_state],
241
+ outputs=[*prompt_boxes, *stitch_buttons, *video_outputs],
242
+ queue=False
243
+ )
244
+
245
+ # ---- Wiring: per-pair stitchers ----
246
+ for i in range(MAX_FRAMES - 1):
247
+ stitch_buttons[i].click(
248
+ fn=lambda prompt, seed, imgs, idx=i: stitch_pair(idx, imgs, prompt, int(seed or 0)),
249
+ inputs=[prompt_boxes[i], seed_in, images_state],
250
+ outputs=[video_outputs[i], debug_box]
251
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  if __name__ == "__main__":
254
+ demo.queue().launch()