Spaces:
Running
on
Zero
Running
on
Zero
| # app.py β FLUX-only with temporal chaining + Aggressive follow + Video stitching (lazy MoviePy) | |
| import os, json, uuid, re, sys, subprocess | |
| from datetime import datetime | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| import pandas as pd | |
| # ========================= | |
| # Storage helpers | |
| # ========================= | |
| ROOT = "outputs" | |
| os.makedirs(ROOT, exist_ok=True) | |
| def now_iso(): return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" | |
| def new_id(): return uuid.uuid4().hex[:8] | |
| def project_dir(pid): | |
| path = os.path.join(ROOT, pid) | |
| os.makedirs(path, exist_ok=True) | |
| os.makedirs(os.path.join(path, "keyframes"), exist_ok=True) | |
| os.makedirs(os.path.join(path, "clips"), exist_ok=True) | |
| return path | |
| def save_project(proj): | |
| pid = proj["meta"]["id"] | |
| path = os.path.join(project_dir(pid), "project.json") | |
| with open(path, "w") as f: json.dump(proj, f, indent=2) | |
| return path | |
| def load_project_file(file_obj): | |
| with open(file_obj.name, "r") as f: | |
| proj = json.load(f) | |
| project_dir(proj["meta"]["id"]) | |
| return proj | |
| def ensure_project(p, suggested_name="Project"): | |
| if p is not None: | |
| return p | |
| pid = new_id() | |
| name = f"{suggested_name}-{pid[:4]}" | |
| proj = { | |
| "meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()}, | |
| "shots": [], # id,title,description,duration,fps,steps,seed,negative,image_path | |
| "clips": [], | |
| } | |
| save_project(proj) | |
| return proj | |
| # ========================= | |
| # LLM β Storyboard generator (ZeroGPU friendly) | |
| # ========================= | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct") | |
| HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "1200")) | |
| _tokenizer = None | |
| _model = None | |
| def _lazy_model_tok(): | |
| global _tokenizer, _model | |
| if _tokenizer is not None and _model is not None: | |
| return _model, _tokenizer | |
| _tokenizer = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True) | |
| use_cuda = torch.cuda.is_available() | |
| dtype = torch.float16 if use_cuda else torch.float32 | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| STORYBOARD_MODEL, device_map="auto", torch_dtype=dtype, | |
| trust_remote_code=True, use_safetensors=True | |
| ) | |
| if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None: | |
| _tokenizer.pad_token_id = _tokenizer.eos_token_id | |
| return _model, _tokenizer | |
| def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str: | |
| return ( | |
| "You are a cinematographer and storyboard artist. " | |
| "Break the idea into DISTINCT, DETAILED shots with concrete visual info: objects, camera placement/angle, subject position, lighting, background.\n\n" | |
| "Return ONLY a JSON array enclosed between <JSON> and </JSON>.\n" | |
| f"Create {n_shots} shots for:\n'''{user_prompt}'''\n\n" | |
| "Item schema:\n" | |
| "{\n" | |
| ' "id": <int starting at 1>,\n' | |
| ' "title": "Short shot title",\n' | |
| ' "description": "Highly specific visual description (camera, framing, time of day, subject position, lighting, mood, background).",\n' | |
| f' "duration": {default_len},\n' | |
| f' "fps": {default_fps},\n' | |
| ' "steps": 30,\n' | |
| ' "seed": null,\n' | |
| ' "negative": ""\n' | |
| "}\n\n" | |
| "Output must start with <JSON> and end with </JSON>.\n" | |
| ) | |
| def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str: | |
| return ( | |
| "Reply ONLY with a JSON array starting with '[' and ending with ']'.\n" | |
| f"Storyboard: {n_shots} shots for:\n'''{user_prompt}'''\n" | |
| "Item schema:\n" | |
| "{\n" | |
| ' "id": <int starting at 1>,\n' | |
| ' "title": "Short title",\n' | |
| ' "description": "Visual description",\n' | |
| f' "duration": {default_len},\n' | |
| f' "fps": {default_fps},\n' | |
| ' "steps": 30,\n' | |
| ' "seed": null,\n' | |
| ' "negative": ""\n' | |
| "}\n" | |
| ) | |
| def _apply_chat(tok, system_msg: str, user_msg: str) -> str: | |
| if hasattr(tok, "apply_chat_template"): | |
| return tok.apply_chat_template( | |
| [{"role": "system", "content": system_msg}, | |
| {"role": "user", "content": user_msg}], | |
| tokenize=False, add_generation_prompt=True | |
| ) | |
| return system_msg + "\n\n" + user_msg | |
| def _generate_text(model, tok, prompt_text: str) -> str: | |
| inputs = tok(prompt_text, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| eos_id = tok.eos_token_id or tok.pad_token_id | |
| gen = model.generate( | |
| **inputs, max_new_tokens=HF_TASK_MAX_TOKENS, do_sample=False, temperature=0.0, | |
| repetition_penalty=1.05, eos_token_id=eos_id, pad_token_id=eos_id | |
| ) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| continuation_ids = gen[0][prompt_len:] | |
| text = tok.decode(continuation_ids, skip_special_tokens=True).strip() | |
| if text.startswith("```"): | |
| text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.I|re.S).strip() | |
| return text | |
| def _extract_json_array(text: str) -> str: | |
| m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.S|re.I) | |
| if m and m.group(1).strip(): | |
| return m.group(1).strip() | |
| start = text.find("[") | |
| if start == -1: return "" | |
| depth = 0; in_str = False; prev = "" | |
| for i in range(start, len(text)): | |
| ch = text[i] | |
| if ch == '"' and prev != '\\': in_str = not in_str | |
| if not in_str: | |
| if ch == "[": depth += 1 | |
| elif ch == "]": | |
| depth -= 1 | |
| if depth == 0: return text[start:i+1].strip() | |
| prev = ch | |
| return "" | |
| def _normalize_shots(shots_raw, default_fps: int, default_len: int): | |
| norm = [] | |
| for i, s in enumerate(shots_raw, start=1): | |
| norm.append({ | |
| "id": int(s.get("id", i)), | |
| "title": s.get("title", f"Shot {i}"), | |
| "description": s.get("description", ""), | |
| "duration": int(s.get("duration", default_len)), | |
| "fps": int(s.get("fps", default_fps)), | |
| "steps": int(s.get("steps", 30)), | |
| "seed": s.get("seed", None), | |
| "negative": s.get("negative", ""), | |
| "image_path": s.get("image_path", None) | |
| }) | |
| return norm | |
| def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int): | |
| model, tok = _lazy_model_tok() | |
| system = "You are a film previsualization assistant. Output must be valid JSON." | |
| p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.", | |
| _prompt_with_tags(user_prompt, n_shots, default_fps, default_len)) | |
| out1 = _generate_text(model, tok, p1) | |
| json_text = _extract_json_array(out1) | |
| if not json_text: | |
| p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.", | |
| _prompt_minimal(user_prompt, n_shots, default_fps, default_len)) | |
| out2 = _generate_text(model, tok, p2) | |
| json_text = _extract_json_array(out2) | |
| if not json_text and "[" in out2 and "]" in out2: | |
| start, end = out2.find("["), out2.rfind("]") | |
| if start != -1 and end > start: json_text = out2[start:end+1].strip() | |
| if not json_text: | |
| return [{ | |
| "id": i, "title": f"Shot {i}", | |
| "description": f"Placeholder for: {user_prompt[:80]}", | |
| "duration": default_len, "fps": default_fps, | |
| "steps": 30, "seed": None, "negative": "", "image_path": None | |
| } for i in range(1, int(n_shots)+1)] | |
| try: | |
| shots_raw = json.loads(json_text) | |
| except Exception: | |
| shots_raw = json.loads(re.sub(r",\s*([\]\}])", r"\1", json_text)) | |
| return _normalize_shots(shots_raw, default_fps, default_len) | |
| # ========================= | |
| # IMAGE GEN β FLUX-only + Temporal chaining | |
| # ========================= | |
| USE_CUDA = torch.cuda.is_available() | |
| DTYPE = torch.float16 if USE_CUDA else torch.float32 | |
| FLUX_MODEL = os.getenv("FLUX_MODEL", "black-forest-labs/FLUX.1-schnell") # gated | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| _flux_t2i = None | |
| _flux_i2i = None | |
| def _lazy_flux_pipes(): | |
| from diffusers import FluxPipeline, FluxImg2ImgPipeline | |
| global _flux_t2i, _flux_i2i | |
| if _flux_t2i is not None and _flux_i2i is not None: | |
| return _flux_t2i, _flux_i2i | |
| _flux_t2i = FluxPipeline.from_pretrained( | |
| FLUX_MODEL, torch_dtype=DTYPE, use_safetensors=True, token=HF_TOKEN | |
| ) | |
| if USE_CUDA: _flux_t2i = _flux_t2i.to("cuda") | |
| _flux_i2i = FluxImg2ImgPipeline.from_pretrained( | |
| FLUX_MODEL, torch_dtype=DTYPE, use_safetensors=True, token=HF_TOKEN | |
| ) | |
| if USE_CUDA: _flux_i2i = _flux_i2i.to("cuda") | |
| return _flux_t2i, _flux_i2i | |
| def _flux_healthcheck(): | |
| if not HF_TOKEN: | |
| raise RuntimeError("HF_TOKEN is not set. Accept the model terms on HF and provide a READ token.") | |
| _lazy_flux_pipes() | |
| def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str: | |
| pdir = project_dir(pid) | |
| out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png") | |
| img.save(out); return out | |
| def _compose_temporal_prompt(shots: list, idx: int, seconds_forward: int = 5): | |
| curr = shots[idx] | |
| curr_desc = (curr.get("description") or "").strip() | |
| curr_neg = (curr.get("negative") or "").strip() | |
| if idx == 0: return curr_desc, curr_neg | |
| prev_desc = (shots[idx-1].get("description") or "").strip() | |
| composed = ( | |
| f"Continue the same scene {seconds_forward} seconds later.\n" | |
| f'PRIORITIZE this new moment & composition: "{curr_desc}".\n' | |
| "Keep continuity ONLY for subject identity, lighting palette, time of day, environment style.\n" | |
| f'Previous frame (context only, do not copy its framing): "{prev_desc}".\n' | |
| f"Avoid replicating the previous composition; allow camera move / subject reposition consistent with {seconds_forward} seconds of progression." | |
| ).strip() | |
| negative = (curr_neg + "; identical composition as previous; exact same framing; rigid pose repeat; freeze frame; " | |
| "hard scene reset; different subject identity; wildly different art style; unrelated background").strip("; ") | |
| return composed, negative | |
| def generate_keyframe_image( | |
| pid: str, shot_idx: int, shots: list, | |
| t2i_steps: int = 18, i2i_steps: int = 22, i2i_strength: float = 0.90, | |
| guidance_scale: float = 3.4, width: int = 640, height: int = 640, | |
| seconds_forward: int = 5, aggressive: bool = False | |
| ): | |
| try: | |
| t2i, i2i = _lazy_flux_pipes() | |
| except Exception as e: | |
| raise gr.Error(f"FLUX failed to load: {e}") | |
| prompt, negative = _compose_temporal_prompt(shots, shot_idx, seconds_forward=seconds_forward) | |
| seed = shots[shot_idx].get("seed", None) | |
| device = "cuda" if USE_CUDA else "cpu" | |
| gen = torch.Generator(device) | |
| if isinstance(seed, int): gen = gen.manual_seed(int(seed)) | |
| width = max(256, min(1024, int(width))) | |
| height = max(256, min(1024, int(height))) | |
| prev_path = shots[shot_idx - 1].get("image_path") if shot_idx > 0 else None | |
| use_prev = bool(shot_idx > 0 and prev_path and os.path.exists(prev_path)) | |
| if aggressive: | |
| i2i_strength = min(0.98, max(i2i_strength, 0.92)) | |
| guidance_scale = max(guidance_scale, 3.6) | |
| i2i_steps = max(i2i_steps, 24) | |
| if not use_prev: | |
| out = t2i( | |
| prompt=prompt, negative_prompt=(negative or None), | |
| num_inference_steps=int(max(10, t2i_steps)), | |
| guidance_scale=float(max(2.4, guidance_scale)), | |
| generator=gen, width=width, height=height | |
| ).images[0] | |
| else: | |
| init_image = Image.open(prev_path).convert("RGB") | |
| out = i2i( | |
| prompt=prompt, negative_prompt=(negative or None), | |
| image=init_image, strength=float(min(max(i2i_strength, 0.70), 0.98)), | |
| num_inference_steps=int(max(14, i2i_steps)), | |
| guidance_scale=float(max(2.4, guidance_scale)), generator=gen | |
| ).images[0] | |
| saved = _save_keyframe(pid, int(shots[shot_idx]["id"]), out) | |
| return saved | |
| # ========================= | |
| # MoviePy lazy install/import | |
| # ========================= | |
| def _ensure_moviepy(): | |
| """ | |
| Import MoviePy lazily. If unavailable, try a best-effort pip install. | |
| If that still fails, raise a clear Gradio error telling the user to rebuild. | |
| Also wires up the bundled ffmpeg from imageio-ffmpeg. | |
| """ | |
| try: | |
| from moviepy.editor import ImageClip, CompositeVideoClip, concatenate_videoclips | |
| from moviepy.video.io.VideoFileClip import VideoFileClip | |
| return ImageClip, CompositeVideoClip, concatenate_videoclips, VideoFileClip | |
| except Exception: | |
| pass # will try to install below | |
| # Try to install at runtime (some Spaces block this) | |
| try: | |
| import sys, subprocess | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", | |
| "moviepy==1.0.3", "imageio>=2.34.0", "imageio-ffmpeg>=0.4.9"]) | |
| # Point MoviePy to a known-good ffmpeg | |
| try: | |
| import imageio_ffmpeg, os as _os | |
| _os.environ["IMAGEIO_FFMPEG_EXE"] = imageio_ffmpeg.get_ffmpeg_exe() | |
| except Exception: | |
| pass | |
| # Try importing again | |
| from moviepy.editor import ImageClip, CompositeVideoClip, concatenate_videoclips | |
| from moviepy.video.io.VideoFileClip import VideoFileClip | |
| return ImageClip, CompositeVideoClip, concatenate_videoclips, VideoFileClip | |
| except Exception as e: | |
| # Final, friendly failure with next steps | |
| import gradio as gr | |
| raise gr.Error( | |
| "MoviePy is not available. Add `moviepy==1.0.3`, `imageio>=2.34.0`, " | |
| "`imageio-ffmpeg>=0.4.9` to requirements.txt and restart/rebuild the Space. " | |
| f"(Runtime install failed with: {type(e).__name__}: {e})" | |
| ) | |
| # ========================= | |
| # Video stitching (pairwise dissolve + final concat) | |
| # ========================= | |
| def _pair_clip_path(pid: str, i: int, j: int) -> str: | |
| return os.path.join(project_dir(pid), "clips", f"pair_{i:02d}_to_{j:02d}.mp4") | |
| def _final_stitched_path(pid: str) -> str: | |
| return os.path.join(project_dir(pid), "clips", "final_stitched.mp4") | |
| def _image_size(path: str): | |
| with Image.open(path) as im: | |
| return im.width, im.height | |
| def _build_pair_clip(img_a: str, img_b: str, out_path: str, fps: int = 24, hold: float = 0.5, crossfade: float = 0.7, resize_to=None): | |
| ImageClip, CompositeVideoClip, concatenate_videoclips, VideoFileClip = _ensure_moviepy() | |
| ca = ImageClip(img_a).set_duration(hold + crossfade) | |
| cb = ImageClip(img_b).set_duration(hold + crossfade).set_start(hold) | |
| if resize_to: | |
| ca = ca.resize(newsize=resize_to) | |
| cb = cb.resize(newsize=resize_to) | |
| ca_x = ca.crossfadeout(crossfade) | |
| cb_x = cb.crossfadein(crossfade) | |
| total = hold + crossfade + hold | |
| comp = CompositeVideoClip([ca_x, cb_x]).set_duration(total) | |
| comp.write_videofile(out_path, fps=fps, codec="libx264", audio=False, preset="medium", | |
| threads=os.cpu_count() or 2, verbose=False, logger=None) | |
| comp.close(); ca.close(); cb.close() | |
| def _build_all_pair_clips(pid: str, shots: list, fps: int = 24, hold: float = 0.5, crossfade: float = 0.7, force_size=None): | |
| paths = [] | |
| base_size = None | |
| if not force_size: | |
| for s in shots: | |
| p = s.get("image_path") | |
| if p and os.path.exists(p): | |
| base_size = _image_size(p) | |
| break | |
| size = force_size or base_size | |
| for i in range(len(shots)-1): | |
| a = shots[i].get("image_path") | |
| b = shots[i+1].get("image_path") | |
| if not (a and b and os.path.exists(a) and os.path.exists(b)): continue | |
| outp = _pair_clip_path(pid, shots[i]["id"], shots[i+1]["id"]) | |
| _build_pair_clip(a, b, outp, fps=fps, hold=hold, crossfade=crossfade, resize_to=size) | |
| paths.append(outp) | |
| return paths | |
| def _build_final_stitched_from_pairs(pair_paths: list, out_path: str, fps: int = 24): | |
| ImageClip, CompositeVideoClip, concatenate_videoclips, VideoFileClip = _ensure_moviepy() | |
| if not pair_paths: raise RuntimeError("No pair clips to stitch.") | |
| clips = [VideoFileClip(p) for p in pair_paths if os.path.exists(p)] | |
| if not clips: raise RuntimeError("No readable pair clips on disk.") | |
| final = concatenate_videoclips(clips, method="compose") | |
| final.write_videofile(out_path, fps=fps, codec="libx264", audio=False, preset="medium", | |
| threads=os.cpu_count() or 2, verbose=False, logger=None) | |
| final.close() | |
| for c in clips: c.close() | |
| # ========================= | |
| # Shots <-> DataFrame utils | |
| # ========================= | |
| SHOT_COLUMNS = ["id", "title", "description", "duration", "fps", "steps", "seed", "negative", "image_path"] | |
| def shots_to_df(shots: list) -> pd.DataFrame: | |
| rows = [{k: s.get(k, None) for k in SHOT_COLUMNS} for s in shots] | |
| return pd.DataFrame(rows, columns=SHOT_COLUMNS) | |
| def df_to_shots(df: pd.DataFrame) -> list: | |
| out = [] | |
| for _, row in df.iterrows(): | |
| out.append({ | |
| "id": int(row["id"]), | |
| "title": (row["title"] or f"Shot {int(row['id'])}"), | |
| "description": row["description"] or "", | |
| "duration": int(row["duration"]) if pd.notna(row["duration"]) else 4, | |
| "fps": int(row["fps"]) if pd.notna(row["fps"]) else 24, | |
| "steps": int(row["steps"]) if pd.notna(row["steps"]) else 30, | |
| "seed": (int(row["seed"]) if pd.notna(row["seed"]) else None), | |
| "negative": row["negative"] or "", | |
| "image_path": row["image_path"] if pd.notna(row["image_path"]) else None | |
| }) | |
| return sorted(out, key=lambda x: x["id"]) | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π¬ Storyboard β Keyframes β Videos β Export") | |
| gr.Markdown( | |
| "Temporal chaining: each new shot is generated N seconds later from the previous approved frame, " | |
| "while the current shot description drives composition & action. **Model**: FLUX-only." | |
| ) | |
| project = gr.State(None) | |
| current_idx = gr.State(0) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| proj_name = gr.Textbox(label="Project name", placeholder="e.g., Desert Chase") | |
| with gr.Column(scale=1): | |
| new_btn = gr.Button("New Project", variant="primary") | |
| with gr.Column(scale=1): | |
| save_btn = gr.Button("Save Project") | |
| with gr.Column(scale=1): | |
| load_file = gr.File(label="Load Project (project.json)", file_count="single", type="filepath") | |
| load_btn = gr.Button("Load") | |
| sb_status = gr.Markdown("") | |
| with gr.Tabs(): | |
| with gr.Tab("Storyboard"): | |
| gr.Markdown("### 1) Storyboard") | |
| sb_prompt = gr.Textbox(label="High-level prompt", lines=4, placeholder="Describe the storyβ¦") | |
| with gr.Row(): | |
| sb_target_shots = gr.Slider(1, 12, value=3, step=1, label="Target # of shots") | |
| sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS") | |
| sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds/shot") | |
| propose_btn = gr.Button("Propose Storyboard (LLM)") | |
| shots_df = gr.Dataframe( | |
| headers=SHOT_COLUMNS, | |
| datatype=["number","str","str","number","number","number","number","str","str"], | |
| row_count=(1,"dynamic"), col_count=len(SHOT_COLUMNS), | |
| label="Edit shots (prompts & params)", wrap=True | |
| ) | |
| save_edits_btn = gr.Button("Save Edits β", variant="primary", interactive=False) | |
| with gr.Row(): | |
| proj_seed_box = gr.Number(label="Project Seed (locked across shots)", precision=0) | |
| to_keyframes_btn = gr.Button("Start Keyframes β", variant="secondary") | |
| with gr.Tab("Keyframes"): | |
| gr.Markdown("### 2) Keyframes") | |
| shot_info_md = gr.Markdown("") | |
| prompt_box = gr.Textbox(label="Shot description (editable)", lines=4) | |
| with gr.Row(): | |
| gen_btn = gr.Button("Generate / Regenerate", variant="primary") | |
| approve_next_btn = gr.Button("Approve & Next β", variant="secondary") | |
| with gr.Row(): | |
| img_strength = gr.Slider(0.50, 0.98, value=0.90, step=0.02, label="Change vs Consistency (img2img strength)") | |
| img_steps = gr.Slider(12, 28, value=22, step=1, label="Inference Steps (img2img)") | |
| guidance = gr.Slider(2.4, 4.0, value=3.4, step=0.1, label="Guidance Scale") | |
| temporal_secs = gr.Slider(1, 10, value=5, step=1, label="Temporal step (seconds later)") | |
| aggressive_follow = gr.Checkbox(value=False, label="Aggressive follow prompt (more change)") | |
| with gr.Row(): | |
| prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath") | |
| out_img = gr.Image(label="Generated image", type="filepath") | |
| kf_status = gr.Markdown("") | |
| with gr.Tab("Videos"): | |
| gr.Markdown("### 3) Videos") | |
| with gr.Row(): | |
| v_fps = gr.Slider(8, 60, value=24, step=1, label="FPS") | |
| v_hold = gr.Slider(0.0, 2.0, value=0.5, step=0.1, label="Hold per still (s)") | |
| v_xfade = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Crossfade (s)") | |
| with gr.Row(): | |
| build_pairs_btn = gr.Button("Build pair clips (AβB, BβC, ...)", variant="primary") | |
| build_final_btn = gr.Button("Build final stitched video", variant="secondary") | |
| vd_table = gr.JSON(label="Rendered outputs (paths)") | |
| with gr.Tab("Export"): | |
| gr.Markdown("### 4) Export (coming next)") | |
| export_info = gr.Markdown("Nothing to export yet.") | |
| # ---------- Handlers ---------- | |
| def on_new(name): | |
| p = ensure_project(None, suggested_name=(name or "Project")) | |
| return p, gr.update(value=f"**New project created** `{p['meta']['name']}` (id: `{p['meta']['id']}`)") | |
| new_btn.click(on_new, inputs=[proj_name], outputs=[project, sb_status]) | |
| def on_propose(p, prompt, target_shots, fps, vlen): | |
| p = ensure_project(p, suggested_name=(proj_name.value if hasattr(proj_name, "value") else "Project")) | |
| if not str(prompt or "").strip(): | |
| raise gr.Error("Please enter a high-level prompt.") | |
| shots = generate_storyboard_with_llm(str(prompt).strip(), int(target_shots), int(fps), int(vlen)) | |
| p = dict(p); p["shots"] = shots; p["meta"]["updated"] = now_iso(); save_project(p) | |
| return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable)."), gr.update(interactive=True) | |
| propose_btn.click(on_propose, | |
| inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len], | |
| outputs=[project, shots_df, sb_status, save_edits_btn] | |
| ) | |
| def on_save_edits(p, df): | |
| if p is None: raise gr.Error("No project in memory.") | |
| if df is None: raise gr.Error("No storyboard table to save.") | |
| shots = df_to_shots(df) | |
| p = dict(p); p["shots"] = shots; p["meta"]["updated"] = now_iso(); save_project(p) | |
| return p, gr.update(value="Edits saved.") | |
| save_edits_btn.click(on_save_edits, inputs=[project, shots_df], outputs=[project, sb_status]) | |
| def on_start_keyframes(p, df, proj_seed_override): | |
| if p is None: raise gr.Error("No project.") | |
| shots = df_to_shots(df) | |
| if not shots: raise gr.Error("Storyboard is empty.") | |
| proj_seed = None | |
| if str(proj_seed_override or "").isdigit(): proj_seed = int(proj_seed_override) | |
| if proj_seed is None: proj_seed = p.get("meta", {}).get("seed") | |
| if proj_seed is None: | |
| for s in shots: | |
| if isinstance(s.get("seed"), int): proj_seed = int(s["seed"]); break | |
| if proj_seed is None: proj_seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) | |
| for s in shots: | |
| if not isinstance(s.get("seed"), int): s["seed"] = proj_seed | |
| p = dict(p); p["shots"] = shots; p["meta"]["seed"] = proj_seed; p["meta"]["updated"] = now_iso(); save_project(p) | |
| idx = 0; prev_path = None | |
| info = (f"**Shot {shots[idx]['id']} β {shots[idx]['title']}** \n" | |
| f"Duration: {shots[idx]['duration']}s @ {shots[idx]['fps']} fps \n" | |
| f"Locked project seed: `{proj_seed}`") | |
| return p, 0, gr.update(value=info), gr.update(value=shots[idx]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value="Ready for shot 1."), gr.update(value=proj_seed) | |
| to_keyframes_btn.click(on_start_keyframes, | |
| inputs=[project, shots_df, proj_seed_box], | |
| outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status, proj_seed_box] | |
| ) | |
| def on_generate_img(p, idx, current_prompt, i2i_strength_val, i2i_steps_val, guidance_val, seconds_forward_val, aggressive_val): | |
| if p is None: raise gr.Error("No project.") | |
| shots = p["shots"] | |
| if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.") | |
| shots[idx]["description"] = current_prompt | |
| img_path = generate_keyframe_image( | |
| p["meta"]["id"], int(idx), shots, | |
| t2i_steps=18, i2i_steps=int(i2i_steps_val), | |
| i2i_strength=float(i2i_strength_val), | |
| guidance_scale=float(guidance_val), | |
| width=640, height=640, | |
| seconds_forward=int(seconds_forward_val), | |
| aggressive=bool(aggressive_val) | |
| ) | |
| prev_path = shots[idx-1]["image_path"] if idx > 0 else None | |
| return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.") | |
| gen_btn.click(on_generate_img, | |
| inputs=[project, current_idx, prompt_box, img_strength, img_steps, guidance, temporal_secs, aggressive_follow], | |
| outputs=[out_img, prev_img, kf_status] | |
| ) | |
| def on_approve_next(p, idx, current_prompt, latest_img_path): | |
| if p is None: raise gr.Error("No project.") | |
| shots = p["shots"]; i = int(idx) | |
| if i < 0 or i >= len(shots): raise gr.Error("Invalid shot index.") | |
| if not latest_img_path: raise gr.Error("Generate an image first.") | |
| shots[i]["description"] = current_prompt | |
| shots[i]["image_path"] = latest_img_path | |
| p["shots"] = shots; p["meta"]["updated"] = now_iso(); save_project(p) | |
| if i + 1 < len(shots): | |
| ni = i + 1 | |
| info = (f"**Shot {shots[ni]['id']} β {shots[ni]['title']}** \n" | |
| f"Duration: {shots[ni]['duration']}s @ {shots[ni]['fps']} fps \n" | |
| f"Locked project seed: `{p['meta'].get('seed')}`") | |
| prev_path = shots[ni-1]["image_path"] | |
| return p, ni, gr.update(value=info), gr.update(value=shots[ni]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value=f"Approved shot {shots[i]['id']}. On to shot {shots[ni]['id']}.") | |
| else: | |
| return p, i, gr.update(value="**All keyframes approved.** Proceed to Videos tab."), gr.update(value=""), gr.update(value=shots[i]["image_path"]), gr.update(value=None), gr.update(value="All shots approved β ") | |
| approve_next_btn.click(on_approve_next, | |
| inputs=[project, current_idx, prompt_box, out_img], | |
| outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status] | |
| ) | |
| # ---- Videos tab | |
| def on_build_pairs(p, fps, hold, xfade): | |
| if p is None: raise gr.Error("No project.") | |
| shots = p.get("shots", []) | |
| if len(shots) < 2: raise gr.Error("Need at least 2 approved images.") | |
| if not any(s.get("image_path") for s in shots): raise gr.Error("No approved images yet.") | |
| pair_paths = _build_all_pair_clips( | |
| p["meta"]["id"], shots, | |
| fps=int(fps), hold=float(hold), crossfade=float(xfade), | |
| force_size=None | |
| ) | |
| if not pair_paths: raise gr.Error("No consecutive pairs with images found.") | |
| return {"pair_clips": pair_paths, "final": None} | |
| build_pairs_btn.click(on_build_pairs, inputs=[project, v_fps, v_hold, v_xfade], outputs=[vd_table]) | |
| def on_build_final(p, fps): | |
| if p is None: raise gr.Error("No project.") | |
| pid = p["meta"]["id"] | |
| clips_dir = os.path.join(project_dir(pid), "clips") | |
| pair_paths = sorted([os.path.join(clips_dir, f) for f in os.listdir(clips_dir) | |
| if f.startswith("pair_") and f.endswith(".mp4")]) | |
| if not pair_paths: raise gr.Error("No pair clips found. Build pair clips first.") | |
| outp = _final_stitched_path(pid) | |
| _build_final_stitched_from_pairs(pair_paths, outp, fps=int(fps)) | |
| return {"pair_clips": pair_paths, "final": outp} | |
| build_final_btn.click(on_build_final, inputs=[project, v_fps], outputs=[vd_table]) | |
| # save/load | |
| def on_save(p): | |
| if p is None: raise gr.Error("No project in memory.") | |
| path = save_project(p); return gr.update(value=f"Saved to `{path}`") | |
| save_btn.click(on_save, inputs=[project], outputs=[sb_status]) | |
| def on_load(file_obj): | |
| p = load_project_file(file_obj) | |
| seed_val = p.get("meta", {}).get("seed", None) | |
| return (p, | |
| gr.update(value=f"Loaded `{p['meta']['name']}` (id: `{p['meta']['id']}`)"), | |
| shots_to_df(p.get("shots", [])), | |
| gr.update(value=seed_val)) | |
| load_btn.click(on_load, inputs=[load_file], outputs=[project, sb_status, shots_df, proj_seed_box]) | |
| if __name__ == "__main__": | |
| _flux_healthcheck() | |
| demo.launch() | |