Shalmoni commited on
Commit
5362213
Β·
verified Β·
1 Parent(s): 76a13c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -101
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py
2
  import os, json, uuid, re
3
  from datetime import datetime
4
  import gradio as gr
@@ -247,22 +247,17 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
247
  return _normalize_shots(shots_raw, default_fps, default_len)
248
 
249
  # =========================
250
- # IMAGE GEN β€” FLUX first, SD-Turbo fallback
251
  # =========================
252
  USE_CUDA = torch.cuda.is_available()
253
  DTYPE = torch.float16 if USE_CUDA else torch.float32
254
 
255
- FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-Nano") # or "black-forest-labs/FLUX.1-dev"
256
- SD_MODEL = os.getenv("SD_MODEL", "stabilityai/sd-turbo")
257
 
258
  _flux_t2i = None
259
  _flux_i2i = None
260
- _sd_t2i = None
261
- _sd_i2i = None
262
- _have_flux = None
263
 
264
  def _lazy_flux_pipes():
265
- # Returns (t2i, i2i) or raises
266
  from diffusers import FluxPipeline, FluxImg2ImgPipeline
267
  global _flux_t2i, _flux_i2i
268
  if _flux_t2i is not None and _flux_i2i is not None:
@@ -273,40 +268,15 @@ def _lazy_flux_pipes():
273
  if USE_CUDA: _flux_i2i = _flux_i2i.to("cuda")
274
  return _flux_t2i, _flux_i2i
275
 
276
- def _lazy_sd_pipes():
277
- # Returns (t2i, i2i)
278
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
279
- global _sd_t2i, _sd_i2i
280
- if _sd_t2i is not None and _sd_i2i is not None:
281
- return _sd_t2i, _sd_i2i
282
- hf_token = os.getenv("HF_TOKEN", None)
283
- _sd_t2i = StableDiffusionPipeline.from_pretrained(
284
- SD_MODEL, torch_dtype=DTYPE, safety_checker=None, feature_extractor=None,
285
- use_safetensors=True, low_cpu_mem_usage=False, token=hf_token
286
- )
287
- if USE_CUDA: _sd_t2i = _sd_t2i.to("cuda")
288
- _sd_i2i = StableDiffusionImg2ImgPipeline(
289
- vae=_sd_t2i.vae, text_encoder=_sd_t2i.text_encoder, tokenizer=_sd_t2i.tokenizer,
290
- unet=_sd_t2i.unet, scheduler=_sd_t2i.scheduler,
291
- safety_checker=None, feature_extractor=None
292
- )
293
- if USE_CUDA: _sd_i2i = _sd_i2i.to("cuda")
294
- return _sd_t2i, _sd_i2i
295
-
296
- def _try_get_pipes():
297
- """Prefer FLUX; fall back to SD-Turbo. Returns (mode, t2i, i2i) where mode in {'flux','sd'}."""
298
- global _have_flux
299
- if _have_flux is None:
300
- try:
301
- t2i, i2i = _lazy_flux_pipes()
302
- _have_flux = True
303
- return "flux", t2i, i2i
304
- except Exception as e:
305
- _have_flux = False
306
- if _have_flux:
307
- return "flux", *_lazy_flux_pipes()
308
- else:
309
- return "sd", *_lazy_sd_pipes()
310
 
311
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
312
  pdir = project_dir(pid)
@@ -316,42 +286,47 @@ def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
316
 
317
  def _significant_change(curr_desc: str, prev_desc: str) -> bool:
318
  """
319
- Heuristic: if symmetric difference of tokens is large -> treat as a new scene,
320
- so we should text2img (seed keeps style) instead of img2img.
321
  """
322
  if not prev_desc: return True
323
  a = set(re.findall(r"\w+", curr_desc.lower()))
324
  b = set(re.findall(r"\w+", prev_desc.lower()))
325
- # weights: boost composition-y words
326
  comp_words = {"wide","close","low","high","overhead","aerial","profile","left","right","center",
327
  "portrait","landscape","long","establishing","macro","tilt","dutch","angle",
328
  "night","day","sunset","sunrise","noon","backlit","rim","key","fill"}
329
  delta = a.symmetric_difference(b)
330
  score = len(delta) + 2 * len((a ^ b) & comp_words)
331
- return score >= 12 # tune threshold 10–16
332
 
333
  @spaces.GPU(duration=180)
334
  def generate_keyframe_image(
335
  pid: str,
336
  shot_idx: int,
337
  shots: list,
338
- t2i_steps: int = 14, # FLUX likes 12–20
339
- i2i_steps: int = 16,
340
- i2i_strength: float = 0.8, # higher = follow prompt more
341
- guidance_scale: float = 3.0, # FLUX sweet spot ~2.5–3.5
342
  width: int = 640,
343
  height: int = 640
344
  ):
345
  """
346
- Generate image for shots[shot_idx].
347
  - shot 0: text2img
348
  - shot k>0: smart chaining
349
  * if significant change: text2img (same seed for style)
350
  * else: img2img from previous approved image
351
  """
352
- mode, t2i, i2i = _try_get_pipes()
353
- shot = shots[shot_idx]
 
 
 
 
 
354
 
 
355
  prompt = (shot.get("description") or "").strip()
356
  negative = shot.get("negative") or ""
357
  seed = shot.get("seed", None)
@@ -373,49 +348,27 @@ def generate_keyframe_image(
373
  prev_desc = shots[shot_idx - 1].get("description") or ""
374
  use_prev = not _significant_change(prompt, prev_desc)
375
 
376
- # invoke
377
- if mode == "flux":
378
- if not use_prev:
379
- out = t2i(
380
- prompt=prompt,
381
- negative_prompt=negative or None,
382
- num_inference_steps=int(max(8, t2i_steps)),
383
- guidance_scale=float(max(2.0, guidance_scale)),
384
- generator=gen,
385
- width=width, height=height
386
- ).images[0]
387
- else:
388
- init_image = Image.open(prev_path).convert("RGB")
389
- out = i2i(
390
- prompt=prompt,
391
- negative_prompt=negative or None,
392
- image=init_image,
393
- strength=float(min(max(i2i_strength, 0.5), 0.95)),
394
- num_inference_steps=int(max(10, i2i_steps)),
395
- guidance_scale=float(max(2.0, guidance_scale)),
396
- generator=gen
397
- ).images[0]
398
  else:
399
- # SD-turbo fallback (keep your original behavior but with less mushy defaults)
400
- if not use_prev:
401
- out = t2i(
402
- prompt=prompt,
403
- negative_prompt=negative,
404
- guidance_scale=1.0,
405
- num_inference_steps=int(max(6, t2i_steps//2)),
406
- generator=gen,
407
- width=width, height=height
408
- ).images[0]
409
- else:
410
- init_image = Image.open(prev_path).convert("RGB")
411
- out = i2i(
412
- prompt=prompt,
413
- negative_prompt=negative,
414
- image=init_image,
415
- strength=float(min(max(i2i_strength, 0.55), 0.9)),
416
- num_inference_steps=int(max(8, i2i_steps//2)),
417
- generator=gen
418
- ).images[0]
419
 
420
  saved_path = _save_keyframe(pid, int(shot["id"]), out)
421
  return saved_path
@@ -449,8 +402,13 @@ def df_to_shots(df: pd.DataFrame) -> list:
449
  # Gradio UI
450
  # =========================
451
  with gr.Blocks() as demo:
452
- gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
453
- gr.Markdown("Edit storyboard prompts, then generate keyframes. **Smart chaining**: only reuse the previous image if the new prompt is similar; otherwise we regenerate from text with the same seed for style consistency.")
 
 
 
 
 
454
 
455
  # State
456
  project = gr.State(None)
@@ -497,10 +455,10 @@ with gr.Blocks() as demo:
497
  with gr.Row():
498
  gen_btn = gr.Button("Generate / Regenerate", variant="primary")
499
  approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
500
- # tuning controls (defaults tuned for FLUX; fallback will downshift)
501
  with gr.Row():
502
- img_strength = gr.Slider(0.50, 0.95, value=0.80, step=0.05, label="Change vs Consistency (img2img strength)")
503
- img_steps = gr.Slider(8, 28, value=16, step=1, label="Inference Steps (img2img)")
504
  guidance = gr.Slider(2.0, 4.0, value=3.0, step=0.1, label="Guidance Scale")
505
  with gr.Row():
506
  prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
@@ -607,7 +565,7 @@ with gr.Blocks() as demo:
607
  p["meta"]["id"],
608
  int(idx),
609
  shots,
610
- t2i_steps=14, # tuned for FLUX
611
  i2i_steps=int(i2i_steps_val),
612
  i2i_strength=float(i2i_strength_val),
613
  guidance_scale=float(guidance_val),
@@ -673,4 +631,5 @@ with gr.Blocks() as demo:
673
  load_btn.click(on_load, inputs=[load_file], outputs=[project, sb_status, shots_df, proj_seed_box])
674
 
675
  if __name__ == "__main__":
 
676
  demo.launch()
 
1
+ # app.py (FLUX-only, smart chaining)
2
  import os, json, uuid, re
3
  from datetime import datetime
4
  import gradio as gr
 
247
  return _normalize_shots(shots_raw, default_fps, default_len)
248
 
249
  # =========================
250
+ # IMAGE GEN β€” FLUX only (no fallback)
251
  # =========================
252
  USE_CUDA = torch.cuda.is_available()
253
  DTYPE = torch.float16 if USE_CUDA else torch.float32
254
 
255
+ FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-Nano")
 
256
 
257
  _flux_t2i = None
258
  _flux_i2i = None
 
 
 
259
 
260
  def _lazy_flux_pipes():
 
261
  from diffusers import FluxPipeline, FluxImg2ImgPipeline
262
  global _flux_t2i, _flux_i2i
263
  if _flux_t2i is not None and _flux_i2i is not None:
 
268
  if USE_CUDA: _flux_i2i = _flux_i2i.to("cuda")
269
  return _flux_t2i, _flux_i2i
270
 
271
+ def _flux_healthcheck():
272
+ try:
273
+ _lazy_flux_pipes()
274
+ return True
275
+ except Exception as e:
276
+ raise RuntimeError(
277
+ f"FLUX failed to initialize: {e}\n"
278
+ f"FLUX_MODEL='{FLUX_MODEL}'. If the repo is gated/private, set HF_TOKEN in env."
279
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
282
  pdir = project_dir(pid)
 
286
 
287
  def _significant_change(curr_desc: str, prev_desc: str) -> bool:
288
  """
289
+ If token-level symmetric difference is large, treat as a new scene:
290
+ do text2img (same seed) instead of img2img to avoid 'mush'.
291
  """
292
  if not prev_desc: return True
293
  a = set(re.findall(r"\w+", curr_desc.lower()))
294
  b = set(re.findall(r"\w+", prev_desc.lower()))
 
295
  comp_words = {"wide","close","low","high","overhead","aerial","profile","left","right","center",
296
  "portrait","landscape","long","establishing","macro","tilt","dutch","angle",
297
  "night","day","sunset","sunrise","noon","backlit","rim","key","fill"}
298
  delta = a.symmetric_difference(b)
299
  score = len(delta) + 2 * len((a ^ b) & comp_words)
300
+ return score >= 10 # more eager to break chaining
301
 
302
  @spaces.GPU(duration=180)
303
  def generate_keyframe_image(
304
  pid: str,
305
  shot_idx: int,
306
  shots: list,
307
+ t2i_steps: int = 16, # FLUX: 12–20
308
+ i2i_steps: int = 18, # FLUX: 14–22
309
+ i2i_strength: float = 0.85, # higher -> follow prompt more
310
+ guidance_scale: float = 3.0, # FLUX sweet spot: ~2.8–3.2
311
  width: int = 640,
312
  height: int = 640
313
  ):
314
  """
315
+ Generate image for shots[shot_idx] using FLUX only.
316
  - shot 0: text2img
317
  - shot k>0: smart chaining
318
  * if significant change: text2img (same seed for style)
319
  * else: img2img from previous approved image
320
  """
321
+ try:
322
+ t2i, i2i = _lazy_flux_pipes()
323
+ except Exception as e:
324
+ raise gr.Error(
325
+ f"FLUX failed to load: {e}\n"
326
+ "Set FLUX_MODEL (e.g., 'black-forest-labs/FLUX.1-Nano') and ensure HF_TOKEN if required."
327
+ )
328
 
329
+ shot = shots[shot_idx]
330
  prompt = (shot.get("description") or "").strip()
331
  negative = shot.get("negative") or ""
332
  seed = shot.get("seed", None)
 
348
  prev_desc = shots[shot_idx - 1].get("description") or ""
349
  use_prev = not _significant_change(prompt, prev_desc)
350
 
351
+ # generate
352
+ if not use_prev:
353
+ out = t2i(
354
+ prompt=prompt,
355
+ negative_prompt=negative or None,
356
+ num_inference_steps=int(max(8, t2i_steps)),
357
+ guidance_scale=float(max(2.0, guidance_scale)),
358
+ generator=gen,
359
+ width=width, height=height
360
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
361
  else:
362
+ init_image = Image.open(prev_path).convert("RGB")
363
+ out = i2i(
364
+ prompt=prompt,
365
+ negative_prompt=negative or None,
366
+ image=init_image,
367
+ strength=float(min(max(i2i_strength, 0.5), 0.95)),
368
+ num_inference_steps=int(max(10, i2i_steps)),
369
+ guidance_scale=float(max(2.0, guidance_scale)),
370
+ generator=gen
371
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
372
 
373
  saved_path = _save_keyframe(pid, int(shot["id"]), out)
374
  return saved_path
 
402
  # Gradio UI
403
  # =========================
404
  with gr.Blocks() as demo:
405
+ gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ (Videos soon) β†’ Export")
406
+ gr.Markdown(
407
+ "Edit storyboard prompts, then generate keyframes. "
408
+ "**Smart chaining**: only reuse the previous image if the new prompt is similar; "
409
+ "otherwise we regenerate from text with the same seed for style consistency. "
410
+ "**Model**: FLUX-only."
411
+ )
412
 
413
  # State
414
  project = gr.State(None)
 
455
  with gr.Row():
456
  gen_btn = gr.Button("Generate / Regenerate", variant="primary")
457
  approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
458
+ # tuning controls (defaults tuned for FLUX)
459
  with gr.Row():
460
+ img_strength = gr.Slider(0.50, 0.95, value=0.85, step=0.05, label="Change vs Consistency (img2img strength)")
461
+ img_steps = gr.Slider(8, 28, value=18, step=1, label="Inference Steps (img2img)")
462
  guidance = gr.Slider(2.0, 4.0, value=3.0, step=0.1, label="Guidance Scale")
463
  with gr.Row():
464
  prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
 
565
  p["meta"]["id"],
566
  int(idx),
567
  shots,
568
+ t2i_steps=16,
569
  i2i_steps=int(i2i_steps_val),
570
  i2i_strength=float(i2i_strength_val),
571
  guidance_scale=float(guidance_val),
 
631
  load_btn.click(on_load, inputs=[load_file], outputs=[project, sb_status, shots_df, proj_seed_box])
632
 
633
  if __name__ == "__main__":
634
+ _flux_healthcheck() # fail fast with clear error if FLUX isn't available
635
  demo.launch()