Shalmoni commited on
Commit
3d81823
Β·
verified Β·
1 Parent(s): a8b7bac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -74
app.py CHANGED
@@ -3,6 +3,7 @@ from datetime import datetime
3
  import gradio as gr
4
  import spaces # ZeroGPU decorator
5
  import torch
 
6
 
7
  # =========================
8
  # Storage helpers
@@ -33,14 +34,13 @@ def load_project_file(file_obj):
33
  return proj
34
 
35
  def ensure_project(p, suggested_name="Project"):
36
- """Create a fresh project dict if None."""
37
  if p is not None:
38
  return p
39
  pid = new_id()
40
  name = f"{suggested_name}-{pid[:4]}"
41
  proj = {
42
  "meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()},
43
- "shots": [],
44
  "clips": []
45
  }
46
  save_project(proj)
@@ -52,7 +52,7 @@ def ensure_project(p, suggested_name="Project"):
52
  from transformers import AutoTokenizer, AutoModelForCausalLM
53
 
54
  STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
55
- HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "1200")) # give a bit more room
56
 
57
  _tokenizer = None
58
  _model = None
@@ -68,7 +68,6 @@ def _lazy_model_tok():
68
  dtype="auto",
69
  trust_remote_code=True,
70
  )
71
- # Ensure pad token to avoid warnings
72
  if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
73
  _tokenizer.pad_token_id = _tokenizer.eos_token_id
74
  return _model, _tokenizer
@@ -85,7 +84,6 @@ def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_
85
  ' \"description\": \"Visual description for keyframe generation\",\n'
86
  f" \"duration\": {default_len},\n"
87
  f" \"fps\": {default_fps},\n"
88
- f" \"video_length\": {default_len},\n"
89
  " \"steps\": 30,\n"
90
  " \"seed\": null,\n"
91
  ' \"negative\": \"\"\n'
@@ -104,7 +102,6 @@ def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_le
104
  ' \"description\": \"Visual description\",\n'
105
  f" \"duration\": {default_len},\n"
106
  f" \"fps\": {default_fps},\n"
107
- f" \"video_length\": {default_len},\n"
108
  " \"steps\": 30,\n"
109
  " \"seed\": null,\n"
110
  ' \"negative\": \"\"\n'
@@ -122,7 +119,6 @@ def _apply_chat(tok, system_msg: str, user_msg: str) -> str:
122
  return system_msg + "\n\n" + user_msg
123
 
124
  def _generate_text(model, tok, prompt_text: str) -> str:
125
- """Generate and decode only the continuation (no prompt echo)."""
126
  inputs = tok(prompt_text, return_tensors="pt")
127
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
128
  eos_id = tok.eos_token_id or tok.pad_token_id
@@ -136,13 +132,10 @@ def _generate_text(model, tok, prompt_text: str) -> str:
136
  eos_token_id=eos_id,
137
  pad_token_id=eos_id,
138
  )
139
-
140
  # decode only continuation
141
  prompt_len = inputs["input_ids"].shape[1]
142
  continuation_ids = gen[0][prompt_len:]
143
  text = tok.decode(continuation_ids, skip_special_tokens=True).strip()
144
-
145
- # strip code fences if present
146
  if text.startswith("```"):
147
  text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE|re.DOTALL).strip()
148
  return text
@@ -177,44 +170,37 @@ def _normalize_shots(shots_raw, default_fps: int, default_len: int):
177
  "description": s.get("description", ""),
178
  "duration": int(s.get("duration", default_len)),
179
  "fps": int(s.get("fps", default_fps)),
180
- "video_length": int(s.get("video_length", default_len)),
181
  "steps": int(s.get("steps", 30)),
182
  "seed": s.get("seed", None),
183
  "negative": s.get("negative", ""),
184
- "keyframe_path": None
185
  })
186
  return norm
187
 
188
  @spaces.GPU(duration=180)
189
  def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
190
- """
191
- Two-pass generation with robust parsing and empty-output fallback.
192
- """
193
  model, tok = _lazy_model_tok()
194
  system = "You are a film previsualization assistant. Output must be valid JSON."
195
 
196
- # PASS 1: with <JSON> tags
197
  p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
198
  _prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
199
  out1 = _generate_text(model, tok, p1)
200
- print(f"[DEBUG] LLM raw out1 (first 240 chars): {out1[:240]}")
201
  json_text = _extract_json_array(out1)
202
 
203
- # PASS 2: strict array fallback
204
  if not json_text:
205
  p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
206
  _prompt_minimal(user_prompt, n_shots, default_fps, default_len))
207
  out2 = _generate_text(model, tok, p2)
208
- print(f"[DEBUG] LLM raw out2 (first 240 chars): {out2[:240]}")
209
  json_text = _extract_json_array(out2)
210
  if not json_text and "[" in out2 and "]" in out2:
211
  start = out2.find("["); end = out2.rfind("]")
212
  if start != -1 and end != -1 and end > start:
213
  json_text = out2[start:end+1].strip()
214
 
215
- # EMPTY FALLBACK β†’ return a simple storyboard so the app does not crash
216
  if not json_text or not json_text.strip():
217
- print("⚠️ LLM returned empty or unparsable JSON. Using fallback storyboard.")
218
  fallback = []
219
  for i in range(1, int(n_shots) + 1):
220
  fallback.append({
@@ -223,15 +209,13 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
223
  "description": f"Simple placeholder for: {user_prompt[:80]}",
224
  "duration": default_len,
225
  "fps": default_fps,
226
- "video_length": default_len,
227
  "steps": 30,
228
  "seed": None,
229
  "negative": "",
230
- "keyframe_path": None
231
  })
232
  return fallback
233
 
234
- # Parse & normalize (with tiny trailing-comma cleanup)
235
  try:
236
  shots_raw = json.loads(json_text)
237
  except Exception:
@@ -240,16 +224,119 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
240
 
241
  return _normalize_shots(shots_raw, default_fps, default_len)
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # =========================
244
  # Gradio UI
245
  # =========================
246
  with gr.Blocks() as demo:
247
  gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
248
- gr.Markdown("**Step 2**: Real storyboard generation on **ZeroGPU**. Next we’ll add keyframes (img2img) and your Modal videos.")
249
 
250
  # Global state
251
  project = gr.State(None) # dict with meta/shots/clips
252
- current_tab = gr.State("Storyboard")
253
 
254
  # Header row
255
  with gr.Row():
@@ -273,19 +360,28 @@ with gr.Blocks() as demo:
273
  sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS")
274
  sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds per shot")
275
  propose_btn = gr.Button("Propose Storyboard (LLM on ZeroGPU)")
276
- shots_json = gr.JSON(label="Storyboard JSON (editable in next step)")
277
- confirm_btn = gr.Button("Confirm Storyboard βœ“", variant="primary")
 
278
  sb_status = gr.Markdown("")
279
 
280
  with gr.Tab("Keyframes"):
281
- gr.Markdown("### 2) Keyframes (coming next)")
282
- kf_table = gr.JSON(label="Shots (read-only for now)")
283
- to_videos_btn = gr.Button("Continue to Videos β†’", interactive=False)
 
 
 
 
 
 
 
 
 
284
 
285
  with gr.Tab("Videos"):
286
  gr.Markdown("### 3) Videos (coming next)")
287
  vd_table = gr.JSON(label="Planned clip edges (read-only for now)")
288
- to_export_btn = gr.Button("Continue to Export β†’", interactive=False)
289
 
290
  with gr.Tab("Export"):
291
  gr.Markdown("### 4) Export (coming next)")
@@ -293,20 +389,12 @@ with gr.Blocks() as demo:
293
 
294
  # -------- Handlers --------
295
  def on_new(name):
296
- name = (name or "").strip() or f"Project-{new_id()}"
297
- pid = new_id()
298
- p = {
299
- "meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()},
300
- "shots": [],
301
- "clips": []
302
- }
303
- save_project(p)
304
- return p, gr.update(value=f"**New project created** `{name}` (id: `{pid}`)")
305
 
306
  new_btn.click(on_new, inputs=[proj_name], outputs=[project, sb_status])
307
 
308
  def on_propose(p, prompt, target_shots, fps, vlen):
309
- # Auto-create project if user forgot
310
  p = ensure_project(p, suggested_name=(proj_name.value if hasattr(proj_name, "value") else "Project"))
311
  if not prompt or not str(prompt).strip():
312
  raise gr.Error("Please enter a high-level prompt.")
@@ -315,39 +403,74 @@ with gr.Blocks() as demo:
315
  p["shots"] = shots
316
  p["meta"]["updated"] = now_iso()
317
  save_project(p)
318
- return p, shots, gr.update(value="Storyboard generated (or fallback) via LLM on ZeroGPU.")
319
 
320
  propose_btn.click(
321
  on_propose,
322
  inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len],
323
- outputs=[project, shots_json, sb_status]
324
  )
325
 
326
- def on_confirm(p):
327
- if p is None or not p.get("shots"):
328
- raise gr.Error("No storyboard yet.")
329
- edges = []
330
- for i in range(len(p["shots"]) - 1):
331
- a = p["shots"][i]["id"]
332
- b = p["shots"][i+1]["id"]
333
- edges.append({"from": a, "to": b, "prompt": f"Transition from shot {a} to {b}"})
334
  p = dict(p)
335
- p["clips"] = edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  p["meta"]["updated"] = now_iso()
337
  save_project(p)
338
- return (
339
- p,
340
- gr.update(value=p["shots"]),
341
- gr.update(value=p["clips"]),
342
- gr.update(value="Storyboard confirmed. Proceed to Keyframes."),
343
- gr.update(interactive=True)
344
- )
345
 
346
- confirm_btn.click(
347
- on_confirm,
348
- inputs=[project],
349
- outputs=[project, kf_table, vd_table, sb_status, to_videos_btn]
350
- )
 
 
 
 
 
 
351
 
352
  def on_save(p):
353
  if p is None:
@@ -355,23 +478,17 @@ with gr.Blocks() as demo:
355
  path = save_project(p)
356
  return gr.update(value=f"Saved to `{path}`")
357
 
358
- save_btn.click(on_save, inputs=[project], outputs=[sb_status])
359
 
360
  def on_load(file_obj):
361
  p = load_project_file(file_obj)
362
  return (
363
  p,
364
  gr.update(value=f"Loaded project `{p['meta']['name']}` (id: `{p['meta']['id']}`)"),
365
- gr.update(value=p["shots"]),
366
- gr.update(value=p["clips"]),
367
- gr.update(interactive=bool(p.get("shots")))
368
  )
369
 
370
- load_btn.click(
371
- on_load,
372
- inputs=[load_file],
373
- outputs=[project, sb_status, kf_table, vd_table, to_videos_btn]
374
- )
375
 
376
  if __name__ == "__main__":
377
  demo.launch()
 
3
  import gradio as gr
4
  import spaces # ZeroGPU decorator
5
  import torch
6
+ from PIL import Image
7
 
8
  # =========================
9
  # Storage helpers
 
34
  return proj
35
 
36
  def ensure_project(p, suggested_name="Project"):
 
37
  if p is not None:
38
  return p
39
  pid = new_id()
40
  name = f"{suggested_name}-{pid[:4]}"
41
  proj = {
42
  "meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()},
43
+ "shots": [], # each: id,title,description,duration,fps,steps,seed,negative, image_path?(on approval)
44
  "clips": []
45
  }
46
  save_project(proj)
 
52
  from transformers import AutoTokenizer, AutoModelForCausalLM
53
 
54
  STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
55
+ HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "1200"))
56
 
57
  _tokenizer = None
58
  _model = None
 
68
  dtype="auto",
69
  trust_remote_code=True,
70
  )
 
71
  if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
72
  _tokenizer.pad_token_id = _tokenizer.eos_token_id
73
  return _model, _tokenizer
 
84
  ' \"description\": \"Visual description for keyframe generation\",\n'
85
  f" \"duration\": {default_len},\n"
86
  f" \"fps\": {default_fps},\n"
 
87
  " \"steps\": 30,\n"
88
  " \"seed\": null,\n"
89
  ' \"negative\": \"\"\n'
 
102
  ' \"description\": \"Visual description\",\n'
103
  f" \"duration\": {default_len},\n"
104
  f" \"fps\": {default_fps},\n"
 
105
  " \"steps\": 30,\n"
106
  " \"seed\": null,\n"
107
  ' \"negative\": \"\"\n'
 
119
  return system_msg + "\n\n" + user_msg
120
 
121
  def _generate_text(model, tok, prompt_text: str) -> str:
 
122
  inputs = tok(prompt_text, return_tensors="pt")
123
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
124
  eos_id = tok.eos_token_id or tok.pad_token_id
 
132
  eos_token_id=eos_id,
133
  pad_token_id=eos_id,
134
  )
 
135
  # decode only continuation
136
  prompt_len = inputs["input_ids"].shape[1]
137
  continuation_ids = gen[0][prompt_len:]
138
  text = tok.decode(continuation_ids, skip_special_tokens=True).strip()
 
 
139
  if text.startswith("```"):
140
  text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE|re.DOTALL).strip()
141
  return text
 
170
  "description": s.get("description", ""),
171
  "duration": int(s.get("duration", default_len)),
172
  "fps": int(s.get("fps", default_fps)),
 
173
  "steps": int(s.get("steps", 30)),
174
  "seed": s.get("seed", None),
175
  "negative": s.get("negative", ""),
176
+ "image_path": s.get("image_path", None) # will be set after approval
177
  })
178
  return norm
179
 
180
  @spaces.GPU(duration=180)
181
  def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
 
 
 
182
  model, tok = _lazy_model_tok()
183
  system = "You are a film previsualization assistant. Output must be valid JSON."
184
 
185
+ # PASS 1
186
  p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
187
  _prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
188
  out1 = _generate_text(model, tok, p1)
 
189
  json_text = _extract_json_array(out1)
190
 
191
+ # PASS 2 fallback
192
  if not json_text:
193
  p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
194
  _prompt_minimal(user_prompt, n_shots, default_fps, default_len))
195
  out2 = _generate_text(model, tok, p2)
 
196
  json_text = _extract_json_array(out2)
197
  if not json_text and "[" in out2 and "]" in out2:
198
  start = out2.find("["); end = out2.rfind("]")
199
  if start != -1 and end != -1 and end > start:
200
  json_text = out2[start:end+1].strip()
201
 
202
+ # EMPTY FALLBACK: simple storyboard so UI never crashes
203
  if not json_text or not json_text.strip():
 
204
  fallback = []
205
  for i in range(1, int(n_shots) + 1):
206
  fallback.append({
 
209
  "description": f"Simple placeholder for: {user_prompt[:80]}",
210
  "duration": default_len,
211
  "fps": default_fps,
 
212
  "steps": 30,
213
  "seed": None,
214
  "negative": "",
215
+ "image_path": None
216
  })
217
  return fallback
218
 
 
219
  try:
220
  shots_raw = json.loads(json_text)
221
  except Exception:
 
224
 
225
  return _normalize_shots(shots_raw, default_fps, default_len)
226
 
227
+ # =========================
228
+ # IMAGE GEN (ZeroGPU) β€” SD1.5 text2img + img2img chaining
229
+ # =========================
230
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
231
+
232
+ SD_MODEL = os.getenv("SD_MODEL", "runwayml/stable-diffusion-v1-5")
233
+ _sd_t2i = None
234
+ _sd_i2i = None
235
+
236
+ def _lazy_sd_pipes():
237
+ global _sd_t2i, _sd_i2i
238
+ if _sd_t2i is not None and _sd_i2i is not None:
239
+ return _sd_t2i, _sd_i2i
240
+ _sd_t2i = StableDiffusionPipeline.from_pretrained(
241
+ SD_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
242
+ )
243
+ _sd_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(
244
+ SD_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
245
+ )
246
+ if torch.cuda.is_available():
247
+ _sd_t2i = _sd_t2i.to("cuda")
248
+ _sd_i2i = _sd_i2i.to("cuda")
249
+ _sd_t2i.safety_checker = None
250
+ _sd_i2i.safety_checker = None
251
+ return _sd_t2i, _sd_i2i
252
+
253
+ def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
254
+ pdir = project_dir(pid)
255
+ out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png")
256
+ img.save(out)
257
+ return out
258
+
259
+ @spaces.GPU(duration=180)
260
+ def generate_keyframe_image(
261
+ pid: str,
262
+ shot_idx: int,
263
+ shots: list,
264
+ guidance_scale: float = 7.5,
265
+ strength: float = 0.35
266
+ ):
267
+ """
268
+ Generate image for shots[shot_idx].
269
+ - If shot_idx == 0: text2img
270
+ - Else: img2img with previous shot's approved image_path as init image
271
+ Uses edited fields in shots: description, negative, steps, seed.
272
+ """
273
+ t2i, i2i = _lazy_sd_pipes()
274
+ shot = shots[shot_idx]
275
+ prompt = shot.get("description", "")
276
+ negative = shot.get("negative") or ""
277
+ steps = int(shot.get("steps", 30))
278
+ seed = shot.get("seed", None)
279
+ gen = torch.Generator("cuda" if torch.cuda.is_available() else "cpu")
280
+ if isinstance(seed, int):
281
+ gen = gen.manual_seed(seed)
282
+
283
+ if shot_idx == 0 or not shots[shot_idx - 1].get("image_path"):
284
+ # text2img
285
+ out = t2i(prompt=prompt, negative_prompt=negative, guidance_scale=guidance_scale,
286
+ num_inference_steps=steps, generator=gen).images[0]
287
+ else:
288
+ # img2img: previous approved keyframe as conditioning
289
+ prev_path = shots[shot_idx - 1]["image_path"]
290
+ init_image = Image.open(prev_path).convert("RGB")
291
+ out = i2i(prompt=prompt, negative_prompt=negative, image=init_image,
292
+ guidance_scale=guidance_scale, strength=strength,
293
+ num_inference_steps=steps, generator=gen).images[0]
294
+
295
+ saved_path = _save_keyframe(pid, int(shot["id"]), out)
296
+ return saved_path
297
+
298
+ # =========================
299
+ # Shots <-> Dataframe utils
300
+ # =========================
301
+ import pandas as pd
302
+
303
+ SHOT_COLUMNS = ["id", "title", "description", "duration", "fps", "steps", "seed", "negative", "image_path"]
304
+
305
+ def shots_to_df(shots: list) -> pd.DataFrame:
306
+ rows = []
307
+ for s in shots:
308
+ rows.append({k: s.get(k, None) for k in SHOT_COLUMNS})
309
+ df = pd.DataFrame(rows, columns=SHOT_COLUMNS)
310
+ return df
311
+
312
+ def df_to_shots(df: pd.DataFrame) -> list:
313
+ out = []
314
+ for _, row in df.iterrows():
315
+ out.append({
316
+ "id": int(row["id"]),
317
+ "title": row["title"] or f"Shot {int(row['id'])}",
318
+ "description": row["description"] or "",
319
+ "duration": int(row["duration"]) if pd.notna(row["duration"]) else 4,
320
+ "fps": int(row["fps"]) if pd.notna(row["fps"]) else 24,
321
+ "steps": int(row["steps"]) if pd.notna(row["steps"]) else 30,
322
+ "seed": (int(row["seed"]) if pd.notna(row["seed"]) else None),
323
+ "negative": row["negative"] or "",
324
+ "image_path": row["image_path"] if pd.notna(row["image_path"]) else None
325
+ })
326
+ # keep sorted by id
327
+ out = sorted(out, key=lambda x: x["id"])
328
+ return out
329
+
330
  # =========================
331
  # Gradio UI
332
  # =========================
333
  with gr.Blocks() as demo:
334
  gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
335
+ gr.Markdown("**Step 3**: Edit storyboard, then generate keyframes. Shot 2..N use the previous approved image as reference (img2img).")
336
 
337
  # Global state
338
  project = gr.State(None) # dict with meta/shots/clips
339
+ current_idx = gr.State(0) # index of current shot in Keyframes tab
340
 
341
  # Header row
342
  with gr.Row():
 
360
  sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS")
361
  sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds per shot")
362
  propose_btn = gr.Button("Propose Storyboard (LLM on ZeroGPU)")
363
+ shots_df = gr.Dataframe(headers=SHOT_COLUMNS, datatype=["number","str","str","number","number","number","number","str","str"], row_count=(1,"dynamic"), col_count=len(SHOT_COLUMNS), label="Edit shots below", wrap=True)
364
+ save_edits_btn = gr.Button("Save Edits βœ“", variant="primary")
365
+ to_keyframes_btn = gr.Button("Start Keyframes β†’", variant="secondary")
366
  sb_status = gr.Markdown("")
367
 
368
  with gr.Tab("Keyframes"):
369
+ gr.Markdown("### 2) Keyframes")
370
+ with gr.Row():
371
+ shot_info_md = gr.Markdown("")
372
+ with gr.Row():
373
+ prompt_box = gr.Textbox(label="Shot description (editable before generating)", lines=4)
374
+ with gr.Row():
375
+ gen_btn = gr.Button("Generate / Regenerate (uses previous approved image if available)", variant="primary")
376
+ approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
377
+ with gr.Row():
378
+ prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
379
+ out_img = gr.Image(label="Generated image", type="filepath")
380
+ kf_status = gr.Markdown("")
381
 
382
  with gr.Tab("Videos"):
383
  gr.Markdown("### 3) Videos (coming next)")
384
  vd_table = gr.JSON(label="Planned clip edges (read-only for now)")
 
385
 
386
  with gr.Tab("Export"):
387
  gr.Markdown("### 4) Export (coming next)")
 
389
 
390
  # -------- Handlers --------
391
  def on_new(name):
392
+ p = ensure_project(None, suggested_name=(name or "Project"))
393
+ return p, gr.update(value=f"**New project created** `{p['meta']['name']}` (id: `{p['meta']['id']}`)")
 
 
 
 
 
 
 
394
 
395
  new_btn.click(on_new, inputs=[proj_name], outputs=[project, sb_status])
396
 
397
  def on_propose(p, prompt, target_shots, fps, vlen):
 
398
  p = ensure_project(p, suggested_name=(proj_name.value if hasattr(proj_name, "value") else "Project"))
399
  if not prompt or not str(prompt).strip():
400
  raise gr.Error("Please enter a high-level prompt.")
 
403
  p["shots"] = shots
404
  p["meta"]["updated"] = now_iso()
405
  save_project(p)
406
+ return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable).")
407
 
408
  propose_btn.click(
409
  on_propose,
410
  inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len],
411
+ outputs=[project, shots_df, sb_status]
412
  )
413
 
414
+ def on_save_edits(p, df):
415
+ if p is None:
416
+ raise gr.Error("No project in memory.")
417
+ shots = df_to_shots(df)
 
 
 
 
418
  p = dict(p)
419
+ p["shots"] = shots
420
+ p["meta"]["updated"] = now_iso()
421
+ save_project(p)
422
+ return p, gr.update(value="Edits saved.")
423
+
424
+ save_edits_btn.click(on_save_edits, inputs=[project, shots_df], outputs=[project, sb_status])
425
+
426
+ def on_start_keyframes(p, df):
427
+ if p is None: raise gr.Error("No project.")
428
+ shots = df_to_shots(df)
429
+ if not shots: raise gr.Error("Storyboard is empty.")
430
+ p = dict(p); p["shots"] = shots; p["meta"]["updated"] = now_iso(); save_project(p)
431
+ idx = 0
432
+ prev_path = None
433
+ info = f"**Shot {shots[idx]['id']} β€” {shots[idx]['title']}** \nDuration: {shots[idx]['duration']}s @ {shots[idx]['fps']} fps"
434
+ return p, 0, gr.update(value=info), gr.update(value=shots[idx]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value="Ready to generate shot 1.")
435
+
436
+ to_keyframes_btn.click(on_start_keyframes, inputs=[project, shots_df], outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status])
437
+
438
+ def on_generate_img(p, idx, current_prompt):
439
+ if p is None: raise gr.Error("No project.")
440
+ shots = p["shots"]
441
+ if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.")
442
+ # Allow in-place prompt tweak before generation
443
+ shots[idx]["description"] = current_prompt
444
+ prev_path = shots[idx-1]["image_path"] if idx > 0 else None
445
+ img_path = generate_keyframe_image(p["meta"]["id"], int(idx), shots)
446
+ return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.")
447
+
448
+ gen_btn.click(on_generate_img, inputs=[project, current_idx, prompt_box], outputs=[out_img, prev_img, kf_status])
449
+
450
+ def on_approve_next(p, idx, current_prompt, latest_img_path):
451
+ if p is None: raise gr.Error("No project.")
452
+ shots = p["shots"]
453
+ i = int(idx)
454
+ if i < 0 or i >= len(shots): raise gr.Error("Invalid shot index.")
455
+ if not latest_img_path: raise gr.Error("Generate an image first.")
456
+ # commit prompt and image path
457
+ shots[i]["description"] = current_prompt
458
+ shots[i]["image_path"] = latest_img_path
459
+ p["shots"] = shots
460
  p["meta"]["updated"] = now_iso()
461
  save_project(p)
 
 
 
 
 
 
 
462
 
463
+ # Move to next
464
+ if i + 1 < len(shots):
465
+ ni = i + 1
466
+ info = f"**Shot {shots[ni]['id']} β€” {shots[ni]['title']}** \nDuration: {shots[ni]['duration']}s @ {shots[ni]['fps']} fps"
467
+ prev_path = shots[ni-1]["image_path"]
468
+ return p, ni, gr.update(value=info), gr.update(value=shots[ni]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value=f"Approved shot {shots[i]['id']}. On to shot {shots[ni]['id']}.")
469
+ else:
470
+ # finished all keyframes
471
+ return p, i, gr.update(value="**All keyframes approved.** Proceed to Videos tab."), gr.update(value=""), gr.update(value=shots[i]["image_path"]), gr.update(value=None), gr.update(value="All shots approved βœ…")
472
+
473
+ approve_next_btn.click(on_approve_next, inputs=[project, current_idx, prompt_box, out_img], outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status])
474
 
475
  def on_save(p):
476
  if p is None:
 
478
  path = save_project(p)
479
  return gr.update(value=f"Saved to `{path}`")
480
 
481
+ save_btn.click(on_save, inputs=[project], outputs=[gr.Markdown.update(value="Project saved.")])
482
 
483
  def on_load(file_obj):
484
  p = load_project_file(file_obj)
485
  return (
486
  p,
487
  gr.update(value=f"Loaded project `{p['meta']['name']}` (id: `{p['meta']['id']}`)"),
488
+ shots_to_df(p.get("shots", [])),
 
 
489
  )
490
 
491
+ load_btn.click(on_load, inputs=[load_file], outputs=[project, sb_status, shots_df])
 
 
 
 
492
 
493
  if __name__ == "__main__":
494
  demo.launch()