Prompt-2-Video / app.py
Shalmoni's picture
Update app.py
ff07852 verified
# app.py β€” FLUX-only with temporal chaining + Aggressive follow + Video stitching (backend + robust ffmpeg concat)
import os, json, uuid, re, tempfile, subprocess, shlex, shutil, json as _json
from datetime import datetime
import gradio as gr
import spaces
import torch
from PIL import Image
import pandas as pd
import requests
import imageio_ffmpeg
# =========================
# Storage helpers
# =========================
ROOT = "outputs"
os.makedirs(ROOT, exist_ok=True)
def now_iso(): return datetime.utcnow().replace(microsecond=0).isoformat() + "Z"
def new_id(): return uuid.uuid4().hex[:8]
def project_dir(pid):
path = os.path.join(ROOT, pid)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "keyframes"), exist_ok=True)
os.makedirs(os.path.join(path, "clips"), exist_ok=True)
return path
def save_project(proj):
pid = proj["meta"]["id"]
path = os.path.join(project_dir(pid), "project.json")
with open(path, "w") as f: json.dump(proj, f, indent=2)
return path
def load_project_file(file_obj):
with open(file_obj.name, "r") as f:
proj = json.load(f)
project_dir(proj["meta"]["id"])
return proj
def ensure_project(p, suggested_name="Project"):
if p is not None:
return p
pid = new_id()
name = f"{suggested_name}-{pid[:4]}"
proj = {
"meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()},
"shots": [], # each shot: id,title,description,duration,fps,steps,seed,negative,image_path
"clips": [],
# optional: "seed" filled later
}
save_project(proj)
return proj
# =========================
# LLM (ZeroGPU) β€” Storyboard generator (robust)
# =========================
from transformers import AutoTokenizer, AutoModelForCausalLM
STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "1200"))
_tokenizer = None
_model = None
def _lazy_model_tok():
global _tokenizer, _model
if _tokenizer is not None and _model is not None:
return _model, _tokenizer
_tokenizer = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True)
use_cuda = torch.cuda.is_available()
preferred_dtype = torch.float16 if use_cuda else torch.float32
_model = AutoModelForCausalLM.from_pretrained(
STORYBOARD_MODEL,
device_map="auto",
torch_dtype=preferred_dtype,
trust_remote_code=True,
use_safetensors=True
)
if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
_tokenizer.pad_token_id = _tokenizer.eos_token_id
return _model, _tokenizer
def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
return (
"You are a cinematographer and storyboard artist. "
"Given a story idea, break it into a sequence of visually DISTINCT, DETAILED shots. "
"For each shot, provide the objects in the scene, very specific camera placement, angle, subject position, lighting, and background details. "
"Imagine you're describing frames for a film storyboard, not vague events.\n\n"
"Return ONLY a JSON array enclosed between <JSON> and </JSON> tags.\n"
f"Create a storyboard of {n_shots} shots for this idea:\n\n"
f"'''{user_prompt}'''\n\n"
"Each item schema:\n"
"{\n"
' "id": <int starting at 1>,\n'
' "title": "Short shot title",\n'
' "description": "Highly specific visual description for image generation. Include camera angle, framing, time of day, subject position, lighting, mood, and background details.",\n'
f' "duration": {default_len},\n'
f' "fps": {default_fps},\n'
' "steps": 30,\n'
' "seed": null,\n'
' "negative": ""\n'
"}\n\n"
"Output must start with <JSON> and end with </JSON>.\n"
)
def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
return (
"Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n"
f"Storyboard: {n_shots} shots for:\n'''{user_prompt}'''\n"
"Item schema:\n"
"{\n"
' "id": <int starting at 1>,\n'
' "title": "Short title",\n'
' "description": "Visual description",\n'
f' "duration": {default_len},\n'
f' "fps": {default_fps},\n'
' "steps": 30,\n'
' "seed": null,\n'
' "negative": ""\n'
"}\n"
)
def _apply_chat(tok, system_msg: str, user_msg: str) -> str:
if hasattr(tok, "apply_chat_template"):
return tok.apply_chat_template(
[{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg}],
tokenize=False,
add_generation_prompt=True
)
return system_msg + "\n\n" + user_msg
def _generate_text(model, tok, prompt_text: str) -> str:
inputs = tok(prompt_text, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
eos_id = tok.eos_token_id or tok.pad_token_id
gen = model.generate(
**inputs,
max_new_tokens=HF_TASK_MAX_TOKENS,
do_sample=False,
temperature=0.0,
repetition_penalty=1.05,
eos_token_id=eos_id,
pad_token_id=eos_id,
)
prompt_len = inputs["input_ids"].shape[1]
continuation_ids = gen[0][prompt_len:]
text = tok.decode(continuation_ids, skip_special_tokens=True).strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE | re.DOTALL).strip()
return text
def _extract_json_array(text: str) -> str:
m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE)
if m:
inner = m.group(1).strip()
if inner:
return inner
start = text.find("[")
if start == -1:
return ""
depth = 0
in_str = False
prev = ""
for i in range(start, len(text)):
ch = text[i]
if ch == '"' and prev != '\\':
in_str = not in_str
if not in_str:
if ch == "[":
depth += 1
elif ch == "]":
depth -= 1
if depth == 0:
return text[start:i+1].strip()
prev = ch
return ""
def _normalize_shots(shots_raw, default_fps: int, default_len: int):
norm = []
for i, s in enumerate(shots_raw, start=1):
norm.append({
"id": int(s.get("id", i)),
"title": s.get("title", f"Shot {i}"),
"description": s.get("description", ""),
"duration": int(s.get("duration", default_len)),
"fps": int(s.get("fps", default_fps)),
"steps": int(s.get("steps", 30)),
"seed": s.get("seed", None),
"negative": s.get("negative", ""),
"image_path": s.get("image_path", None)
})
return norm
@spaces.GPU(duration=180)
def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
model, tok = _lazy_model_tok()
system = "You are a film previsualization assistant. Output must be valid JSON."
p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
_prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
out1 = _generate_text(model, tok, p1)
json_text = _extract_json_array(out1)
if not json_text:
p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
_prompt_minimal(user_prompt, n_shots, default_fps, default_len))
out2 = _generate_text(model, tok, p2)
json_text = _extract_json_array(out2)
if not json_text and "[" in out2 and "]" in out2:
start = out2.find("["); end = out2.rfind("]")
if start != -1 and end != -1 and end > start:
json_text = out2[start:end+1].strip()
if not json_text or not json_text.strip():
fallback = []
for i in range(1, int(n_shots) + 1):
fallback.append({
"id": i,
"title": f"Shot {i}",
"description": f"Simple placeholder for: {user_prompt[:80]}",
"duration": default_len,
"fps": default_fps,
"steps": 30,
"seed": None,
"negative": "",
"image_path": None
})
return fallback
try:
shots_raw = json.loads(json_text)
except Exception:
json_text_clean = re.sub(r",\s*([\]\}])", r"\1", json_text)
shots_raw = json.loads(json_text_clean)
return _normalize_shots(shots_raw, default_fps, default_len)
# =========================
# IMAGE GEN β€” FLUX only (no fallback) + Temporal chaining
# =========================
USE_CUDA = torch.cuda.is_available()
DTYPE = torch.float16 if USE_CUDA else torch.float32
# Gated repo; accept access and set HF_TOKEN
FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-schnell")
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
# I2V backend for video between frames
I2V_ENDPOINT = os.getenv(
"I2V_ENDPOINT",
"https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
)
_flux_t2i = None
_flux_i2i = None
def _lazy_flux_pipes():
from diffusers import FluxPipeline, FluxImg2ImgPipeline
global _flux_t2i, _flux_i2i
if _flux_t2i is not None and _flux_i2i is not None:
return _flux_t2i, _flux_i2i
_flux_t2i = FluxPipeline.from_pretrained(
FLUX_MODEL, torch_dtype=DTYPE, use_safetensors=True, token=HF_TOKEN
)
if USE_CUDA: _flux_t2i = _flux_t2i.to("cuda")
_flux_i2i = FluxImg2ImgPipeline.from_pretrained(
FLUX_MODEL, torch_dtype=DTYPE, use_safetensors=True, token=HF_TOKEN
)
if USE_CUDA: _flux_i2i = _flux_i2i.to("cuda")
return _flux_t2i, _flux_i2i
def _flux_healthcheck():
if not HF_TOKEN:
raise RuntimeError(
"HF_TOKEN is not set. FLUX models are gated; set a Hugging Face READ token "
"and accept the model terms on the repo page."
)
_lazy_flux_pipes()
def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
pdir = project_dir(pid)
out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png")
img.save(out)
return out
# ---- Temporal prompt composer (PRIORITIZE the new shot) ----
def _compose_temporal_prompt(shots: list, idx: int, seconds_forward: int = 5):
"""
Build a prompt that continues the scene N seconds later,
prioritizing the NEW shot description (composition/action),
while keeping only identity/lighting/environment continuity.
Returns (composed_prompt, composed_negative).
"""
curr = shots[idx]
curr_desc = (curr.get("description") or "").strip()
curr_neg = (curr.get("negative") or "").strip()
if idx == 0:
return curr_desc, curr_neg
prev = shots[idx - 1]
prev_desc = (prev.get("description") or "").strip()
composed = (
f"Continue the same scene {seconds_forward} seconds later.\n"
f'PRIORITIZE this new moment and its composition now: "{curr_desc}".\n'
"Keep continuity ONLY for subject identity, lighting palette, time of day, and general environment style.\n"
f'Previous frame (context only, do not copy its framing): "{prev_desc}".\n'
f"Avoid replicating the previous composition; allow camera move / subject reposition consistent with {seconds_forward} seconds of natural progression."
).strip()
negative = (
curr_neg + (
"; identical composition as previous; exact same framing; rigid pose repeat; freeze frame; "
"hard scene reset; different subject identity; wildly different art style; unrelated background"
)
).strip("; ")
return composed, negative
@spaces.GPU(duration=180)
def generate_keyframe_image(
pid: str,
shot_idx: int,
shots: list,
t2i_steps: int = 18, # FLUX: 12–22
i2i_steps: int = 22, # FLUX: 16–26
i2i_strength: float = 0.90, # more change toward new prompt
guidance_scale: float = 3.4, # stronger text pull
width: int = 640,
height: int = 640,
seconds_forward: int = 5, # temporal step
aggressive: bool = False # optional push
):
"""
Generate image for shots[shot_idx] using FLUX only.
- Shot 1: text2img
- Shot k>1: img2img from previous approved frame + temporal prompt ("N seconds later")
"""
try:
t2i, i2i = _lazy_flux_pipes()
except Exception as e:
raise gr.Error(
f"FLUX failed to load: {e}\n"
"Set FLUX_MODEL (e.g., 'black-forest-labs/FLUX.1-schnell') and ensure HF_TOKEN if required."
)
# Build temporal prompt
composed_prompt, composed_negative = _compose_temporal_prompt(shots, shot_idx, seconds_forward=seconds_forward)
# RNG / seed
seed = shots[shot_idx].get("seed", None)
device = "cuda" if USE_CUDA else "cpu"
gen = torch.Generator(device)
if isinstance(seed, int):
gen = gen.manual_seed(int(seed))
# sizes
width = max(256, min(1024, int(width)))
height = max(256, min(1024, int(height)))
# chaining
prev_path = shots[shot_idx - 1].get("image_path") if shot_idx > 0 else None
use_prev = bool(shot_idx > 0 and prev_path and os.path.exists(prev_path))
# Aggressive mode bumps
if aggressive:
i2i_strength = min(0.98, max(i2i_strength, 0.92))
guidance_scale = max(guidance_scale, 3.6)
i2i_steps = max(i2i_steps, 24)
# generate
if not use_prev:
out = t2i(
prompt=composed_prompt,
negative_prompt=composed_negative or None,
num_inference_steps=int(max(10, t2i_steps)),
guidance_scale=float(max(2.4, guidance_scale)),
generator=gen,
width=width, height=height
).images[0]
else:
init_image = Image.open(prev_path).convert("RGB") # previous approved frame (the "init_image")
out = i2i(
prompt=composed_prompt,
negative_prompt=composed_negative or None,
image=init_image,
strength=float(min(max(i2i_strength, 0.70), 0.98)),
num_inference_steps=int(max(14, i2i_steps)),
guidance_scale=float(max(2.4, guidance_scale)),
generator=gen
).images[0]
saved_path = _save_keyframe(pid, int(shots[shot_idx]["id"]), out)
return saved_path
# =========================
# Video stitching helpers (backend per pair + robust ffmpeg concat)
# =========================
def _pair_clip_path(pid: str, i: int, j: int) -> str:
return os.path.join(project_dir(pid), "clips", f"pair_{i:02d}_to_{j:02d}.mp4")
def _final_stitched_path(pid: str) -> str:
return os.path.join(project_dir(pid), "clips", "final_stitched.mp4")
def _call_i2v_backend(img_a_path: str, img_b_path: str, prompt: str, seed, endpoint: str) -> bytes:
"""
Calls Modal backend with two images to get a transition clip (mp4 bytes).
"""
params = {}
if prompt:
params["prompt"] = prompt
if seed is not None:
try:
params["seed"] = str(int(seed))
except Exception:
pass
with open(img_a_path, "rb") as fa, open(img_b_path, "rb") as fb:
files = {
"image_bytes": ("start.png", fa, "application/octet-stream"),
"image_bytes_end": ("end.png", fb, "application/octet-stream"),
}
r = requests.post(endpoint, params=params, files=files, headers={"accept": "application/json"})
if r.status_code != 200:
raise gr.Error(f"I2V backend error {r.status_code}: {r.text[:400]}")
return r.content
def _build_all_pair_videos_backend(pid: str, shots: list, endpoint: str, prompt: str, seed) -> list[str]:
out_paths = []
for k in range(len(shots) - 1):
a = shots[k].get("image_path")
b = shots[k + 1].get("image_path")
if not (a and b and os.path.exists(a) and os.path.exists(b)):
continue
mp4_bytes = _call_i2v_backend(a, b, prompt=prompt, seed=seed, endpoint=endpoint)
outp = _pair_clip_path(pid, shots[k]["id"], shots[k + 1]["id"])
with open(outp, "wb") as f:
f.write(mp4_bytes)
out_paths.append(outp)
return out_paths
# ---------- robust concat (normalize + concat demuxer; fallback re-encode once)
def _ffprobe_stream(path):
ffmpeg = imageio_ffmpeg.get_ffmpeg_exe()
ffprobe = ffmpeg.replace("ffmpeg", "ffprobe")
cmd = [
ffprobe,
"-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=width,height,avg_frame_rate,pix_fmt,codec_name",
"-of", "json",
path,
]
try:
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
data = _json.loads(out.decode("utf-8"))
return (data.get("streams") or [{}])[0]
except Exception:
return {}
def _ffmpeg_safe_concat(mp4_paths: list[str], out_path: str, fps: int = 24, size=None):
if not mp4_paths:
raise gr.Error("No clips to stitch.")
clips_dir = os.path.dirname(out_path)
os.makedirs(clips_dir, exist_ok=True)
log_path = os.path.join(clips_dir, "ffmpeg_concat.log")
ffmpeg = imageio_ffmpeg.get_ffmpeg_exe()
# Determine reference WxH
ref_w, ref_h = None, None
if size and isinstance(size, (tuple, list)) and len(size) == 2:
ref_w, ref_h = int(size[0]), int(size[1])
else:
for p in mp4_paths:
st = _ffprobe_stream(p)
if st.get("width") and st.get("height"):
ref_w, ref_h = int(st["width"]), int(st["height"])
break
if ref_w and ref_h:
if ref_w % 2: ref_w += 1
if ref_h % 2: ref_h += 1
norm_paths = []
tmpdir = tempfile.mkdtemp(prefix="norm_")
try:
# Normalize each clip to consistent fps/size/codec/pixfmt
for i, inp in enumerate(mp4_paths):
if not os.path.exists(inp):
continue
norm = os.path.join(tmpdir, f"norm_{i:03d}.mp4")
vf = []
if ref_w and ref_h:
vf.append(f"scale=w={ref_w}:h={ref_h}:force_original_aspect_ratio=decrease")
vf.append(f"pad={ref_w}:{ref_h}:(ow-iw)/2:(oh-ih)/2:color=black")
vf.append(f"fps={int(fps)}")
vf_arg = ",".join(vf)
cmd = [
ffmpeg, "-y", "-i", inp,
"-vf", vf_arg,
"-an",
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
"-profile:v", "main",
"-preset", "veryfast",
"-crf", "18",
"-movflags", "+faststart",
norm
]
with open(log_path, "ab") as lg:
lg.write(("NORMALIZE: " + " ".join(cmd) + "\n").encode())
subprocess.check_call(cmd, stdout=lg, stderr=lg)
norm_paths.append(norm)
if not norm_paths:
raise gr.Error("No inputs could be normalized for concat.")
# concat demuxer (stream copy)
listfile = os.path.join(tmpdir, "list.txt")
with open(listfile, "w") as f:
for p in norm_paths:
f.write(f"file '{p}'\n")
cmd_concat_copy = [
ffmpeg, "-y",
"-f", "concat", "-safe", "0", "-i", listfile,
"-c", "copy",
out_path
]
try:
with open(log_path, "ab") as lg:
lg.write(("CONCAT_COPY: " + " ".join(cmd_concat_copy) + "\n").encode())
subprocess.check_call(cmd_concat_copy, stdout=lg, stderr=lg)
return out_path
except subprocess.CalledProcessError:
pass # fallback
# filter_complex concat (re-encode once)
cmd = [ffmpeg, "-y"]
for p in norm_paths:
cmd += ["-i", p]
n = len(norm_paths)
inputs = "".join([f"[{i}:v]" for i in range(n)])
filtergraph = f"{inputs}concat=n={n}:v=1:a=0[outv]"
cmd += [
"-filter_complex", filtergraph,
"-map", "[outv]",
"-an",
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
"-profile:v", "main",
"-preset", "veryfast",
"-crf", "18",
"-r", str(int(fps)),
"-movflags", "+faststart",
out_path
]
with open(log_path, "ab") as lg:
lg.write(("CONCAT_REENC: " + " ".join(cmd) + "\n").encode())
subprocess.check_call(cmd, stdout=lg, stderr=lg)
return out_path
except subprocess.CalledProcessError:
raise gr.Error("ffmpeg concat failed (copy and re-encode). See logs at: " + log_path)
except Exception as e:
raise gr.Error(f"Concat error: {e}")
finally:
try:
shutil.rmtree(tmpdir)
except Exception:
pass
# =========================
# Shots <-> DataFrame utils
# =========================
SHOT_COLUMNS = ["id", "title", "description", "duration", "fps", "steps", "seed", "negative", "image_path"]
def shots_to_df(shots: list) -> pd.DataFrame:
rows = [{k: s.get(k, None) for k in SHOT_COLUMNS} for s in shots]
return pd.DataFrame(rows, columns=SHOT_COLUMNS)
def df_to_shots(df: pd.DataFrame) -> list:
out = []
for _, row in df.iterrows():
out.append({
"id": int(row["id"]),
"title": (row["title"] or f"Shot {int(row['id'])}"),
"description": row["description"] or "",
"duration": int(row["duration"]) if pd.notna(row["duration"]) else 4,
"fps": int(row["fps"]) if pd.notna(row["fps"]) else 24,
"steps": int(row["steps"]) if pd.notna(row["steps"]) else 30,
"seed": (int(row["seed"]) if pd.notna(row["seed"]) else None),
"negative": row["negative"] or "",
"image_path": row["image_path"] if pd.notna(row["image_path"]) else None
})
return sorted(out, key=lambda x: x["id"])
# =========================
# Gradio UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
gr.Markdown(
"Edit storyboard prompts, then generate keyframes.\n"
"**Temporal chaining**: each new shot is generated N seconds later from the previous approved frame, "
"while the current shot description drives composition & action. **Model**: FLUX-only."
)
# State
project = gr.State(None)
current_idx = gr.State(0)
# Header
with gr.Row():
with gr.Column(scale=2):
proj_name = gr.Textbox(label="Project name", placeholder="e.g., Desert Chase")
with gr.Column(scale=1):
new_btn = gr.Button("New Project", variant="primary")
with gr.Column(scale=1):
save_btn = gr.Button("Save Project")
with gr.Column(scale=1):
load_file = gr.File(label="Load Project (project.json)", file_count="single", type="filepath")
load_btn = gr.Button("Load")
sb_status = gr.Markdown("")
# Tabs
with gr.Tabs():
with gr.Tab("Storyboard"):
gr.Markdown("### 1) Storyboard")
sb_prompt = gr.Textbox(label="High-level prompt", lines=4, placeholder="Describe the story you want to create…")
with gr.Row():
sb_target_shots = gr.Slider(1, 12, value=3, step=1, label="Target # of shots")
sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS")
sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds per shot")
propose_btn = gr.Button("Propose Storyboard (LLM on ZeroGPU)")
shots_df = gr.Dataframe(
headers=SHOT_COLUMNS,
datatype=["number","str","str","number","number","number","number","str","str"],
row_count=(1,"dynamic"), col_count=len(SHOT_COLUMNS),
label="Edit shots below (prompts & params)", wrap=True
)
save_edits_btn = gr.Button("Save Edits βœ“", variant="primary", interactive=False)
with gr.Row():
proj_seed_box = gr.Number(label="Project Seed (locked across shots)", precision=0)
to_keyframes_btn = gr.Button("Start Keyframes β†’", variant="secondary")
with gr.Tab("Keyframes"):
gr.Markdown("### 2) Keyframes")
shot_info_md = gr.Markdown("")
prompt_box = gr.Textbox(label="Shot description (editable before generating)", lines=4)
with gr.Row():
gen_btn = gr.Button("Generate / Regenerate", variant="primary")
approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
with gr.Row():
img_strength = gr.Slider(0.50, 0.98, value=0.90, step=0.02, label="Change vs Consistency (img2img strength)")
img_steps = gr.Slider(12, 28, value=22, step=1, label="Inference Steps (img2img)")
guidance = gr.Slider(2.4, 4.0, value=3.4, step=0.1, label="Guidance Scale")
temporal_secs = gr.Slider(1, 10, value=5, step=1, label="Temporal step (seconds later)")
aggressive_follow = gr.Checkbox(value=False, label="Aggressive follow prompt (more change)")
with gr.Row():
prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
out_img = gr.Image(label="Generated image", type="filepath")
kf_status = gr.Markdown("")
with gr.Tab("Videos"):
gr.Markdown("### 3) Videos")
with gr.Row():
v_fps = gr.Slider(8, 60, value=24, step=1, label="FPS (display only)")
v_hold = gr.Slider(0.0, 2.0, value=0.5, step=0.1, label="Hold per still (UI only)")
v_xfade = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Crossfade (UI only)")
with gr.Row():
build_pairs_btn = gr.Button("Build pair clips (A→B, B→C, ...)", variant="primary")
build_final_btn = gr.Button("Build final stitched video", variant="secondary")
vd_table = gr.JSON(label="Rendered outputs (paths)")
with gr.Tab("Export"):
gr.Markdown("### 4) Export (coming next)")
export_info = gr.Markdown("Nothing to export yet.")
# ---------- Handlers ----------
def on_new(name):
p = ensure_project(None, suggested_name=(name or "Project"))
return p, gr.update(value=f"**New project created** `{p['meta']['name']}` (id: `{p['meta']['id']}`)")
new_btn.click(on_new, inputs=[proj_name], outputs=[project, sb_status])
def on_propose(p, prompt, target_shots, fps, vlen):
p = ensure_project(p, suggested_name=(proj_name.value if hasattr(proj_name, "value") else "Project"))
if not prompt or not str(prompt).strip():
raise gr.Error("Please enter a high-level prompt.")
shots = generate_storyboard_with_llm(str(prompt).strip(), int(target_shots), int(fps), int(vlen))
p = dict(p)
p["shots"] = shots
p["meta"]["updated"] = now_iso()
save_project(p)
return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable)."), gr.update(interactive=True)
propose_btn.click(
on_propose,
inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len],
outputs=[project, shots_df, sb_status, save_edits_btn]
)
def on_save_edits(p, df):
if p is None:
raise gr.Error("No project in memory. Click New Project, then generate a storyboard.")
if df is None:
raise gr.Error("No storyboard table to save. Generate a storyboard first, then edit it.")
shots = df_to_shots(df)
p = dict(p)
p["shots"] = shots
p["meta"]["updated"] = now_iso()
save_project(p)
return p, gr.update(value="Edits saved.")
save_edits_btn.click(on_save_edits, inputs=[project, shots_df], outputs=[project, sb_status])
def on_start_keyframes(p, df, proj_seed_override):
if p is None: raise gr.Error("No project.")
shots = df_to_shots(df)
if not shots: raise gr.Error("Storyboard is empty.")
# lock a single seed for the project:
proj_seed = None
if proj_seed_override not in [None, ""] and str(proj_seed_override).isdigit():
proj_seed = int(proj_seed_override)
if proj_seed is None:
proj_seed = p.get("meta", {}).get("seed", None)
if proj_seed is None:
for s in shots:
if isinstance(s.get("seed"), int):
proj_seed = int(s["seed"])
break
if proj_seed is None:
proj_seed = int(torch.randint(0, 2**31 - 1, (1,)).item())
for s in shots:
if not isinstance(s.get("seed"), int):
s["seed"] = proj_seed
p = dict(p)
p["shots"] = shots
p["meta"]["seed"] = proj_seed
p["meta"]["updated"] = now_iso()
save_project(p)
idx = 0
prev_path = None
info = (
f"**Shot {shots[idx]['id']} β€” {shots[idx]['title']}** \n"
f"Duration: {shots[idx]['duration']}s @ {shots[idx]['fps']} fps \n"
f"Locked project seed: `{proj_seed}`"
)
return p, 0, gr.update(value=info), gr.update(value=shots[idx]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value=f"Ready to generate shot 1."), gr.update(value=proj_seed)
to_keyframes_btn.click(
on_start_keyframes,
inputs=[project, shots_df, proj_seed_box],
outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status, proj_seed_box]
)
def on_generate_img(p, idx, current_prompt, i2i_strength_val, i2i_steps_val, guidance_val, seconds_forward_val, aggressive_val):
if p is None: raise gr.Error("No project.")
shots = p["shots"]
if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.")
shots[idx]["description"] = current_prompt # allow tweaking
img_path = generate_keyframe_image(
p["meta"]["id"],
int(idx),
shots,
t2i_steps=18,
i2i_steps=int(i2i_steps_val),
i2i_strength=float(i2i_strength_val),
guidance_scale=float(guidance_val),
width=640,
height=640,
seconds_forward=int(seconds_forward_val),
aggressive=bool(aggressive_val)
)
prev_path = shots[idx-1]["image_path"] if idx > 0 else None
return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.")
gen_btn.click(
on_generate_img,
inputs=[project, current_idx, prompt_box, img_strength, img_steps, guidance, temporal_secs, aggressive_follow],
outputs=[out_img, prev_img, kf_status]
)
def on_approve_next(p, idx, current_prompt, latest_img_path):
if p is None: raise gr.Error("No project.")
shots = p["shots"]
i = int(idx)
if i < 0 or i >= len(shots): raise gr.Error("Invalid shot index.")
if not latest_img_path: raise gr.Error("Generate an image first.")
# commit
shots[i]["description"] = current_prompt
shots[i]["image_path"] = latest_img_path
p["shots"] = shots
p["meta"]["updated"] = now_iso()
save_project(p)
# next
if i + 1 < len(shots):
ni = i + 1
info = (
f"**Shot {shots[ni]['id']} β€” {shots[ni]['title']}** \n"
f"Duration: {shots[ni]['duration']}s @ {shots[ni]['fps']} fps \n"
f"Locked project seed: `{p['meta'].get('seed')}`"
)
prev_path = shots[ni-1]["image_path"]
return p, ni, gr.update(value=info), gr.update(value=shots[ni]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value=f"Approved shot {shots[i]['id']}. On to shot {shots[ni]['id']}.")
else:
return p, i, gr.update(value="**All keyframes approved.** Proceed to Videos tab."), gr.update(value=""), gr.update(value=shots[i]["image_path"]), gr.update(value=None), gr.update(value="All shots approved βœ…")
approve_next_btn.click(on_approve_next, inputs=[project, current_idx, prompt_box, out_img], outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status])
# ---- Videos tab handlers (backend + robust ffmpeg)
def on_build_pairs(p, fps, hold, xfade):
if p is None:
raise gr.Error("No project.")
shots = p.get("shots", [])
if len(shots) < 2:
raise gr.Error("Need at least 2 approved images to build pair clips.")
if not any(s.get("image_path") for s in shots):
raise gr.Error("No approved images yet. Approve keyframes first.")
seed = p.get("meta", {}).get("seed", None)
titles = " -> ".join([s.get("title") or f"Shot {s.get('id')}" for s in shots])
context_prompt = f"Transition between consecutive storyboard frames. Sequence: {titles}"
pair_paths = _build_all_pair_videos_backend(
p["meta"]["id"], shots,
endpoint=I2V_ENDPOINT,
prompt=context_prompt,
seed=seed
)
if not pair_paths:
raise gr.Error("Could not create any pair clips (missing consecutive images).")
return {"pair_clips": pair_paths, "final": None}
build_pairs_btn.click(
on_build_pairs,
inputs=[project, v_fps, v_hold, v_xfade],
outputs=[vd_table]
)
def on_build_final(p, fps):
if p is None:
raise gr.Error("No project.")
pid = p["meta"]["id"]
clips_dir = os.path.join(project_dir(pid), "clips")
pair_paths = sorted(
[os.path.join(clips_dir, f) for f in os.listdir(clips_dir) if f.startswith("pair_") and f.endswith(".mp4")]
)
if not pair_paths:
raise gr.Error("No pair clips found. Click 'Build pair clips' first.")
outp = _final_stitched_path(pid)
_ffmpeg_safe_concat(pair_paths, outp, fps=int(fps), size=None) # set size=(640,640) to force letterbox
return {"pair_clips": pair_paths, "final": outp}
build_final_btn.click(
on_build_final,
inputs=[project, v_fps],
outputs=[vd_table]
)
def on_save(p):
if p is None:
raise gr.Error("No project in memory.")
path = save_project(p)
return gr.update(value=f"Saved to `{path}`")
save_btn.click(on_save, inputs=[project], outputs=[sb_status])
def on_load(file_obj):
p = load_project_file(file_obj)
seed_val = p.get("meta", {}).get("seed", None)
return (
p,
gr.update(value=f"Loaded project `{p['meta']['name']}` (id: `{p['meta']['id']}`)"),
shots_to_df(p.get("shots", [])),
gr.update(value=seed_val)
)
load_btn.click(on_load, inputs=[load_file], outputs=[project, sb_status, shots_df, proj_seed_box])
if __name__ == "__main__":
_flux_healthcheck() # fail fast with clear error if FLUX isn't available
demo.launch()