Prompt-2-Video / app.py
Shalmoni's picture
Update app.py
3d81823 verified
raw
history blame
20.1 kB
import os, json, uuid, re
from datetime import datetime
import gradio as gr
import spaces # ZeroGPU decorator
import torch
from PIL import Image
# =========================
# Storage helpers
# =========================
ROOT = "outputs"
os.makedirs(ROOT, exist_ok=True)
def now_iso(): return datetime.utcnow().replace(microsecond=0).isoformat() + "Z"
def new_id(): return uuid.uuid4().hex[:8]
def project_dir(pid):
path = os.path.join(ROOT, pid)
os.makedirs(path, exist_ok=True)
os.makedirs(os.path.join(path, "keyframes"), exist_ok=True)
os.makedirs(os.path.join(path, "clips"), exist_ok=True)
return path
def save_project(proj):
pid = proj["meta"]["id"]
path = os.path.join(project_dir(pid), "project.json")
with open(path, "w") as f: json.dump(proj, f, indent=2)
return path
def load_project_file(file_obj):
with open(file_obj.name, "r") as f:
proj = json.load(f)
project_dir(proj["meta"]["id"]) # ensure dirs
return proj
def ensure_project(p, suggested_name="Project"):
if p is not None:
return p
pid = new_id()
name = f"{suggested_name}-{pid[:4]}"
proj = {
"meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()},
"shots": [], # each: id,title,description,duration,fps,steps,seed,negative, image_path?(on approval)
"clips": []
}
save_project(proj)
return proj
# =========================
# LLM (ZeroGPU) β€” Storyboard generator (robust, two-pass + empty fallback)
# =========================
from transformers import AutoTokenizer, AutoModelForCausalLM
STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "1200"))
_tokenizer = None
_model = None
def _lazy_model_tok():
global _tokenizer, _model
if _tokenizer is not None and _model is not None:
return _model, _tokenizer
_tokenizer = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
STORYBOARD_MODEL,
device_map="auto",
dtype="auto",
trust_remote_code=True,
)
if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
_tokenizer.pad_token_id = _tokenizer.eos_token_id
return _model, _tokenizer
def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
return (
"Return ONLY a JSON array, enclosed between <JSON> and </JSON>.\n"
f"Create a storyboard of {n_shots} shots for this idea:\n\n"
f"'''{user_prompt}'''\n\n"
"Each item schema:\n"
"{\n"
' \"id\": <int starting at 1>,\n'
' \"title\": \"Short title\",\n'
' \"description\": \"Visual description for keyframe generation\",\n'
f" \"duration\": {default_len},\n"
f" \"fps\": {default_fps},\n"
" \"steps\": 30,\n"
" \"seed\": null,\n"
' \"negative\": \"\"\n'
"}\n\n"
"Output:\n<JSON>\n[ { ... }, ... ]\n</JSON>\n"
)
def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
return (
"Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n"
f"Storyboard: {n_shots} shots for:\n'''{user_prompt}'''\n"
"Item schema:\n"
"{\n"
' \"id\": <int starting at 1>,\n'
' \"title\": \"Short title\",\n'
' \"description\": \"Visual description\",\n'
f" \"duration\": {default_len},\n"
f" \"fps\": {default_fps},\n"
" \"steps\": 30,\n"
" \"seed\": null,\n"
' \"negative\": \"\"\n'
"}\n"
)
def _apply_chat(tok, system_msg: str, user_msg: str) -> str:
if hasattr(tok, "apply_chat_template"):
return tok.apply_chat_template(
[{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg}],
tokenize=False,
add_generation_prompt=True
)
return system_msg + "\n\n" + user_msg
def _generate_text(model, tok, prompt_text: str) -> str:
inputs = tok(prompt_text, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
eos_id = tok.eos_token_id or tok.pad_token_id
gen = model.generate(
**inputs,
max_new_tokens=HF_TASK_MAX_TOKENS,
do_sample=False,
temperature=0.0,
repetition_penalty=1.05,
eos_token_id=eos_id,
pad_token_id=eos_id,
)
# decode only continuation
prompt_len = inputs["input_ids"].shape[1]
continuation_ids = gen[0][prompt_len:]
text = tok.decode(continuation_ids, skip_special_tokens=True).strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE|re.DOTALL).strip()
return text
def _extract_json_array(text: str) -> str:
m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE)
if m:
inner = m.group(1).strip()
if inner:
return inner
# Fallback: first balanced array
start = text.find("[")
if start == -1:
return ""
depth = 0
for i in range(start, len(text)):
ch = text[i]
if ch == "[":
depth += 1
elif ch == "]":
depth -= 1
if depth == 0:
return text[start:i+1].strip()
return ""
def _normalize_shots(shots_raw, default_fps: int, default_len: int):
norm = []
for i, s in enumerate(shots_raw, start=1):
norm.append({
"id": int(s.get("id", i)),
"title": s.get("title", f"Shot {i}"),
"description": s.get("description", ""),
"duration": int(s.get("duration", default_len)),
"fps": int(s.get("fps", default_fps)),
"steps": int(s.get("steps", 30)),
"seed": s.get("seed", None),
"negative": s.get("negative", ""),
"image_path": s.get("image_path", None) # will be set after approval
})
return norm
@spaces.GPU(duration=180)
def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
model, tok = _lazy_model_tok()
system = "You are a film previsualization assistant. Output must be valid JSON."
# PASS 1
p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
_prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
out1 = _generate_text(model, tok, p1)
json_text = _extract_json_array(out1)
# PASS 2 fallback
if not json_text:
p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
_prompt_minimal(user_prompt, n_shots, default_fps, default_len))
out2 = _generate_text(model, tok, p2)
json_text = _extract_json_array(out2)
if not json_text and "[" in out2 and "]" in out2:
start = out2.find("["); end = out2.rfind("]")
if start != -1 and end != -1 and end > start:
json_text = out2[start:end+1].strip()
# EMPTY FALLBACK: simple storyboard so UI never crashes
if not json_text or not json_text.strip():
fallback = []
for i in range(1, int(n_shots) + 1):
fallback.append({
"id": i,
"title": f"Shot {i}",
"description": f"Simple placeholder for: {user_prompt[:80]}",
"duration": default_len,
"fps": default_fps,
"steps": 30,
"seed": None,
"negative": "",
"image_path": None
})
return fallback
try:
shots_raw = json.loads(json_text)
except Exception:
json_text_clean = re.sub(r",\s*([\]\}])", r"\1", json_text)
shots_raw = json.loads(json_text_clean)
return _normalize_shots(shots_raw, default_fps, default_len)
# =========================
# IMAGE GEN (ZeroGPU) β€” SD1.5 text2img + img2img chaining
# =========================
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
SD_MODEL = os.getenv("SD_MODEL", "runwayml/stable-diffusion-v1-5")
_sd_t2i = None
_sd_i2i = None
def _lazy_sd_pipes():
global _sd_t2i, _sd_i2i
if _sd_t2i is not None and _sd_i2i is not None:
return _sd_t2i, _sd_i2i
_sd_t2i = StableDiffusionPipeline.from_pretrained(
SD_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
_sd_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(
SD_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
if torch.cuda.is_available():
_sd_t2i = _sd_t2i.to("cuda")
_sd_i2i = _sd_i2i.to("cuda")
_sd_t2i.safety_checker = None
_sd_i2i.safety_checker = None
return _sd_t2i, _sd_i2i
def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
pdir = project_dir(pid)
out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png")
img.save(out)
return out
@spaces.GPU(duration=180)
def generate_keyframe_image(
pid: str,
shot_idx: int,
shots: list,
guidance_scale: float = 7.5,
strength: float = 0.35
):
"""
Generate image for shots[shot_idx].
- If shot_idx == 0: text2img
- Else: img2img with previous shot's approved image_path as init image
Uses edited fields in shots: description, negative, steps, seed.
"""
t2i, i2i = _lazy_sd_pipes()
shot = shots[shot_idx]
prompt = shot.get("description", "")
negative = shot.get("negative") or ""
steps = int(shot.get("steps", 30))
seed = shot.get("seed", None)
gen = torch.Generator("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(seed, int):
gen = gen.manual_seed(seed)
if shot_idx == 0 or not shots[shot_idx - 1].get("image_path"):
# text2img
out = t2i(prompt=prompt, negative_prompt=negative, guidance_scale=guidance_scale,
num_inference_steps=steps, generator=gen).images[0]
else:
# img2img: previous approved keyframe as conditioning
prev_path = shots[shot_idx - 1]["image_path"]
init_image = Image.open(prev_path).convert("RGB")
out = i2i(prompt=prompt, negative_prompt=negative, image=init_image,
guidance_scale=guidance_scale, strength=strength,
num_inference_steps=steps, generator=gen).images[0]
saved_path = _save_keyframe(pid, int(shot["id"]), out)
return saved_path
# =========================
# Shots <-> Dataframe utils
# =========================
import pandas as pd
SHOT_COLUMNS = ["id", "title", "description", "duration", "fps", "steps", "seed", "negative", "image_path"]
def shots_to_df(shots: list) -> pd.DataFrame:
rows = []
for s in shots:
rows.append({k: s.get(k, None) for k in SHOT_COLUMNS})
df = pd.DataFrame(rows, columns=SHOT_COLUMNS)
return df
def df_to_shots(df: pd.DataFrame) -> list:
out = []
for _, row in df.iterrows():
out.append({
"id": int(row["id"]),
"title": row["title"] or f"Shot {int(row['id'])}",
"description": row["description"] or "",
"duration": int(row["duration"]) if pd.notna(row["duration"]) else 4,
"fps": int(row["fps"]) if pd.notna(row["fps"]) else 24,
"steps": int(row["steps"]) if pd.notna(row["steps"]) else 30,
"seed": (int(row["seed"]) if pd.notna(row["seed"]) else None),
"negative": row["negative"] or "",
"image_path": row["image_path"] if pd.notna(row["image_path"]) else None
})
# keep sorted by id
out = sorted(out, key=lambda x: x["id"])
return out
# =========================
# Gradio UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
gr.Markdown("**Step 3**: Edit storyboard, then generate keyframes. Shot 2..N use the previous approved image as reference (img2img).")
# Global state
project = gr.State(None) # dict with meta/shots/clips
current_idx = gr.State(0) # index of current shot in Keyframes tab
# Header row
with gr.Row():
with gr.Column(scale=2):
proj_name = gr.Textbox(label="Project name", placeholder="e.g., Desert Chase")
with gr.Column(scale=1):
new_btn = gr.Button("New Project", variant="primary")
with gr.Column(scale=1):
save_btn = gr.Button("Save Project")
with gr.Column(scale=1):
load_file = gr.File(label="Load Project (project.json)", file_count="single", type="filepath")
load_btn = gr.Button("Load")
# Tabs
with gr.Tabs():
with gr.Tab("Storyboard"):
gr.Markdown("### 1) Storyboard")
sb_prompt = gr.Textbox(label="High-level prompt", lines=4, placeholder="Describe the story you want to create…")
with gr.Row():
sb_target_shots = gr.Slider(1, 12, value=3, step=1, label="Target # of shots")
sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS")
sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds per shot")
propose_btn = gr.Button("Propose Storyboard (LLM on ZeroGPU)")
shots_df = gr.Dataframe(headers=SHOT_COLUMNS, datatype=["number","str","str","number","number","number","number","str","str"], row_count=(1,"dynamic"), col_count=len(SHOT_COLUMNS), label="Edit shots below", wrap=True)
save_edits_btn = gr.Button("Save Edits βœ“", variant="primary")
to_keyframes_btn = gr.Button("Start Keyframes β†’", variant="secondary")
sb_status = gr.Markdown("")
with gr.Tab("Keyframes"):
gr.Markdown("### 2) Keyframes")
with gr.Row():
shot_info_md = gr.Markdown("")
with gr.Row():
prompt_box = gr.Textbox(label="Shot description (editable before generating)", lines=4)
with gr.Row():
gen_btn = gr.Button("Generate / Regenerate (uses previous approved image if available)", variant="primary")
approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
with gr.Row():
prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
out_img = gr.Image(label="Generated image", type="filepath")
kf_status = gr.Markdown("")
with gr.Tab("Videos"):
gr.Markdown("### 3) Videos (coming next)")
vd_table = gr.JSON(label="Planned clip edges (read-only for now)")
with gr.Tab("Export"):
gr.Markdown("### 4) Export (coming next)")
export_info = gr.Markdown("Nothing to export yet.")
# -------- Handlers --------
def on_new(name):
p = ensure_project(None, suggested_name=(name or "Project"))
return p, gr.update(value=f"**New project created** `{p['meta']['name']}` (id: `{p['meta']['id']}`)")
new_btn.click(on_new, inputs=[proj_name], outputs=[project, sb_status])
def on_propose(p, prompt, target_shots, fps, vlen):
p = ensure_project(p, suggested_name=(proj_name.value if hasattr(proj_name, "value") else "Project"))
if not prompt or not str(prompt).strip():
raise gr.Error("Please enter a high-level prompt.")
shots = generate_storyboard_with_llm(str(prompt).strip(), int(target_shots), int(fps), int(vlen))
p = dict(p)
p["shots"] = shots
p["meta"]["updated"] = now_iso()
save_project(p)
return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable).")
propose_btn.click(
on_propose,
inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len],
outputs=[project, shots_df, sb_status]
)
def on_save_edits(p, df):
if p is None:
raise gr.Error("No project in memory.")
shots = df_to_shots(df)
p = dict(p)
p["shots"] = shots
p["meta"]["updated"] = now_iso()
save_project(p)
return p, gr.update(value="Edits saved.")
save_edits_btn.click(on_save_edits, inputs=[project, shots_df], outputs=[project, sb_status])
def on_start_keyframes(p, df):
if p is None: raise gr.Error("No project.")
shots = df_to_shots(df)
if not shots: raise gr.Error("Storyboard is empty.")
p = dict(p); p["shots"] = shots; p["meta"]["updated"] = now_iso(); save_project(p)
idx = 0
prev_path = None
info = f"**Shot {shots[idx]['id']} β€” {shots[idx]['title']}** \nDuration: {shots[idx]['duration']}s @ {shots[idx]['fps']} fps"
return p, 0, gr.update(value=info), gr.update(value=shots[idx]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value="Ready to generate shot 1.")
to_keyframes_btn.click(on_start_keyframes, inputs=[project, shots_df], outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status])
def on_generate_img(p, idx, current_prompt):
if p is None: raise gr.Error("No project.")
shots = p["shots"]
if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.")
# Allow in-place prompt tweak before generation
shots[idx]["description"] = current_prompt
prev_path = shots[idx-1]["image_path"] if idx > 0 else None
img_path = generate_keyframe_image(p["meta"]["id"], int(idx), shots)
return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.")
gen_btn.click(on_generate_img, inputs=[project, current_idx, prompt_box], outputs=[out_img, prev_img, kf_status])
def on_approve_next(p, idx, current_prompt, latest_img_path):
if p is None: raise gr.Error("No project.")
shots = p["shots"]
i = int(idx)
if i < 0 or i >= len(shots): raise gr.Error("Invalid shot index.")
if not latest_img_path: raise gr.Error("Generate an image first.")
# commit prompt and image path
shots[i]["description"] = current_prompt
shots[i]["image_path"] = latest_img_path
p["shots"] = shots
p["meta"]["updated"] = now_iso()
save_project(p)
# Move to next
if i + 1 < len(shots):
ni = i + 1
info = f"**Shot {shots[ni]['id']} β€” {shots[ni]['title']}** \nDuration: {shots[ni]['duration']}s @ {shots[ni]['fps']} fps"
prev_path = shots[ni-1]["image_path"]
return p, ni, gr.update(value=info), gr.update(value=shots[ni]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value=f"Approved shot {shots[i]['id']}. On to shot {shots[ni]['id']}.")
else:
# finished all keyframes
return p, i, gr.update(value="**All keyframes approved.** Proceed to Videos tab."), gr.update(value=""), gr.update(value=shots[i]["image_path"]), gr.update(value=None), gr.update(value="All shots approved βœ…")
approve_next_btn.click(on_approve_next, inputs=[project, current_idx, prompt_box, out_img], outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status])
def on_save(p):
if p is None:
raise gr.Error("No project in memory.")
path = save_project(p)
return gr.update(value=f"Saved to `{path}`")
save_btn.click(on_save, inputs=[project], outputs=[gr.Markdown.update(value="Project saved.")])
def on_load(file_obj):
p = load_project_file(file_obj)
return (
p,
gr.update(value=f"Loaded project `{p['meta']['name']}` (id: `{p['meta']['id']}`)"),
shots_to_df(p.get("shots", [])),
)
load_btn.click(on_load, inputs=[load_file], outputs=[project, sb_status, shots_df])
if __name__ == "__main__":
demo.launch()