Spaces:
Running
on
Zero
Running
on
Zero
| import os, json, uuid, re | |
| from datetime import datetime | |
| import gradio as gr | |
| import spaces # ZeroGPU decorator | |
| import torch | |
| # ========================= | |
| # Storage helpers | |
| # ========================= | |
| ROOT = "outputs" | |
| os.makedirs(ROOT, exist_ok=True) | |
| def now_iso(): return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" | |
| def new_id(): return uuid.uuid4().hex[:8] | |
| def project_dir(pid): | |
| path = os.path.join(ROOT, pid) | |
| os.makedirs(path, exist_ok=True) | |
| os.makedirs(os.path.join(path, "keyframes"), exist_ok=True) | |
| os.makedirs(os.path.join(path, "clips"), exist_ok=True) | |
| return path | |
| def save_project(proj): | |
| pid = proj["meta"]["id"] | |
| path = os.path.join(project_dir(pid), "project.json") | |
| with open(path, "w") as f: json.dump(proj, f, indent=2) | |
| return path | |
| def load_project_file(file_obj): | |
| with open(file_obj.name, "r") as f: | |
| proj = json.load(f) | |
| project_dir(proj["meta"]["id"]) # ensure dirs | |
| return proj | |
| # ========================= | |
| # LLM (ZeroGPU) — Storyboard generator (robust, two-pass + empty fallback) | |
| # ========================= | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct") | |
| HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "1200")) # give a bit more room | |
| _tokenizer = None | |
| _model = None | |
| def _lazy_model_tok(): | |
| global _tokenizer, _model | |
| if _tokenizer is not None and _model is not None: | |
| return _model, _tokenizer | |
| _tokenizer = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| STORYBOARD_MODEL, | |
| device_map="auto", | |
| dtype="auto", | |
| trust_remote_code=True, | |
| ) | |
| # Ensure pad token exists to avoid warnings | |
| if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None: | |
| _tokenizer.pad_token_id = _tokenizer.eos_token_id | |
| return _model, _tokenizer | |
| def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str: | |
| return ( | |
| "Return ONLY a JSON array, enclosed between <JSON> and </JSON>.\n" | |
| f"Create a storyboard of {n_shots} shots for this idea:\n\n" | |
| f"'''{user_prompt}'''\n\n" | |
| "Each item schema:\n" | |
| "{\n" | |
| ' \"id\": <int starting at 1>,\n' | |
| ' \"title\": \"Short title\",\n' | |
| ' \"description\": \"Visual description for keyframe generation\",\n' | |
| f" \"duration\": {default_len},\n" | |
| f" \"fps\": {default_fps},\n" | |
| f" \"video_length\": {default_len},\n" | |
| " \"steps\": 30,\n" | |
| " \"seed\": null,\n" | |
| ' \"negative\": \"\"\n' | |
| "}\n\n" | |
| "Output:\n<JSON>\n[ { ... }, ... ]\n</JSON>\n" | |
| ) | |
| def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str: | |
| return ( | |
| "Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n" | |
| f"Storyboard: {n_shots} shots for:\n'''{user_prompt}'''\n" | |
| "Item schema:\n" | |
| "{\n" | |
| ' \"id\": <int starting at 1>,\n' | |
| ' \"title\": \"Short title\",\n' | |
| ' \"description\": \"Visual description\",\n' | |
| f" \"duration\": {default_len},\n" | |
| f" \"fps\": {default_fps},\n" | |
| f" \"video_length\": {default_len},\n" | |
| " \"steps\": 30,\n" | |
| " \"seed\": null,\n" | |
| ' \"negative\": \"\"\n' | |
| "}\n" | |
| ) | |
| def _apply_chat(tok, system_msg: str, user_msg: str) -> str: | |
| if hasattr(tok, "apply_chat_template"): | |
| return tok.apply_chat_template( | |
| [{"role": "system", "content": system_msg}, | |
| {"role": "user", "content": user_msg}], | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return system_msg + "\n\n" + user_msg | |
| def _generate_text(model, tok, prompt_text: str) -> str: | |
| inputs = tok(prompt_text, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| eos_id = tok.eos_token_id | |
| gen = model.generate( | |
| **inputs, | |
| max_new_tokens=HF_TASK_MAX_TOKENS, | |
| do_sample=False, | |
| temperature=0.0, | |
| repetition_penalty=1.05, | |
| eos_token_id=eos_id, | |
| pad_token_id=tok.pad_token_id if tok.pad_token_id is not None else eos_id, | |
| ) | |
| text = tok.decode(gen[0], skip_special_tokens=True) | |
| if text.startswith(prompt_text): | |
| text = text[len(prompt_text):] | |
| # strip code fences if present | |
| text = text.strip() | |
| if text.startswith("```"): | |
| text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE|re.DOTALL).strip() | |
| return text | |
| def _extract_json_array(text: str) -> str: | |
| m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE) | |
| if m: | |
| inner = m.group(1).strip() | |
| if inner: | |
| return inner | |
| # Fallback: first balanced array | |
| start = text.find("[") | |
| if start == -1: | |
| return "" | |
| depth = 0 | |
| for i in range(start, len(text)): | |
| ch = text[i] | |
| if ch == "[": | |
| depth += 1 | |
| elif ch == "]": | |
| depth -= 1 | |
| if depth == 0: | |
| return text[start:i+1].strip() | |
| return "" | |
| def _normalize_shots(shots_raw, default_fps: int, default_len: int): | |
| norm = [] | |
| for i, s in enumerate(shots_raw, start=1): | |
| norm.append({ | |
| "id": int(s.get("id", i)), | |
| "title": s.get("title", f"Shot {i}"), | |
| "description": s.get("description", ""), | |
| "duration": int(s.get("duration", default_len)), | |
| "fps": int(s.get("fps", default_fps)), | |
| "video_length": int(s.get("video_length", default_len)), | |
| "steps": int(s.get("steps", 30)), | |
| "seed": s.get("seed", None), | |
| "negative": s.get("negative", ""), | |
| "keyframe_path": None | |
| }) | |
| return norm | |
| def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int): | |
| """ | |
| Two-pass generation with robust parsing and empty-output fallback. | |
| """ | |
| model, tok = _lazy_model_tok() | |
| system = "You are a film previsualization assistant. Output must be valid JSON." | |
| # PASS 1: with <JSON> tags | |
| p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.", | |
| _prompt_with_tags(user_prompt, n_shots, default_fps, default_len)) | |
| out1 = _generate_text(model, tok, p1) | |
| print(f"[DEBUG] LLM raw out1 (first 240 chars): {out1[:240]}") | |
| json_text = _extract_json_array(out1) | |
| # PASS 2: strict array fallback | |
| if not json_text: | |
| p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.", | |
| _prompt_minimal(user_prompt, n_shots, default_fps, default_len)) | |
| out2 = _generate_text(model, tok, p2) | |
| print(f"[DEBUG] LLM raw out2 (first 240 chars): {out2[:240]}") | |
| json_text = _extract_json_array(out2) | |
| if not json_text and "[" in out2 and "]" in out2: | |
| start = out2.find("["); end = out2.rfind("]") | |
| if start != -1 and end != -1 and end > start: | |
| json_text = out2[start:end+1].strip() | |
| # EMPTY FALLBACK → return a single stub so the app does not crash | |
| if not json_text or not json_text.strip(): | |
| print("⚠️ LLM returned empty or unparsable JSON. Using fallback storyboard.") | |
| return [{ | |
| "id": 1, | |
| "title": "Shot 1", | |
| "description": f"Fallback shot for: {user_prompt[:80]}", | |
| "duration": default_len, | |
| "fps": default_fps, | |
| "video_length": default_len, | |
| "steps": 30, | |
| "seed": None, | |
| "negative": "", | |
| "keyframe_path": None | |
| }] | |
| # Parse & normalize (with tiny trailing-comma cleanup) | |
| try: | |
| shots_raw = json.loads(json_text) | |
| except Exception: | |
| json_text_clean = re.sub(r",\s*([\]\}])", r"\1", json_text) | |
| shots_raw = json.loads(json_text_clean) | |
| return _normalize_shots(shots_raw, default_fps, default_len) | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🎬 Storyboard → Keyframes → Videos → Export") | |
| gr.Markdown("**Step 2**: Real storyboard generation on **ZeroGPU**. Next steps will add keyframes (img2img) and your Modal videos.") | |
| # Global state | |
| project = gr.State(None) # dict with meta/shots/clips | |
| current_tab = gr.State("Storyboard") | |
| # Header row | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| proj_name = gr.Textbox(label="Project name", placeholder="e.g., Desert Chase") | |
| with gr.Column(scale=1): | |
| new_btn = gr.Button("New Project", variant="primary") | |
| with gr.Column(scale=1): | |
| save_btn = gr.Button("Save Project") | |
| with gr.Column(scale=1): | |
| load_file = gr.File(label="Load Project (project.json)", file_count="single", type="filepath") | |
| load_btn = gr.Button("Load") | |
| # Tabs | |
| with gr.Tabs(): | |
| with gr.Tab("Storyboard"): | |
| gr.Markdown("### 1) Storyboard") | |
| sb_prompt = gr.Textbox(label="High-level prompt", lines=4, placeholder="Describe the story you want to create…") | |
| with gr.Row(): | |
| sb_target_shots = gr.Slider(1, 12, value=3, step=1, label="Target # of shots") | |
| sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS") | |
| sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds per shot") | |
| propose_btn = gr.Button("Propose Storyboard (LLM on ZeroGPU)") | |
| shots_json = gr.JSON(label="Storyboard JSON (editable in next step)") | |
| confirm_btn = gr.Button("Confirm Storyboard ✓", variant="primary") | |
| sb_status = gr.Markdown("") | |
| with gr.Tab("Keyframes"): | |
| gr.Markdown("### 2) Keyframes (coming next)") | |
| kf_table = gr.JSON(label="Shots (read-only for now)") | |
| to_videos_btn = gr.Button("Continue to Videos →", interactive=False) | |
| with gr.Tab("Videos"): | |
| gr.Markdown("### 3) Videos (coming next)") | |
| vd_table = gr.JSON(label="Planned clip edges (read-only for now)") | |
| to_export_btn = gr.Button("Continue to Export →", interactive=False) | |
| with gr.Tab("Export"): | |
| gr.Markdown("### 4) Export (coming next)") | |
| export_info = gr.Markdown("Nothing to export yet.") | |
| # -------- Handlers -------- | |
| def on_new(name): | |
| name = (name or "").strip() or f"Project-{new_id()}" | |
| pid = new_id() | |
| p = { | |
| "meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()}, | |
| "shots": [], | |
| "clips": [] | |
| } | |
| save_project(p) | |
| return p, gr.update(value=f"**New project created** `{name}` (id: `{pid}`)") | |
| new_btn.click(on_new, inputs=[proj_name], outputs=[project, sb_status]) | |
| def on_propose(p, prompt, target_shots, fps, vlen): | |
| if p is None: | |
| raise gr.Error("Create a project first (New Project).") | |
| if not prompt or not str(prompt).strip(): | |
| raise gr.Error("Please enter a high-level prompt.") | |
| shots = generate_storyboard_with_llm(str(prompt).strip(), int(target_shots), int(fps), int(vlen)) | |
| p = dict(p) | |
| p["shots"] = shots | |
| p["meta"]["updated"] = now_iso() | |
| save_project(p) | |
| return p, shots, gr.update(value="Storyboard generated by LLM (ZeroGPU).") | |
| propose_btn.click( | |
| on_propose, | |
| inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len], | |
| outputs=[project, shots_json, sb_status] | |
| ) | |
| def on_confirm(p): | |
| if p is None or not p.get("shots"): | |
| raise gr.Error("No storyboard yet.") | |
| edges = [] | |
| for i in range(len(p["shots"]) - 1): | |
| a = p["shots"][i]["id"] | |
| b = p["shots"][i+1]["id"] | |
| edges.append({"from": a, "to": b, "prompt": f"Transition from shot {a} to {b}"}) | |
| p = dict(p) | |
| p["clips"] = edges | |
| p["meta"]["updated"] = now_iso() | |
| save_project(p) | |
| return ( | |
| p, | |
| gr.update(value=p["shots"]), | |
| gr.update(value=p["clips"]), | |
| gr.update(value="Storyboard confirmed. Proceed to Keyframes."), | |
| gr.update(interactive=True) | |
| ) | |
| confirm_btn.click( | |
| on_confirm, | |
| inputs=[project], | |
| outputs=[project, kf_table, vd_table, sb_status, to_videos_btn] | |
| ) | |
| def on_save(p): | |
| if p is None: | |
| raise gr.Error("No project in memory.") | |
| path = save_project(p) | |
| return gr.update(value=f"Saved to `{path}`") | |
| save_btn.click(on_save, inputs=[project], outputs=[sb_status]) | |
| def on_load(file_obj): | |
| p = load_project_file(file_obj) | |
| return ( | |
| p, | |
| gr.update(value=f"Loaded project `{p['meta']['name']}` (id: `{p['meta']['id']}`)"), | |
| gr.update(value=p["shots"]), | |
| gr.update(value=p["clips"]), | |
| gr.update(interactive=bool(p.get("shots"))) | |
| ) | |
| load_btn.click( | |
| on_load, | |
| inputs=[load_file], | |
| outputs=[project, sb_status, kf_table, vd_table, to_videos_btn] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |