Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,475 +1,254 @@
|
|
| 1 |
-
import
|
| 2 |
-
from typing import Optional
|
|
|
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
import gradio as gr
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
#
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
DEFAULT_W = 704 # keep defaults under the 715px threshold
|
| 20 |
-
DEFAULT_H = 704
|
| 21 |
-
POLL_INTERVAL = 2.5
|
| 22 |
-
POLL_TIMEOUT = 240 # bump if queues are long
|
| 23 |
-
MODEL = None # or set e.g. "SDXL 1.0"
|
| 24 |
-
|
| 25 |
-
# ---------------- Helpers ----------------
|
| 26 |
-
def _headers():
|
| 27 |
-
# Always send an apikey; fallback to anonymous for testing
|
| 28 |
-
return {
|
| 29 |
-
"Client-Agent": CLIENT_AGENT,
|
| 30 |
-
"apikey": HORDE_API_KEY if HORDE_API_KEY else "0000000000"
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
def pil_to_b64(img_pil: Image.Image) -> str:
|
| 34 |
buf = io.BytesIO()
|
| 35 |
-
|
| 36 |
-
return
|
| 37 |
|
| 38 |
-
def
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
s = (s or "").strip()
|
| 43 |
-
s = s.replace("-", "+").replace("_", "/")
|
| 44 |
-
pad = (-len(s)) % 4
|
| 45 |
-
if pad:
|
| 46 |
-
s += "=" * pad
|
| 47 |
-
try:
|
| 48 |
-
return base64.b64decode(s, validate=False)
|
| 49 |
-
except Exception:
|
| 50 |
-
return base64.urlsafe_b64decode(s + "=" * ((4 - len(s) % 4) % 4))
|
| 51 |
|
| 52 |
-
def
|
| 53 |
-
"""
|
| 54 |
-
If bytes are compressed, try gzip, zlib, bz2, lzma (in that order).
|
| 55 |
-
Returns original buf if none succeed.
|
| 56 |
-
"""
|
| 57 |
-
head = buf[:4]
|
| 58 |
-
try:
|
| 59 |
-
# gzip
|
| 60 |
-
if len(buf) > 2 and buf[0:2] == b"\x1f\x8b":
|
| 61 |
-
dbg.append("Detected gzip; decompressing…")
|
| 62 |
-
return gzip.decompress(buf)
|
| 63 |
-
# zlib (78 01/9C/DA)
|
| 64 |
-
if len(buf) > 2 and buf[0] == 0x78 and buf[1] in (0x01, 0x5E, 0x9C, 0xDA):
|
| 65 |
-
dbg.append("Detected zlib; decompressing…")
|
| 66 |
-
return zlib.decompress(buf)
|
| 67 |
-
# bz2
|
| 68 |
-
if len(buf) > 3 and buf[0:3] == b"BZh":
|
| 69 |
-
dbg.append("Detected bz2; decompressing…")
|
| 70 |
-
return bz2.decompress(buf)
|
| 71 |
-
# lzma/xz
|
| 72 |
-
if len(buf) > 6 and buf[0:6] == b"\xfd7zXZ\x00":
|
| 73 |
-
dbg.append("Detected lzma/xz; decompressing…")
|
| 74 |
-
return lzma.decompress(buf)
|
| 75 |
-
except Exception as e:
|
| 76 |
-
dbg.append(f"Decompress probe failed: {type(e).__name__}: {e}")
|
| 77 |
-
return buf
|
| 78 |
-
|
| 79 |
-
def build_prompt(user_text: str, is_first: bool, lock_longshot: bool = True) -> str:
|
| 80 |
-
"""Compose continuity-aware prompt text."""
|
| 81 |
-
user_text = (user_text or "").strip()
|
| 82 |
-
longshot_plus = (
|
| 83 |
-
"single continuous long shot; no cuts or new shot; no angle switch; "
|
| 84 |
-
"smooth camera motion (pan/tilt/zoom only); unbroken continuity"
|
| 85 |
-
)
|
| 86 |
-
if is_first:
|
| 87 |
-
base = f"Opening frame. {user_text}" if user_text else "Opening frame."
|
| 88 |
-
if lock_longshot:
|
| 89 |
-
base += ". " + longshot_plus
|
| 90 |
-
return base
|
| 91 |
-
# Subsequent frames
|
| 92 |
-
base = (
|
| 93 |
-
"Treat the previous frame as a still from the same continuous long shot. "
|
| 94 |
-
"Maintain style, subject identity, lighting, and camera continuity. "
|
| 95 |
-
f"Generate the next moment: {user_text if user_text else 'advance the action naturally.'}"
|
| 96 |
-
)
|
| 97 |
-
if lock_longshot:
|
| 98 |
-
base += ". " + longshot_plus
|
| 99 |
-
return base
|
| 100 |
-
|
| 101 |
-
# =========================
|
| 102 |
-
# Horde client (txt2img OR img2img)
|
| 103 |
-
# =========================
|
| 104 |
-
def horde_generate(
|
| 105 |
-
prompt: str,
|
| 106 |
-
steps: int = DEFAULT_STEPS,
|
| 107 |
-
width: int = DEFAULT_W,
|
| 108 |
-
height: int = DEFAULT_H,
|
| 109 |
-
model: Optional[str] = MODEL,
|
| 110 |
-
init_image: Optional[Image.Image] = None,
|
| 111 |
-
denoise: float = 0.45, # 0.0 = identical, 1.0 = big change
|
| 112 |
-
):
|
| 113 |
"""
|
| 114 |
-
|
| 115 |
-
|
| 116 |
"""
|
| 117 |
dbg = []
|
| 118 |
-
if
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def _submit(payload):
|
| 122 |
-
sub = requests.post(HORDE_URL, json=payload, headers=_headers(), timeout=30)
|
| 123 |
-
# Auto-fallback if KudosUpfront required
|
| 124 |
-
if sub.status_code == 403:
|
| 125 |
-
try:
|
| 126 |
-
body = sub.json()
|
| 127 |
-
except Exception:
|
| 128 |
-
body = {"message": sub.text}
|
| 129 |
-
msg = (body.get("message") or "").lower()
|
| 130 |
-
rc = body.get("rc") or ""
|
| 131 |
-
if "kudos" in msg or rc == "KudosUpfront":
|
| 132 |
-
payload["params"]["steps"] = min(int(payload["params"]["steps"]), 30)
|
| 133 |
-
payload["params"]["width"] = min(int(payload["params"]["width"]), 704)
|
| 134 |
-
payload["params"]["height"] = min(int(payload["params"]["height"]), 704)
|
| 135 |
-
dbg.append("Fallback applied: steps<=30, width/height<=704. Retrying submit…")
|
| 136 |
-
sub = requests.post(HORDE_URL, json=payload, headers=_headers(), timeout=30)
|
| 137 |
-
return sub
|
| 138 |
-
|
| 139 |
-
# ---- try img2img if init_image provided ----
|
| 140 |
-
tried_img2img = False
|
| 141 |
-
if init_image is not None:
|
| 142 |
-
tried_img2img = True
|
| 143 |
-
payload = {
|
| 144 |
-
"prompt": prompt.strip(),
|
| 145 |
-
"params": {
|
| 146 |
-
"steps": int(steps),
|
| 147 |
-
"width": int(width),
|
| 148 |
-
"height": int(height),
|
| 149 |
-
"n": 1,
|
| 150 |
-
"denoise": float(denoise)
|
| 151 |
-
},
|
| 152 |
-
"nsfw": False,
|
| 153 |
-
"censor_nsfw": True,
|
| 154 |
-
"source_processing": "img2img",
|
| 155 |
-
"source_image": pil_to_b64(init_image),
|
| 156 |
-
"r2": True
|
| 157 |
-
}
|
| 158 |
-
if model:
|
| 159 |
-
payload["models"] = [model]
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
if submit.status_code >= 300:
|
| 165 |
-
dbg.append(f"SUBMIT body={submit.text[:500]}")
|
| 166 |
-
submit.raise_for_status()
|
| 167 |
-
submit_j = submit.json()
|
| 168 |
-
job_id = submit_j.get("id")
|
| 169 |
-
if not job_id:
|
| 170 |
-
dbg.append(f"SUBMIT json={submit_j}")
|
| 171 |
-
raise gr.Error("Horde submit succeeded but no job id returned.")
|
| 172 |
-
dbg.append(f"JOB id={job_id}")
|
| 173 |
-
return _poll_and_decode(job_id, dbg)
|
| 174 |
-
except Exception:
|
| 175 |
-
dbg.append("IMG2IMG path failed, falling back to text-only:\n" + traceback.format_exc())
|
| 176 |
-
|
| 177 |
-
# ---- txt2img path ----
|
| 178 |
-
payload = {
|
| 179 |
-
"prompt": prompt.strip(),
|
| 180 |
-
"params": {
|
| 181 |
-
"steps": int(steps),
|
| 182 |
-
"width": int(width),
|
| 183 |
-
"height": int(height),
|
| 184 |
-
"n": 1
|
| 185 |
-
},
|
| 186 |
-
"nsfw": False,
|
| 187 |
-
"censor_nsfw": True,
|
| 188 |
-
"r2": True
|
| 189 |
-
}
|
| 190 |
-
if model:
|
| 191 |
-
payload["models"] = [model]
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
dbg.append(f"SUBMIT (txt2img{', after img2img fail' if tried_img2img else ''}) status={submit.status_code}")
|
| 196 |
-
if submit.status_code >= 300:
|
| 197 |
-
dbg.append(f"SUBMIT body={submit.text[:500]}")
|
| 198 |
-
submit.raise_for_status()
|
| 199 |
-
submit_j = submit.json()
|
| 200 |
-
job_id = submit_j.get("id")
|
| 201 |
-
if not job_id:
|
| 202 |
-
dbg.append(f"SUBMIT json={submit_j}")
|
| 203 |
-
raise gr.Error("Horde submit succeeded but no job id returned.")
|
| 204 |
-
dbg.append(f"JOB id={job_id}")
|
| 205 |
-
return _poll_and_decode(job_id, dbg)
|
| 206 |
-
except Exception:
|
| 207 |
-
dbg.append("SUBMIT exception:\n" + traceback.format_exc())
|
| 208 |
-
return None, "\n".join(dbg)
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
try:
|
| 214 |
-
|
| 215 |
-
if status_r.status_code >= 300:
|
| 216 |
-
dbg.append(f"POLL status={status_r.status_code}")
|
| 217 |
-
dbg.append(f"POLL body={status_r.text[:500]}")
|
| 218 |
-
status_r.raise_for_status()
|
| 219 |
-
s = status_r.json()
|
| 220 |
-
|
| 221 |
-
k = s.get("kudos", "?")
|
| 222 |
-
queue = s.get("queue_position", "?")
|
| 223 |
-
eta = s.get("wait_time", "?")
|
| 224 |
-
dbg.append(f"queue={queue} eta≈{eta}s kudos={k}")
|
| 225 |
-
|
| 226 |
-
if s.get("faulted"):
|
| 227 |
-
dbg.append(f"FAULT: {s}")
|
| 228 |
-
return None, "\n".join(dbg)
|
| 229 |
-
|
| 230 |
-
if s.get("done"):
|
| 231 |
-
gens = s.get("generations") or []
|
| 232 |
-
if not gens:
|
| 233 |
-
dbg.append("DONE but no generations returned.")
|
| 234 |
-
return None, "\n".join(dbg)
|
| 235 |
-
|
| 236 |
-
g0 = gens[0]
|
| 237 |
-
dbg.append(f"GEN keys: {list(g0.keys())}")
|
| 238 |
-
dbg.append(f"img_type: {g0.get('img_type')}")
|
| 239 |
-
|
| 240 |
-
# Prefer URL if present (also check gen_metadata for r2 url)
|
| 241 |
-
url = (
|
| 242 |
-
g0.get("r2")
|
| 243 |
-
or g0.get("url")
|
| 244 |
-
or g0.get("src")
|
| 245 |
-
or g0.get("image_url")
|
| 246 |
-
or (isinstance(g0.get("gen_metadata"), dict) and (g0["gen_metadata"].get("r2") or g0["gen_metadata"].get("url")))
|
| 247 |
-
)
|
| 248 |
-
if isinstance(url, str) and (url.startswith("http://") or url.startswith("https://")):
|
| 249 |
-
dbg.append("Found URL in generation → fetching…")
|
| 250 |
-
try:
|
| 251 |
-
r = requests.get(url, timeout=60)
|
| 252 |
-
r.raise_for_status()
|
| 253 |
-
img_bytes = r.content
|
| 254 |
-
return _decode_bytes_to_image(img_bytes, dbg)
|
| 255 |
-
except Exception as e:
|
| 256 |
-
dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
|
| 257 |
-
return None, "\n".join(dbg)
|
| 258 |
-
|
| 259 |
-
# Base64 (or hex/encoded) path
|
| 260 |
-
b64 = g0.get("img")
|
| 261 |
-
if not b64:
|
| 262 |
-
dbg.append("No 'img' field present.")
|
| 263 |
-
return None, "\n".join(dbg)
|
| 264 |
-
|
| 265 |
-
# If 'img' looks like a URL string, fetch it
|
| 266 |
-
if b64.startswith("http://") or b64.startswith("https://"):
|
| 267 |
-
dbg.append("img field is a URL string → fetching…")
|
| 268 |
-
try:
|
| 269 |
-
r = requests.get(b64, timeout=60)
|
| 270 |
-
r.raise_for_status()
|
| 271 |
-
img_bytes = r.content
|
| 272 |
-
return _decode_bytes_to_image(img_bytes, dbg)
|
| 273 |
-
except Exception as e:
|
| 274 |
-
dbg.append(f"URL fetch failed: {type(e).__name__}: {e}")
|
| 275 |
-
return None, "\n".join(dbg)
|
| 276 |
-
|
| 277 |
-
# Try base64-url safe first
|
| 278 |
-
try:
|
| 279 |
-
img_bytes = _b64_to_bytes(b64)
|
| 280 |
-
except Exception as e:
|
| 281 |
-
dbg.append(f"Base64/urlsafe decode failed: {type(e).__name__}: {e}")
|
| 282 |
-
img_bytes = None
|
| 283 |
-
|
| 284 |
-
# If that failed or header looks wrong, try hex
|
| 285 |
-
if not img_bytes or len(img_bytes) < 8:
|
| 286 |
-
if re.fullmatch(r"[0-9a-fA-F]+", b64) and len(b64) % 2 == 0:
|
| 287 |
-
dbg.append("img looks like hex → decoding…")
|
| 288 |
-
try:
|
| 289 |
-
img_bytes = bytes.fromhex(b64)
|
| 290 |
-
except Exception as e:
|
| 291 |
-
dbg.append(f"Hex decode failed: {type(e).__name__}: {e}")
|
| 292 |
-
img_bytes = None
|
| 293 |
-
|
| 294 |
-
if not img_bytes:
|
| 295 |
-
return None, "\n".join(dbg)
|
| 296 |
-
|
| 297 |
-
# Some workers compress payloads—try to decompress if needed
|
| 298 |
-
img_bytes = _maybe_decompress(img_bytes, dbg)
|
| 299 |
-
|
| 300 |
-
return _decode_bytes_to_image(img_bytes, dbg)
|
| 301 |
-
|
| 302 |
-
if time.time() - start > POLL_TIMEOUT:
|
| 303 |
-
dbg.append("TIMEOUT waiting for Horde.")
|
| 304 |
-
return None, "\n".join(dbg)
|
| 305 |
-
|
| 306 |
-
time.sleep(POLL_INTERVAL)
|
| 307 |
-
|
| 308 |
except Exception:
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
def _decode_bytes_to_image(img_bytes: bytes, dbg: list[str]):
|
| 313 |
-
head = img_bytes[:12]
|
| 314 |
-
dbg.append(f"header bytes: {head.hex(' ')}")
|
| 315 |
-
try:
|
| 316 |
-
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
| 317 |
-
return img, "\n".join(dbg)
|
| 318 |
-
except Exception as e:
|
| 319 |
-
dbg.append(f"PIL decode failed: {type(e).__name__}: {e}")
|
| 320 |
|
| 321 |
-
try:
|
| 322 |
-
arr = imageio.imread(io.BytesIO(img_bytes))
|
| 323 |
-
if isinstance(arr, np.ndarray):
|
| 324 |
-
if arr.ndim == 2: # grayscale → RGB
|
| 325 |
-
arr = np.stack([arr, arr, arr], axis=-1)
|
| 326 |
-
elif arr.shape[-1] == 4: # RGBA → RGB
|
| 327 |
-
arr = arr[..., :3]
|
| 328 |
-
img = Image.fromarray(arr.astype(np.uint8), mode="RGB")
|
| 329 |
-
dbg.append("Decoded via imageio fallback.")
|
| 330 |
-
return img, "\n".join(dbg)
|
| 331 |
except Exception as e:
|
| 332 |
-
dbg.append(f"
|
|
|
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
.gradio-container { padding: 24px; }
|
| 377 |
-
|
| 378 |
-
.
|
| 379 |
-
border-radius: 18px !important;
|
| 380 |
-
min-height: 90px;
|
| 381 |
-
font-size: 16px;
|
| 382 |
-
line-height: 1.4;
|
| 383 |
-
padding: 14px 16px;
|
| 384 |
-
}
|
| 385 |
-
/* Pill buttons */
|
| 386 |
-
.pill button {
|
| 387 |
-
border-radius: 999px !important;
|
| 388 |
-
padding: 10px 18px;
|
| 389 |
-
font-size: 15px;
|
| 390 |
-
}
|
| 391 |
-
/* Rounded image boxes */
|
| 392 |
-
.image-out .wrap, .image-out .svelte-1ipelgc {
|
| 393 |
-
border-radius: 22px !important;
|
| 394 |
-
}
|
| 395 |
"""
|
| 396 |
|
| 397 |
-
with gr.Blocks(css=
|
| 398 |
-
gr.Markdown("
|
| 399 |
-
"
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
with gr.Row():
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
-
#
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
)
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
# ---- Row 4: Next scene ----
|
| 456 |
-
with gr.Row():
|
| 457 |
-
with gr.Column(scale=1, min_width=320):
|
| 458 |
-
p4 = gr.Textbox(
|
| 459 |
-
placeholder="Describe the next scene",
|
| 460 |
-
lines=4,
|
| 461 |
-
label=None,
|
| 462 |
-
elem_classes=["prompt-box"]
|
| 463 |
-
)
|
| 464 |
-
b4 = gr.Button("Generate image 4", elem_classes=["pill"])
|
| 465 |
-
with gr.Column(scale=2, min_width=380):
|
| 466 |
-
img4 = gr.Image(label="Image 4 output", type="pil", elem_classes=["image-out"])
|
| 467 |
-
|
| 468 |
-
# Wire callbacks
|
| 469 |
-
b1.click(fn=generate_opening, inputs=[p1, steps, size, lock], outputs=[img1, debug_box])
|
| 470 |
-
b2.click(fn=generate_next, inputs=[p2, steps, size, lock, img1, change], outputs=[img2, debug_box])
|
| 471 |
-
b3.click(fn=generate_next, inputs=[p3, steps, size, lock, img2, change], outputs=[img3, debug_box])
|
| 472 |
-
b4.click(fn=generate_next, inputs=[p4, steps, size, lock, img3, change], outputs=[img4, debug_box])
|
| 473 |
|
| 474 |
if __name__ == "__main__":
|
| 475 |
-
demo.queue().launch()
|
|
|
|
| 1 |
+
import os, io, time, random, base64, zipfile
|
| 2 |
+
from typing import List, Tuple, Optional
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
from PIL import Image
|
| 6 |
import gradio as gr
|
| 7 |
+
|
| 8 |
+
# ========= Config =========
|
| 9 |
+
MAX_FRAMES = 8 # how many upload slots & rows to render
|
| 10 |
+
MODAL_BASE = "https://moonmath-ai--moonmath-i2v-backend-moonmathinference-run.modal.run"
|
| 11 |
+
|
| 12 |
+
# ========= Helpers =========
|
| 13 |
+
def _save_video_bytes(data: bytes, tag: str) -> str:
|
| 14 |
+
os.makedirs("/mnt/data", exist_ok=True)
|
| 15 |
+
path = f"/mnt/data/{tag}_{int(time.time())}.mp4"
|
| 16 |
+
with open(path, "wb") as f:
|
| 17 |
+
f.write(data)
|
| 18 |
+
return path
|
| 19 |
+
|
| 20 |
+
def _png_bytes_from_pil(img: Image.Image) -> bytes:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
buf = io.BytesIO()
|
| 22 |
+
img.save(buf, format="PNG")
|
| 23 |
+
return buf.getvalue()
|
| 24 |
|
| 25 |
+
def _download_to_bytes(url: str) -> bytes:
|
| 26 |
+
r = requests.get(url, timeout=180)
|
| 27 |
+
r.raise_for_status()
|
| 28 |
+
return r.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
def call_modal_i2v(start_img: Image.Image, prompt: str, seed: Optional[int]) -> Tuple[Optional[str], str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
"""
|
| 32 |
+
POST to Modal with multipart 'image_bytes' and query args prompt & seed.
|
| 33 |
+
Returns (mp4_path_or_None, debug_log).
|
| 34 |
"""
|
| 35 |
dbg = []
|
| 36 |
+
if seed in (None, 0, -1):
|
| 37 |
+
seed = random.randint(1, 2**31 - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# Build URL (encode prompt)
|
| 40 |
+
from urllib.parse import quote
|
| 41 |
+
url = f"{MODAL_BASE}?prompt={quote(prompt)}&seed={seed}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
files = {"image_bytes": ("start.png", _png_bytes_from_pil(start_img), "image/png")}
|
| 44 |
+
headers = {"accept": "application/json"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
try:
|
| 47 |
+
resp = requests.post(url, files=files, headers=headers, timeout=600)
|
| 48 |
+
ctype = (resp.headers.get("content-type") or "").lower()
|
| 49 |
+
dbg.append(f"HTTP {resp.status_code}; content-type={ctype}")
|
| 50 |
+
|
| 51 |
+
# Case A: raw bytes (not JSON)
|
| 52 |
+
if "application/json" not in ctype:
|
| 53 |
+
resp.raise_for_status()
|
| 54 |
+
path = _save_video_bytes(resp.content, "pair")
|
| 55 |
+
dbg.append(f"Saved raw video to {path}")
|
| 56 |
+
return path, "\n".join(dbg)
|
| 57 |
+
|
| 58 |
+
# Case B: JSON containing url or base64
|
| 59 |
+
data = resp.json()
|
| 60 |
+
video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
|
| 61 |
+
video_b64 = data.get("video_b64") or data.get("videoBase64")
|
| 62 |
+
|
| 63 |
+
if video_url and isinstance(video_url, str):
|
| 64 |
+
b = _download_to_bytes(video_url)
|
| 65 |
+
path = _save_video_bytes(b, "pair")
|
| 66 |
+
dbg.append(f"Downloaded video from {video_url} -> {path}")
|
| 67 |
+
return path, "\n".join(dbg)
|
| 68 |
+
|
| 69 |
+
if video_b64 and isinstance(video_b64, str):
|
| 70 |
+
pad = (-len(video_b64)) % 4
|
| 71 |
+
if pad: video_b64 += "=" * pad
|
| 72 |
+
b = base64.b64decode(video_b64)
|
| 73 |
+
path = _save_video_bytes(b, "pair")
|
| 74 |
+
dbg.append("Decoded base64 video.")
|
| 75 |
+
return path, "\n".join(dbg)
|
| 76 |
+
|
| 77 |
+
# Nothing usable returned
|
| 78 |
try:
|
| 79 |
+
dbg.append(f"Backend JSON: {str(data)[:500]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
except Exception:
|
| 81 |
+
pass
|
| 82 |
+
return None, "\n".join(dbg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
except Exception as e:
|
| 85 |
+
dbg.append(f"Exception: {type(e).__name__}: {e}")
|
| 86 |
+
return None, "\n".join(dbg)
|
| 87 |
|
| 88 |
+
# ========= State handlers =========
|
| 89 |
+
def add_images(files: List[str], images_state: List[Image.Image], names_state: List[str]):
|
| 90 |
+
"""
|
| 91 |
+
Append uploads to state; return updated previews and row visibilities.
|
| 92 |
+
"""
|
| 93 |
+
imgs, names = list(images_state), list(names_state)
|
| 94 |
+
for f in files or []:
|
| 95 |
+
try:
|
| 96 |
+
img = Image.open(f).convert("RGB")
|
| 97 |
+
imgs.append(img)
|
| 98 |
+
names.append(os.path.basename(f))
|
| 99 |
+
except Exception:
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
# Outputs to update: image slots, labels, visibilities; pair rows visible up to len-1
|
| 103 |
+
img_values, img_labels, img_vis = [], [], []
|
| 104 |
+
pair_vis = []
|
| 105 |
+
for i in range(MAX_FRAMES):
|
| 106 |
+
if i < len(imgs):
|
| 107 |
+
img_values.append(imgs[i])
|
| 108 |
+
img_labels.append(f"Image {i+1}")
|
| 109 |
+
img_vis.append(True)
|
| 110 |
+
else:
|
| 111 |
+
img_values.append(None)
|
| 112 |
+
img_labels.append(f"Image {i+1}")
|
| 113 |
+
img_vis.append(False)
|
| 114 |
+
|
| 115 |
+
for i in range(MAX_FRAMES - 1):
|
| 116 |
+
pair_vis.append(i < len(imgs) - 1)
|
| 117 |
+
|
| 118 |
+
return imgs, names, img_values, img_labels, img_vis, pair_vis
|
| 119 |
+
|
| 120 |
+
def clear_all():
|
| 121 |
+
img_values = [None]*MAX_FRAMES
|
| 122 |
+
img_labels = [f"Image {i+1}" for i in range(MAX_FRAMES)]
|
| 123 |
+
img_vis = [False]*MAX_FRAMES
|
| 124 |
+
pair_vis = [False]*(MAX_FRAMES-1)
|
| 125 |
+
return [], [], img_values, img_labels, img_vis, pair_vis
|
| 126 |
+
|
| 127 |
+
def stitch_pair(index: int,
|
| 128 |
+
images: List[Image.Image],
|
| 129 |
+
prompt: str,
|
| 130 |
+
seed: int):
|
| 131 |
+
"""
|
| 132 |
+
index is 0-based pair (0 => 1&2, 1 => 2&3...)
|
| 133 |
+
We call Modal using the *first* image of the pair as the init image.
|
| 134 |
+
"""
|
| 135 |
+
if not images or len(images) < index+2:
|
| 136 |
+
gr.Warning("Upload more images first.")
|
| 137 |
+
return None, "Not enough images."
|
| 138 |
+
|
| 139 |
+
# Compose a minimal helpful prompt for continuity
|
| 140 |
+
user = (prompt or "").strip()
|
| 141 |
+
extra = f"(Transition between frame {index+1} → {index+2} of the same shot.)"
|
| 142 |
+
final_prompt = f"{user} {extra}".strip()
|
| 143 |
+
|
| 144 |
+
path, dbg = call_modal_i2v(images[index], final_prompt, seed)
|
| 145 |
+
if path is None:
|
| 146 |
+
gr.Warning("Stitch failed. See debug log.")
|
| 147 |
+
return path, dbg
|
| 148 |
+
|
| 149 |
+
# ========= UI =========
|
| 150 |
+
CSS = """
|
| 151 |
.gradio-container { padding: 24px; }
|
| 152 |
+
.pill button { border-radius: 999px !important; padding: 10px 18px; }
|
| 153 |
+
.rounded textarea { border-radius: 16px !important; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
"""
|
| 155 |
|
| 156 |
+
with gr.Blocks(css=CSS, title="Stitch — Upload & Stitch Adjacent Pairs") as demo:
|
| 157 |
+
gr.Markdown("## Stitch — Upload stills, then generate between-frames videos\n"
|
| 158 |
+
"Upload images in order. For each adjacent pair (1&2, 2&3, …), write a short transition prompt and click **Stitch**.")
|
| 159 |
+
|
| 160 |
+
images_state = gr.State([]) # List[PIL.Image]
|
| 161 |
+
names_state = gr.State([]) # List[str]
|
| 162 |
|
| 163 |
with gr.Row():
|
| 164 |
+
# Left column: image slots
|
| 165 |
+
with gr.Column(scale=1, min_width=340):
|
| 166 |
+
uploader = gr.Files(label="Add images (in order)", file_types=["image"], file_count="multiple")
|
| 167 |
+
clear_btn = gr.Button("Clear all", elem_classes=["pill"])
|
| 168 |
+
image_slots = []
|
| 169 |
+
for i in range(MAX_FRAMES):
|
| 170 |
+
image_slots.append(
|
| 171 |
+
gr.Image(label=f"Image {i+1}", interactive=False, visible=False)
|
| 172 |
+
)
|
| 173 |
|
| 174 |
+
# Middle column: per-pair prompt + button
|
| 175 |
+
with gr.Column(scale=1, min_width=340):
|
| 176 |
+
seed_in = gr.Number(value=0, precision=0, label="Seed (0 = random)")
|
| 177 |
+
prompt_boxes = []
|
| 178 |
+
stitch_buttons = []
|
| 179 |
+
for i in range(MAX_FRAMES - 1):
|
| 180 |
+
prompt_boxes.append(
|
| 181 |
+
gr.Textbox(
|
| 182 |
+
placeholder=f"Prompt for transition between Image {i+1} & {i+2}",
|
| 183 |
+
lines=2, label="Prompt", elem_classes=["rounded"], visible=False
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
stitch_buttons.append(
|
| 187 |
+
gr.Button(f"Stitch {i+1}&{i+2}", elem_classes=["pill"], visible=False)
|
| 188 |
+
)
|
| 189 |
|
| 190 |
+
# Right column: per-pair video outputs + shared debug
|
| 191 |
+
with gr.Column(scale=1, min_width=360):
|
| 192 |
+
video_outputs = []
|
| 193 |
+
for i in range(MAX_FRAMES - 1):
|
| 194 |
+
video_outputs.append(
|
| 195 |
+
gr.Video(label=f"Video (image {i+1}+{i+2}) output", visible=False)
|
| 196 |
+
)
|
| 197 |
+
debug_box = gr.Code(label="Debug log", interactive=False)
|
| 198 |
+
|
| 199 |
+
# ---- Wiring: upload & clear ----
|
| 200 |
+
uploader.upload(
|
| 201 |
+
fn=add_images,
|
| 202 |
+
inputs=[uploader, images_state, names_state],
|
| 203 |
+
outputs=[
|
| 204 |
+
images_state, names_state,
|
| 205 |
+
# image values, labels, visibilities
|
| 206 |
+
*image_slots, # values (Image components accept PIL Image)
|
| 207 |
+
*[s for s in image_slots], # labels: set via .label below (we'll hack via .update)
|
| 208 |
+
*[s for s in image_slots], # visibility
|
| 209 |
+
*[b for b in stitch_buttons] # visibility for rows (we’ll mirror to prompt/video too)
|
| 210 |
+
],
|
| 211 |
+
queue=False
|
| 212 |
+
)
|
| 213 |
|
| 214 |
+
# NOTE: Gradio can't directly set multiple attributes with one function return to each component slot,
|
| 215 |
+
# so we will do a lightweight post-upload JS update using .update. Simpler: tie visibility of prompt/video
|
| 216 |
+
# to the corresponding button's visibility in another handler:
|
| 217 |
+
|
| 218 |
+
def reflect_row_visibility(images: List[Image.Image]):
|
| 219 |
+
n = len(images)
|
| 220 |
+
vis = [i < n-1 for i in range(MAX_FRAMES-1)]
|
| 221 |
+
# return prompt visibilities, button visibilities, video visibilities
|
| 222 |
+
return [gr.Textbox(visible=vis[i]) for i in range(MAX_FRAMES-1)] + \
|
| 223 |
+
[gr.Button(visible=vis[i]) for i in range(MAX_FRAMES-1)] + \
|
| 224 |
+
[gr.Video(visible=vis[i]) for i in range(MAX_FRAMES-1)]
|
| 225 |
+
|
| 226 |
+
uploader.upload(
|
| 227 |
+
fn=reflect_row_visibility,
|
| 228 |
+
inputs=[images_state],
|
| 229 |
+
outputs=[*prompt_boxes, *stitch_buttons, *video_outputs],
|
| 230 |
+
queue=False
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
clear_btn.click(
|
| 234 |
+
fn=clear_all,
|
| 235 |
+
inputs=[],
|
| 236 |
+
outputs=[images_state, names_state, *image_slots, *image_slots, *image_slots, *stitch_buttons],
|
| 237 |
+
queue=False
|
| 238 |
+
).then(
|
| 239 |
+
fn=lambda imgs: reflect_row_visibility(imgs),
|
| 240 |
+
inputs=[images_state],
|
| 241 |
+
outputs=[*prompt_boxes, *stitch_buttons, *video_outputs],
|
| 242 |
+
queue=False
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# ---- Wiring: per-pair stitchers ----
|
| 246 |
+
for i in range(MAX_FRAMES - 1):
|
| 247 |
+
stitch_buttons[i].click(
|
| 248 |
+
fn=lambda prompt, seed, imgs, idx=i: stitch_pair(idx, imgs, prompt, int(seed or 0)),
|
| 249 |
+
inputs=[prompt_boxes[i], seed_in, images_state],
|
| 250 |
+
outputs=[video_outputs[i], debug_box]
|
| 251 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
if __name__ == "__main__":
|
| 254 |
+
demo.queue().launch()
|