Shalmoni commited on
Commit
ac99ac3
Β·
verified Β·
1 Parent(s): 85904ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -84
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os, json, uuid, re
2
  from datetime import datetime
3
  import gradio as gr
@@ -71,7 +72,7 @@ def _lazy_model_tok():
71
  _model = AutoModelForCausalLM.from_pretrained(
72
  STORYBOARD_MODEL,
73
  device_map="auto",
74
- torch_dtype=preferred_dtype, # <- correct kwarg
75
  trust_remote_code=True,
76
  use_safetensors=True
77
  )
@@ -111,7 +112,6 @@ def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_
111
  "Output must start with <JSON> and end with </JSON>.\n"
112
  )
113
 
114
-
115
  def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
116
  return (
117
  "Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n"
@@ -125,7 +125,7 @@ def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_le
125
  f" \"fps\": {default_fps},\n"
126
  " \"steps\": 30,\n"
127
  " \"seed\": null,\n"
128
- ' "negative": ""\n'
129
  "}\n"
130
  )
131
 
@@ -170,14 +170,20 @@ def _extract_json_array(text: str) -> str:
170
  if start == -1:
171
  return ""
172
  depth = 0
 
 
173
  for i in range(start, len(text)):
174
  ch = text[i]
175
- if ch == "[":
176
- depth += 1
177
- elif ch == "]":
178
- depth -= 1
179
- if depth == 0:
180
- return text[start:i+1].strip()
 
 
 
 
181
  return ""
182
 
183
  def _normalize_shots(shots_raw, default_fps: int, default_len: int):
@@ -241,81 +247,116 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
241
  return _normalize_shots(shots_raw, default_fps, default_len)
242
 
243
  # =========================
244
- # IMAGE GEN (ZeroGPU) β€” sd-turbo t2i + img2img chaining
245
  # =========================
246
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
 
 
 
 
247
 
248
- SD_MODEL = os.getenv("SD_MODEL", "stabilityai/sd-turbo")
 
249
  _sd_t2i = None
250
  _sd_i2i = None
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  def _lazy_sd_pipes():
 
 
253
  global _sd_t2i, _sd_i2i
254
  if _sd_t2i is not None and _sd_i2i is not None:
255
  return _sd_t2i, _sd_i2i
256
-
257
- use_cuda = torch.cuda.is_available()
258
- dtype = torch.float16 if use_cuda else torch.float32
259
  hf_token = os.getenv("HF_TOKEN", None)
260
-
261
  _sd_t2i = StableDiffusionPipeline.from_pretrained(
262
- SD_MODEL,
263
- torch_dtype=dtype,
264
- safety_checker=None,
265
- feature_extractor=None,
266
- use_safetensors=True,
267
- low_cpu_mem_usage=False,
268
- token=hf_token
269
  )
270
- if use_cuda:
271
- _sd_t2i = _sd_t2i.to("cuda")
272
-
273
  _sd_i2i = StableDiffusionImg2ImgPipeline(
274
- vae=_sd_t2i.vae,
275
- text_encoder=_sd_t2i.text_encoder,
276
- tokenizer=_sd_t2i.tokenizer,
277
- unet=_sd_t2i.unet,
278
- scheduler=_sd_t2i.scheduler,
279
- safety_checker=None,
280
- feature_extractor=None
281
  )
282
- if use_cuda:
283
- _sd_i2i = _sd_i2i.to("cuda")
284
-
285
  return _sd_t2i, _sd_i2i
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
288
  pdir = project_dir(pid)
289
  out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png")
290
  img.save(out)
291
  return out
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  @spaces.GPU(duration=180)
294
  def generate_keyframe_image(
295
  pid: str,
296
  shot_idx: int,
297
  shots: list,
298
- t2i_steps: int = 6, # first shot
299
- i2i_steps: int = 10, # subsequent shots
300
- i2i_strength: float = 0.65, # change vs consistency
301
- guidance_scale: float = 0.5,
302
- width: int = 512,
303
- height: int = 512
304
  ):
305
  """
306
  Generate image for shots[shot_idx].
307
- - shot 0: text2img (few steps)
308
- - shot k>0: img2img from previous approved image with higher strength/steps
309
- Seed is kept SAME across all shots (stored in shots[i]['seed']).
 
310
  """
311
- t2i, i2i = _lazy_sd_pipes()
312
  shot = shots[shot_idx]
313
 
314
  prompt = (shot.get("description") or "").strip()
315
  negative = shot.get("negative") or ""
316
  seed = shot.get("seed", None)
317
 
318
- device = "cuda" if torch.cuda.is_available() else "cpu"
319
  gen = torch.Generator(device)
320
  if isinstance(seed, int):
321
  gen = gen.manual_seed(int(seed))
@@ -323,40 +364,57 @@ def generate_keyframe_image(
323
  width = max(256, min(1024, int(width)))
324
  height = max(256, min(1024, int(height)))
325
 
326
- if shot_idx == 0 or not shots[shot_idx - 1].get("image_path"):
327
- out = t2i(
328
- prompt=prompt,
329
- negative_prompt=negative,
330
- guidance_scale=guidance_scale,
331
- num_inference_steps=int(max(1, t2i_steps)),
332
- generator=gen,
333
- width=width,
334
- height=height
335
- ).images[0]
336
  else:
337
- prev_path = shots[shot_idx - 1].get("image_path")
338
- if prev_path and os.path.exists(prev_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  init_image = Image.open(prev_path).convert("RGB")
340
- strength = float(i2i_strength)
341
- strength = min(max(strength, 0.50), 0.90)
342
  out = i2i(
343
  prompt=prompt,
344
- negative_prompt=negative,
345
  image=init_image,
346
- guidance_scale=guidance_scale,
347
- strength=strength,
348
- num_inference_steps=int(max(2, i2i_steps)),
349
  generator=gen
350
  ).images[0]
351
- else:
 
 
352
  out = t2i(
353
  prompt=prompt,
354
  negative_prompt=negative,
355
- guidance_scale=guidance_scale,
356
- num_inference_steps=int(max(1, t2i_steps)),
357
  generator=gen,
358
- width=width,
359
- height=height
 
 
 
 
 
 
 
 
 
360
  ).images[0]
361
 
362
  saved_path = _save_keyframe(pid, int(shot["id"]), out)
@@ -392,7 +450,7 @@ def df_to_shots(df: pd.DataFrame) -> list:
392
  # =========================
393
  with gr.Blocks() as demo:
394
  gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
395
- gr.Markdown("Edit storyboard prompts, then generate keyframes. Shots 2+ use the previous approved image for consistency. A single project seed is locked for a cohesive look.")
396
 
397
  # State
398
  project = gr.State(None)
@@ -439,11 +497,11 @@ with gr.Blocks() as demo:
439
  with gr.Row():
440
  gen_btn = gr.Button("Generate / Regenerate", variant="primary")
441
  approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
442
- # tuning controls
443
  with gr.Row():
444
- img_strength = gr.Slider(0.40, 0.90, value=0.65, step=0.05, label="Change vs Consistency (img2img strength)")
445
- img_steps = gr.Slider(4, 20, value=10, step=1, label="Img2Img Steps")
446
- guidance = gr.Slider(0.0, 2.0, value=0.5, step=0.05, label="Guidance Scale")
447
  with gr.Row():
448
  prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
449
  out_img = gr.Image(label="Generated image", type="filepath")
@@ -473,7 +531,6 @@ with gr.Blocks() as demo:
473
  p["shots"] = shots
474
  p["meta"]["updated"] = now_iso()
475
  save_project(p)
476
- # Enable Save Edits after storyboard exists
477
  return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable)."), gr.update(interactive=True)
478
 
479
  propose_btn.click(
@@ -503,11 +560,8 @@ with gr.Blocks() as demo:
503
 
504
  # lock a single seed for the project:
505
  proj_seed = None
506
- # override if user supplied:
507
  if proj_seed_override not in [None, ""] and str(proj_seed_override).isdigit():
508
  proj_seed = int(proj_seed_override)
509
-
510
- # otherwise use existing project meta seed or find one in shots:
511
  if proj_seed is None:
512
  proj_seed = p.get("meta", {}).get("seed", None)
513
  if proj_seed is None:
@@ -518,7 +572,6 @@ with gr.Blocks() as demo:
518
  if proj_seed is None:
519
  proj_seed = int(torch.randint(0, 2**31 - 1, (1,)).item())
520
 
521
- # apply to all shots missing seed
522
  for s in shots:
523
  if not isinstance(s.get("seed"), int):
524
  s["seed"] = proj_seed
@@ -549,19 +602,19 @@ with gr.Blocks() as demo:
549
  shots = p["shots"]
550
  if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.")
551
  shots[idx]["description"] = current_prompt # allow tweaking
552
- prev_path = shots[idx-1]["image_path"] if idx > 0 else None
553
 
554
  img_path = generate_keyframe_image(
555
  p["meta"]["id"],
556
  int(idx),
557
  shots,
558
- t2i_steps=6,
559
  i2i_steps=int(i2i_steps_val),
560
  i2i_strength=float(i2i_strength_val),
561
  guidance_scale=float(guidance_val),
562
- width=512,
563
- height=512
564
  )
 
565
  return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.")
566
 
567
  gen_btn.click(
 
1
+ # app.py
2
  import os, json, uuid, re
3
  from datetime import datetime
4
  import gradio as gr
 
72
  _model = AutoModelForCausalLM.from_pretrained(
73
  STORYBOARD_MODEL,
74
  device_map="auto",
75
+ torch_dtype=preferred_dtype,
76
  trust_remote_code=True,
77
  use_safetensors=True
78
  )
 
112
  "Output must start with <JSON> and end with </JSON>.\n"
113
  )
114
 
 
115
  def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
116
  return (
117
  "Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n"
 
125
  f" \"fps\": {default_fps},\n"
126
  " \"steps\": 30,\n"
127
  " \"seed\": null,\n"
128
+ ' "negative": ""\n'
129
  "}\n"
130
  )
131
 
 
170
  if start == -1:
171
  return ""
172
  depth = 0
173
+ in_str = False
174
+ prev = ""
175
  for i in range(start, len(text)):
176
  ch = text[i]
177
+ if ch == '"' and prev != '\\':
178
+ in_str = not in_str
179
+ if not in_str:
180
+ if ch == "[":
181
+ depth += 1
182
+ elif ch == "]":
183
+ depth -= 1
184
+ if depth == 0:
185
+ return text[start:i+1].strip()
186
+ prev = ch
187
  return ""
188
 
189
  def _normalize_shots(shots_raw, default_fps: int, default_len: int):
 
247
  return _normalize_shots(shots_raw, default_fps, default_len)
248
 
249
  # =========================
250
+ # IMAGE GEN β€” FLUX first, SD-Turbo fallback
251
  # =========================
252
+ USE_CUDA = torch.cuda.is_available()
253
+ DTYPE = torch.float16 if USE_CUDA else torch.float32
254
+
255
+ FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-Nano") # or "black-forest-labs/FLUX.1-dev"
256
+ SD_MODEL = os.getenv("SD_MODEL", "stabilityai/sd-turbo")
257
 
258
+ _flux_t2i = None
259
+ _flux_i2i = None
260
  _sd_t2i = None
261
  _sd_i2i = None
262
+ _have_flux = None
263
+
264
+ def _lazy_flux_pipes():
265
+ # Returns (t2i, i2i) or raises
266
+ from diffusers import FluxPipeline, FluxImg2ImgPipeline
267
+ global _flux_t2i, _flux_i2i
268
+ if _flux_t2i is not None and _flux_i2i is not None:
269
+ return _flux_t2i, _flux_i2i
270
+ _flux_t2i = FluxPipeline.from_pretrained(FLUX_MODEL, torch_dtype=DTYPE, use_safetensors=True)
271
+ if USE_CUDA: _flux_t2i = _flux_t2i.to("cuda")
272
+ _flux_i2i = FluxImg2ImgPipeline.from_pretrained(FLUX_MODEL, torch_dtype=DTYPE, use_safetensors=True)
273
+ if USE_CUDA: _flux_i2i = _flux_i2i.to("cuda")
274
+ return _flux_t2i, _flux_i2i
275
 
276
  def _lazy_sd_pipes():
277
+ # Returns (t2i, i2i)
278
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
279
  global _sd_t2i, _sd_i2i
280
  if _sd_t2i is not None and _sd_i2i is not None:
281
  return _sd_t2i, _sd_i2i
 
 
 
282
  hf_token = os.getenv("HF_TOKEN", None)
 
283
  _sd_t2i = StableDiffusionPipeline.from_pretrained(
284
+ SD_MODEL, torch_dtype=DTYPE, safety_checker=None, feature_extractor=None,
285
+ use_safetensors=True, low_cpu_mem_usage=False, token=hf_token
 
 
 
 
 
286
  )
287
+ if USE_CUDA: _sd_t2i = _sd_t2i.to("cuda")
 
 
288
  _sd_i2i = StableDiffusionImg2ImgPipeline(
289
+ vae=_sd_t2i.vae, text_encoder=_sd_t2i.text_encoder, tokenizer=_sd_t2i.tokenizer,
290
+ unet=_sd_t2i.unet, scheduler=_sd_t2i.scheduler,
291
+ safety_checker=None, feature_extractor=None
 
 
 
 
292
  )
293
+ if USE_CUDA: _sd_i2i = _sd_i2i.to("cuda")
 
 
294
  return _sd_t2i, _sd_i2i
295
 
296
+ def _try_get_pipes():
297
+ """Prefer FLUX; fall back to SD-Turbo. Returns (mode, t2i, i2i) where mode in {'flux','sd'}."""
298
+ global _have_flux
299
+ if _have_flux is None:
300
+ try:
301
+ t2i, i2i = _lazy_flux_pipes()
302
+ _have_flux = True
303
+ return "flux", t2i, i2i
304
+ except Exception as e:
305
+ _have_flux = False
306
+ if _have_flux:
307
+ return "flux", *_lazy_flux_pipes()
308
+ else:
309
+ return "sd", *_lazy_sd_pipes()
310
+
311
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
312
  pdir = project_dir(pid)
313
  out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png")
314
  img.save(out)
315
  return out
316
 
317
+ def _significant_change(curr_desc: str, prev_desc: str) -> bool:
318
+ """
319
+ Heuristic: if symmetric difference of tokens is large -> treat as a new scene,
320
+ so we should text2img (seed keeps style) instead of img2img.
321
+ """
322
+ if not prev_desc: return True
323
+ a = set(re.findall(r"\w+", curr_desc.lower()))
324
+ b = set(re.findall(r"\w+", prev_desc.lower()))
325
+ # weights: boost composition-y words
326
+ comp_words = {"wide","close","low","high","overhead","aerial","profile","left","right","center",
327
+ "portrait","landscape","long","establishing","macro","tilt","dutch","angle",
328
+ "night","day","sunset","sunrise","noon","backlit","rim","key","fill"}
329
+ delta = a.symmetric_difference(b)
330
+ score = len(delta) + 2 * len((a ^ b) & comp_words)
331
+ return score >= 12 # tune threshold 10–16
332
+
333
  @spaces.GPU(duration=180)
334
  def generate_keyframe_image(
335
  pid: str,
336
  shot_idx: int,
337
  shots: list,
338
+ t2i_steps: int = 14, # FLUX likes 12–20
339
+ i2i_steps: int = 16,
340
+ i2i_strength: float = 0.8, # higher = follow prompt more
341
+ guidance_scale: float = 3.0, # FLUX sweet spot ~2.5–3.5
342
+ width: int = 640,
343
+ height: int = 640
344
  ):
345
  """
346
  Generate image for shots[shot_idx].
347
+ - shot 0: text2img
348
+ - shot k>0: smart chaining
349
+ * if significant change: text2img (same seed for style)
350
+ * else: img2img from previous approved image
351
  """
352
+ mode, t2i, i2i = _try_get_pipes()
353
  shot = shots[shot_idx]
354
 
355
  prompt = (shot.get("description") or "").strip()
356
  negative = shot.get("negative") or ""
357
  seed = shot.get("seed", None)
358
 
359
+ device = "cuda" if USE_CUDA else "cpu"
360
  gen = torch.Generator(device)
361
  if isinstance(seed, int):
362
  gen = gen.manual_seed(int(seed))
 
364
  width = max(256, min(1024, int(width)))
365
  height = max(256, min(1024, int(height)))
366
 
367
+ # decide chaining
368
+ use_prev = False
369
+ prev_path = shots[shot_idx - 1].get("image_path") if shot_idx > 0 else None
370
+ if shot_idx == 0 or not prev_path or not os.path.exists(prev_path):
371
+ use_prev = False
 
 
 
 
 
372
  else:
373
+ prev_desc = shots[shot_idx - 1].get("description") or ""
374
+ use_prev = not _significant_change(prompt, prev_desc)
375
+
376
+ # invoke
377
+ if mode == "flux":
378
+ if not use_prev:
379
+ out = t2i(
380
+ prompt=prompt,
381
+ negative_prompt=negative or None,
382
+ num_inference_steps=int(max(8, t2i_steps)),
383
+ guidance_scale=float(max(2.0, guidance_scale)),
384
+ generator=gen,
385
+ width=width, height=height
386
+ ).images[0]
387
+ else:
388
  init_image = Image.open(prev_path).convert("RGB")
 
 
389
  out = i2i(
390
  prompt=prompt,
391
+ negative_prompt=negative or None,
392
  image=init_image,
393
+ strength=float(min(max(i2i_strength, 0.5), 0.95)),
394
+ num_inference_steps=int(max(10, i2i_steps)),
395
+ guidance_scale=float(max(2.0, guidance_scale)),
396
  generator=gen
397
  ).images[0]
398
+ else:
399
+ # SD-turbo fallback (keep your original behavior but with less mushy defaults)
400
+ if not use_prev:
401
  out = t2i(
402
  prompt=prompt,
403
  negative_prompt=negative,
404
+ guidance_scale=1.0,
405
+ num_inference_steps=int(max(6, t2i_steps//2)),
406
  generator=gen,
407
+ width=width, height=height
408
+ ).images[0]
409
+ else:
410
+ init_image = Image.open(prev_path).convert("RGB")
411
+ out = i2i(
412
+ prompt=prompt,
413
+ negative_prompt=negative,
414
+ image=init_image,
415
+ strength=float(min(max(i2i_strength, 0.55), 0.9)),
416
+ num_inference_steps=int(max(8, i2i_steps//2)),
417
+ generator=gen
418
  ).images[0]
419
 
420
  saved_path = _save_keyframe(pid, int(shot["id"]), out)
 
450
  # =========================
451
  with gr.Blocks() as demo:
452
  gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
453
+ gr.Markdown("Edit storyboard prompts, then generate keyframes. **Smart chaining**: only reuse the previous image if the new prompt is similar; otherwise we regenerate from text with the same seed for style consistency.")
454
 
455
  # State
456
  project = gr.State(None)
 
497
  with gr.Row():
498
  gen_btn = gr.Button("Generate / Regenerate", variant="primary")
499
  approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
500
+ # tuning controls (defaults tuned for FLUX; fallback will downshift)
501
  with gr.Row():
502
+ img_strength = gr.Slider(0.50, 0.95, value=0.80, step=0.05, label="Change vs Consistency (img2img strength)")
503
+ img_steps = gr.Slider(8, 28, value=16, step=1, label="Inference Steps (img2img)")
504
+ guidance = gr.Slider(2.0, 4.0, value=3.0, step=0.1, label="Guidance Scale")
505
  with gr.Row():
506
  prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
507
  out_img = gr.Image(label="Generated image", type="filepath")
 
531
  p["shots"] = shots
532
  p["meta"]["updated"] = now_iso()
533
  save_project(p)
 
534
  return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable)."), gr.update(interactive=True)
535
 
536
  propose_btn.click(
 
560
 
561
  # lock a single seed for the project:
562
  proj_seed = None
 
563
  if proj_seed_override not in [None, ""] and str(proj_seed_override).isdigit():
564
  proj_seed = int(proj_seed_override)
 
 
565
  if proj_seed is None:
566
  proj_seed = p.get("meta", {}).get("seed", None)
567
  if proj_seed is None:
 
572
  if proj_seed is None:
573
  proj_seed = int(torch.randint(0, 2**31 - 1, (1,)).item())
574
 
 
575
  for s in shots:
576
  if not isinstance(s.get("seed"), int):
577
  s["seed"] = proj_seed
 
602
  shots = p["shots"]
603
  if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.")
604
  shots[idx]["description"] = current_prompt # allow tweaking
 
605
 
606
  img_path = generate_keyframe_image(
607
  p["meta"]["id"],
608
  int(idx),
609
  shots,
610
+ t2i_steps=14, # tuned for FLUX
611
  i2i_steps=int(i2i_steps_val),
612
  i2i_strength=float(i2i_strength_val),
613
  guidance_scale=float(guidance_val),
614
+ width=640,
615
+ height=640
616
  )
617
+ prev_path = shots[idx-1]["image_path"] if idx > 0 else None
618
  return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.")
619
 
620
  gen_btn.click(