Shalmoni commited on
Commit
96406a7
Β·
verified Β·
1 Parent(s): d494c1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -60
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import os, json, uuid, re
2
  from datetime import datetime
3
  import gradio as gr
4
- import spaces # ZeroGPU decorator
5
  import torch
6
  from PIL import Image
 
7
 
8
  # =========================
9
  # Storage helpers
@@ -30,7 +31,7 @@ def save_project(proj):
30
  def load_project_file(file_obj):
31
  with open(file_obj.name, "r") as f:
32
  proj = json.load(f)
33
- project_dir(proj["meta"]["id"]) # ensure dirs
34
  return proj
35
 
36
  def ensure_project(p, suggested_name="Project"):
@@ -40,14 +41,14 @@ def ensure_project(p, suggested_name="Project"):
40
  name = f"{suggested_name}-{pid[:4]}"
41
  proj = {
42
  "meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()},
43
- "shots": [], # each: id,title,description,duration,fps,steps,seed,negative, image_path?(on approval)
44
  "clips": []
45
  }
46
  save_project(proj)
47
  return proj
48
 
49
  # =========================
50
- # LLM (ZeroGPU) β€” Storyboard generator (robust, two-pass + empty fallback)
51
  # =========================
52
  from transformers import AutoTokenizer, AutoModelForCausalLM
53
 
@@ -119,6 +120,7 @@ def _apply_chat(tok, system_msg: str, user_msg: str) -> str:
119
  return system_msg + "\n\n" + user_msg
120
 
121
  def _generate_text(model, tok, prompt_text: str) -> str:
 
122
  inputs = tok(prompt_text, return_tensors="pt")
123
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
124
  eos_id = tok.eos_token_id or tok.pad_token_id
@@ -132,7 +134,6 @@ def _generate_text(model, tok, prompt_text: str) -> str:
132
  eos_token_id=eos_id,
133
  pad_token_id=eos_id,
134
  )
135
- # decode only continuation
136
  prompt_len = inputs["input_ids"].shape[1]
137
  continuation_ids = gen[0][prompt_len:]
138
  text = tok.decode(continuation_ids, skip_special_tokens=True).strip()
@@ -146,7 +147,6 @@ def _extract_json_array(text: str) -> str:
146
  inner = m.group(1).strip()
147
  if inner:
148
  return inner
149
- # Fallback: first balanced array
150
  start = text.find("[")
151
  if start == -1:
152
  return ""
@@ -173,7 +173,7 @@ def _normalize_shots(shots_raw, default_fps: int, default_len: int):
173
  "steps": int(s.get("steps", 30)),
174
  "seed": s.get("seed", None),
175
  "negative": s.get("negative", ""),
176
- "image_path": s.get("image_path", None) # will be set after approval
177
  })
178
  return norm
179
 
@@ -182,13 +182,13 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
182
  model, tok = _lazy_model_tok()
183
  system = "You are a film previsualization assistant. Output must be valid JSON."
184
 
185
- # PASS 1
186
  p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
187
  _prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
188
  out1 = _generate_text(model, tok, p1)
189
  json_text = _extract_json_array(out1)
190
 
191
- # PASS 2 fallback
192
  if not json_text:
193
  p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
194
  _prompt_minimal(user_prompt, n_shots, default_fps, default_len))
@@ -199,7 +199,7 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
199
  if start != -1 and end != -1 and end > start:
200
  json_text = out2[start:end+1].strip()
201
 
202
- # EMPTY FALLBACK: simple storyboard so UI never crashes
203
  if not json_text or not json_text.strip():
204
  fallback = []
205
  for i in range(1, int(n_shots) + 1):
@@ -234,20 +234,35 @@ _sd_t2i = None
234
  _sd_i2i = None
235
 
236
  def _lazy_sd_pipes():
 
237
  global _sd_t2i, _sd_i2i
238
  if _sd_t2i is not None and _sd_i2i is not None:
239
  return _sd_t2i, _sd_i2i
 
 
 
240
  _sd_t2i = StableDiffusionPipeline.from_pretrained(
241
- SD_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
242
- )
243
- _sd_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(
244
- SD_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
245
  )
246
  if torch.cuda.is_available():
247
  _sd_t2i = _sd_t2i.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
248
  _sd_i2i = _sd_i2i.to("cuda")
249
- _sd_t2i.safety_checker = None
250
- _sd_i2i.safety_checker = None
251
  return _sd_t2i, _sd_i2i
252
 
253
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
@@ -266,9 +281,8 @@ def generate_keyframe_image(
266
  ):
267
  """
268
  Generate image for shots[shot_idx].
269
- - If shot_idx == 0: text2img
270
- - Else: img2img with previous shot's approved image_path as init image
271
- Uses edited fields in shots: description, negative, steps, seed.
272
  """
273
  t2i, i2i = _lazy_sd_pipes()
274
  shot = shots[shot_idx]
@@ -276,21 +290,31 @@ def generate_keyframe_image(
276
  negative = shot.get("negative") or ""
277
  steps = int(shot.get("steps", 30))
278
  seed = shot.get("seed", None)
 
279
  gen = torch.Generator("cuda" if torch.cuda.is_available() else "cpu")
280
  if isinstance(seed, int):
281
  gen = gen.manual_seed(seed)
282
 
283
  if shot_idx == 0 or not shots[shot_idx - 1].get("image_path"):
284
- # text2img
285
- out = t2i(prompt=prompt, negative_prompt=negative, guidance_scale=guidance_scale,
286
- num_inference_steps=steps, generator=gen).images[0]
 
 
 
 
287
  else:
288
- # img2img: previous approved keyframe as conditioning
289
  prev_path = shots[shot_idx - 1]["image_path"]
290
  init_image = Image.open(prev_path).convert("RGB")
291
- out = i2i(prompt=prompt, negative_prompt=negative, image=init_image,
292
- guidance_scale=guidance_scale, strength=strength,
293
- num_inference_steps=steps, generator=gen).images[0]
 
 
 
 
 
 
294
 
295
  saved_path = _save_keyframe(pid, int(shot["id"]), out)
296
  return saved_path
@@ -298,23 +322,18 @@ def generate_keyframe_image(
298
  # =========================
299
  # Shots <-> Dataframe utils
300
  # =========================
301
- import pandas as pd
302
-
303
  SHOT_COLUMNS = ["id", "title", "description", "duration", "fps", "steps", "seed", "negative", "image_path"]
304
 
305
  def shots_to_df(shots: list) -> pd.DataFrame:
306
- rows = []
307
- for s in shots:
308
- rows.append({k: s.get(k, None) for k in SHOT_COLUMNS})
309
- df = pd.DataFrame(rows, columns=SHOT_COLUMNS)
310
- return df
311
 
312
  def df_to_shots(df: pd.DataFrame) -> list:
313
  out = []
314
  for _, row in df.iterrows():
315
  out.append({
316
  "id": int(row["id"]),
317
- "title": row["title"] or f"Shot {int(row['id'])}",
318
  "description": row["description"] or "",
319
  "duration": int(row["duration"]) if pd.notna(row["duration"]) else 4,
320
  "fps": int(row["fps"]) if pd.notna(row["fps"]) else 24,
@@ -323,22 +342,20 @@ def df_to_shots(df: pd.DataFrame) -> list:
323
  "negative": row["negative"] or "",
324
  "image_path": row["image_path"] if pd.notna(row["image_path"]) else None
325
  })
326
- # keep sorted by id
327
- out = sorted(out, key=lambda x: x["id"])
328
- return out
329
 
330
  # =========================
331
  # Gradio UI
332
  # =========================
333
  with gr.Blocks() as demo:
334
  gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
335
- gr.Markdown("**Step 3**: Edit storyboard, then generate keyframes. Shot 2..N use the previous approved image as reference (img2img).")
336
 
337
- # Global state
338
- project = gr.State(None) # dict with meta/shots/clips
339
- current_idx = gr.State(0) # index of current shot in Keyframes tab
340
 
341
- # Header row
342
  with gr.Row():
343
  with gr.Column(scale=2):
344
  proj_name = gr.Textbox(label="Project name", placeholder="e.g., Desert Chase")
@@ -349,6 +366,7 @@ with gr.Blocks() as demo:
349
  with gr.Column(scale=1):
350
  load_file = gr.File(label="Load Project (project.json)", file_count="single", type="filepath")
351
  load_btn = gr.Button("Load")
 
352
 
353
  # Tabs
354
  with gr.Tabs():
@@ -360,19 +378,21 @@ with gr.Blocks() as demo:
360
  sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS")
361
  sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds per shot")
362
  propose_btn = gr.Button("Propose Storyboard (LLM on ZeroGPU)")
363
- shots_df = gr.Dataframe(headers=SHOT_COLUMNS, datatype=["number","str","str","number","number","number","number","str","str"], row_count=(1,"dynamic"), col_count=len(SHOT_COLUMNS), label="Edit shots below", wrap=True)
364
- save_edits_btn = gr.Button("Save Edits βœ“", variant="primary")
 
 
 
 
 
365
  to_keyframes_btn = gr.Button("Start Keyframes β†’", variant="secondary")
366
- sb_status = gr.Markdown("")
367
 
368
  with gr.Tab("Keyframes"):
369
  gr.Markdown("### 2) Keyframes")
 
 
370
  with gr.Row():
371
- shot_info_md = gr.Markdown("")
372
- with gr.Row():
373
- prompt_box = gr.Textbox(label="Shot description (editable before generating)", lines=4)
374
- with gr.Row():
375
- gen_btn = gr.Button("Generate / Regenerate (uses previous approved image if available)", variant="primary")
376
  approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
377
  with gr.Row():
378
  prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
@@ -387,7 +407,7 @@ with gr.Blocks() as demo:
387
  gr.Markdown("### 4) Export (coming next)")
388
  export_info = gr.Markdown("Nothing to export yet.")
389
 
390
- # -------- Handlers --------
391
  def on_new(name):
392
  p = ensure_project(None, suggested_name=(name or "Project"))
393
  return p, gr.update(value=f"**New project created** `{p['meta']['name']}` (id: `{p['meta']['id']}`)")
@@ -403,17 +423,23 @@ with gr.Blocks() as demo:
403
  p["shots"] = shots
404
  p["meta"]["updated"] = now_iso()
405
  save_project(p)
406
- return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable).")
 
407
 
408
  propose_btn.click(
409
  on_propose,
410
  inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len],
411
- outputs=[project, shots_df, sb_status]
412
  )
413
 
414
- def on_save_edits(p, df):
 
 
 
415
  if p is None:
416
- raise gr.Error("No project in memory.")
 
 
417
  shots = df_to_shots(df)
418
  p = dict(p)
419
  p["shots"] = shots
@@ -439,8 +465,7 @@ with gr.Blocks() as demo:
439
  if p is None: raise gr.Error("No project.")
440
  shots = p["shots"]
441
  if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.")
442
- # Allow in-place prompt tweak before generation
443
- shots[idx]["description"] = current_prompt
444
  prev_path = shots[idx-1]["image_path"] if idx > 0 else None
445
  img_path = generate_keyframe_image(p["meta"]["id"], int(idx), shots)
446
  return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.")
@@ -453,21 +478,20 @@ with gr.Blocks() as demo:
453
  i = int(idx)
454
  if i < 0 or i >= len(shots): raise gr.Error("Invalid shot index.")
455
  if not latest_img_path: raise gr.Error("Generate an image first.")
456
- # commit prompt and image path
457
  shots[i]["description"] = current_prompt
458
  shots[i]["image_path"] = latest_img_path
459
  p["shots"] = shots
460
  p["meta"]["updated"] = now_iso()
461
  save_project(p)
462
 
463
- # Move to next
464
  if i + 1 < len(shots):
465
  ni = i + 1
466
  info = f"**Shot {shots[ni]['id']} β€” {shots[ni]['title']}** \nDuration: {shots[ni]['duration']}s @ {shots[ni]['fps']} fps"
467
  prev_path = shots[ni-1]["image_path"]
468
  return p, ni, gr.update(value=info), gr.update(value=shots[ni]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value=f"Approved shot {shots[i]['id']}. On to shot {shots[ni]['id']}.")
469
  else:
470
- # finished all keyframes
471
  return p, i, gr.update(value="**All keyframes approved.** Proceed to Videos tab."), gr.update(value=""), gr.update(value=shots[i]["image_path"]), gr.update(value=None), gr.update(value="All shots approved βœ…")
472
 
473
  approve_next_btn.click(on_approve_next, inputs=[project, current_idx, prompt_box, out_img], outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status])
 
1
  import os, json, uuid, re
2
  from datetime import datetime
3
  import gradio as gr
4
+ import spaces
5
  import torch
6
  from PIL import Image
7
+ import pandas as pd
8
 
9
  # =========================
10
  # Storage helpers
 
31
  def load_project_file(file_obj):
32
  with open(file_obj.name, "r") as f:
33
  proj = json.load(f)
34
+ project_dir(proj["meta"]["id"])
35
  return proj
36
 
37
  def ensure_project(p, suggested_name="Project"):
 
41
  name = f"{suggested_name}-{pid[:4]}"
42
  proj = {
43
  "meta": {"id": pid, "name": name, "created": now_iso(), "updated": now_iso()},
44
+ "shots": [], # each shot: id,title,description,duration,fps,steps,seed,negative,image_path?
45
  "clips": []
46
  }
47
  save_project(proj)
48
  return proj
49
 
50
  # =========================
51
+ # LLM (ZeroGPU) β€” Storyboard generator (robust)
52
  # =========================
53
  from transformers import AutoTokenizer, AutoModelForCausalLM
54
 
 
120
  return system_msg + "\n\n" + user_msg
121
 
122
  def _generate_text(model, tok, prompt_text: str) -> str:
123
+ """Decode only the continuation (avoid prompt echo)."""
124
  inputs = tok(prompt_text, return_tensors="pt")
125
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
126
  eos_id = tok.eos_token_id or tok.pad_token_id
 
134
  eos_token_id=eos_id,
135
  pad_token_id=eos_id,
136
  )
 
137
  prompt_len = inputs["input_ids"].shape[1]
138
  continuation_ids = gen[0][prompt_len:]
139
  text = tok.decode(continuation_ids, skip_special_tokens=True).strip()
 
147
  inner = m.group(1).strip()
148
  if inner:
149
  return inner
 
150
  start = text.find("[")
151
  if start == -1:
152
  return ""
 
173
  "steps": int(s.get("steps", 30)),
174
  "seed": s.get("seed", None),
175
  "negative": s.get("negative", ""),
176
+ "image_path": s.get("image_path", None)
177
  })
178
  return norm
179
 
 
182
  model, tok = _lazy_model_tok()
183
  system = "You are a film previsualization assistant. Output must be valid JSON."
184
 
185
+ # Pass 1
186
  p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
187
  _prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
188
  out1 = _generate_text(model, tok, p1)
189
  json_text = _extract_json_array(out1)
190
 
191
+ # Pass 2
192
  if not json_text:
193
  p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
194
  _prompt_minimal(user_prompt, n_shots, default_fps, default_len))
 
199
  if start != -1 and end != -1 and end > start:
200
  json_text = out2[start:end+1].strip()
201
 
202
+ # Empty fallback
203
  if not json_text or not json_text.strip():
204
  fallback = []
205
  for i in range(1, int(n_shots) + 1):
 
234
  _sd_i2i = None
235
 
236
  def _lazy_sd_pipes():
237
+ """Load SD once, disable safety checker to avoid offload_state_dict issues; reuse modules for img2img."""
238
  global _sd_t2i, _sd_i2i
239
  if _sd_t2i is not None and _sd_i2i is not None:
240
  return _sd_t2i, _sd_i2i
241
+
242
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
243
+
244
  _sd_t2i = StableDiffusionPipeline.from_pretrained(
245
+ SD_MODEL,
246
+ torch_dtype=dtype,
247
+ safety_checker=None,
248
+ feature_extractor=None,
249
+ use_safetensors=True
250
  )
251
  if torch.cuda.is_available():
252
  _sd_t2i = _sd_t2i.to("cuda")
253
+
254
+ _sd_i2i = StableDiffusionImg2ImgPipeline(
255
+ vae=_sd_t2i.vae,
256
+ text_encoder=_sd_t2i.text_encoder,
257
+ tokenizer=_sd_t2i.tokenizer,
258
+ unet=_sd_t2i.unet,
259
+ scheduler=_sd_t2i.scheduler,
260
+ safety_checker=None,
261
+ feature_extractor=None
262
+ )
263
+ if torch.cuda.is_available():
264
  _sd_i2i = _sd_i2i.to("cuda")
265
+
 
266
  return _sd_t2i, _sd_i2i
267
 
268
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
 
281
  ):
282
  """
283
  Generate image for shots[shot_idx].
284
+ - shot 0: text2img
285
+ - shot k>0: img2img using previous approved image as conditioning (if available)
 
286
  """
287
  t2i, i2i = _lazy_sd_pipes()
288
  shot = shots[shot_idx]
 
290
  negative = shot.get("negative") or ""
291
  steps = int(shot.get("steps", 30))
292
  seed = shot.get("seed", None)
293
+
294
  gen = torch.Generator("cuda" if torch.cuda.is_available() else "cpu")
295
  if isinstance(seed, int):
296
  gen = gen.manual_seed(seed)
297
 
298
  if shot_idx == 0 or not shots[shot_idx - 1].get("image_path"):
299
+ out = t2i(
300
+ prompt=prompt,
301
+ negative_prompt=negative,
302
+ guidance_scale=guidance_scale,
303
+ num_inference_steps=steps,
304
+ generator=gen
305
+ ).images[0]
306
  else:
 
307
  prev_path = shots[shot_idx - 1]["image_path"]
308
  init_image = Image.open(prev_path).convert("RGB")
309
+ out = i2i(
310
+ prompt=prompt,
311
+ negative_prompt=negative,
312
+ image=init_image,
313
+ guidance_scale=guidance_scale,
314
+ strength=strength,
315
+ num_inference_steps=steps,
316
+ generator=gen
317
+ ).images[0]
318
 
319
  saved_path = _save_keyframe(pid, int(shot["id"]), out)
320
  return saved_path
 
322
  # =========================
323
  # Shots <-> Dataframe utils
324
  # =========================
 
 
325
  SHOT_COLUMNS = ["id", "title", "description", "duration", "fps", "steps", "seed", "negative", "image_path"]
326
 
327
  def shots_to_df(shots: list) -> pd.DataFrame:
328
+ rows = [{k: s.get(k, None) for k in SHOT_COLUMNS} for s in shots]
329
+ return pd.DataFrame(rows, columns=SHOT_COLUMNS)
 
 
 
330
 
331
  def df_to_shots(df: pd.DataFrame) -> list:
332
  out = []
333
  for _, row in df.iterrows():
334
  out.append({
335
  "id": int(row["id"]),
336
+ "title": (row["title"] or f"Shot {int(row['id'])}"),
337
  "description": row["description"] or "",
338
  "duration": int(row["duration"]) if pd.notna(row["duration"]) else 4,
339
  "fps": int(row["fps"]) if pd.notna(row["fps"]) else 24,
 
342
  "negative": row["negative"] or "",
343
  "image_path": row["image_path"] if pd.notna(row["image_path"]) else None
344
  })
345
+ return sorted(out, key=lambda x: x["id"])
 
 
346
 
347
  # =========================
348
  # Gradio UI
349
  # =========================
350
  with gr.Blocks() as demo:
351
  gr.Markdown("# 🎬 Storyboard β†’ Keyframes β†’ Videos β†’ Export")
352
+ gr.Markdown("**Edit storyboard prompts**, then generate keyframes. Each next shot uses the **previous approved image** as reference.")
353
 
354
+ # State
355
+ project = gr.State(None)
356
+ current_idx = gr.State(0)
357
 
358
+ # Header
359
  with gr.Row():
360
  with gr.Column(scale=2):
361
  proj_name = gr.Textbox(label="Project name", placeholder="e.g., Desert Chase")
 
366
  with gr.Column(scale=1):
367
  load_file = gr.File(label="Load Project (project.json)", file_count="single", type="filepath")
368
  load_btn = gr.Button("Load")
369
+ sb_status = gr.Markdown("")
370
 
371
  # Tabs
372
  with gr.Tabs():
 
378
  sb_default_fps = gr.Slider(8, 60, value=24, step=1, label="Default FPS")
379
  sb_default_len = gr.Slider(1, 12, value=4, step=1, label="Default seconds per shot")
380
  propose_btn = gr.Button("Propose Storyboard (LLM on ZeroGPU)")
381
+ shots_df = gr.Dataframe(
382
+ headers=SHOT_COLUMNS,
383
+ datatype=["number","str","str","number","number","number","number","str","str"],
384
+ row_count=(1,"dynamic"), col_count=len(SHOT_COLUMNS),
385
+ label="Edit shots below (prompts & params)", wrap=True
386
+ )
387
+ save_edits_btn = gr.Button("Save Edits βœ“", variant="primary", interactive=False)
388
  to_keyframes_btn = gr.Button("Start Keyframes β†’", variant="secondary")
 
389
 
390
  with gr.Tab("Keyframes"):
391
  gr.Markdown("### 2) Keyframes")
392
+ shot_info_md = gr.Markdown("")
393
+ prompt_box = gr.Textbox(label="Shot description (editable before generating)", lines=4)
394
  with gr.Row():
395
+ gen_btn = gr.Button("Generate / Regenerate", variant="primary")
 
 
 
 
396
  approve_next_btn = gr.Button("Approve & Next β†’", variant="secondary")
397
  with gr.Row():
398
  prev_img = gr.Image(label="Previous approved image (conditioning)", type="filepath")
 
407
  gr.Markdown("### 4) Export (coming next)")
408
  export_info = gr.Markdown("Nothing to export yet.")
409
 
410
+ # ---------- Handlers ----------
411
  def on_new(name):
412
  p = ensure_project(None, suggested_name=(name or "Project"))
413
  return p, gr.update(value=f"**New project created** `{p['meta']['name']}` (id: `{p['meta']['id']}`)")
 
423
  p["shots"] = shots
424
  p["meta"]["updated"] = now_iso()
425
  save_project(p)
426
+ # Enable Save Edits after storyboard exists
427
+ return p, shots_to_df(shots), gr.update(value="Storyboard generated (editable)."), gr.update(interactive=True)
428
 
429
  propose_btn.click(
430
  on_propose,
431
  inputs=[project, sb_prompt, sb_target_shots, sb_default_fps, sb_default_len],
432
+ outputs=[project, shots_df, sb_status, save_edits_btn]
433
  )
434
 
435
+ # Defensive save handler (works even if user clicks too early)
436
+ def on_save_edits(*args):
437
+ p = args[0] if len(args) > 0 else None
438
+ df = args[1] if len(args) > 1 else None
439
  if p is None:
440
+ raise gr.Error("No project in memory. Click New Project, then generate a storyboard.")
441
+ if df is None:
442
+ raise gr.Error("No storyboard table to save. Generate a storyboard first, then edit it.")
443
  shots = df_to_shots(df)
444
  p = dict(p)
445
  p["shots"] = shots
 
465
  if p is None: raise gr.Error("No project.")
466
  shots = p["shots"]
467
  if idx < 0 or idx >= len(shots): raise gr.Error("Invalid shot index.")
468
+ shots[idx]["description"] = current_prompt # allow tweaking before generation
 
469
  prev_path = shots[idx-1]["image_path"] if idx > 0 else None
470
  img_path = generate_keyframe_image(p["meta"]["id"], int(idx), shots)
471
  return img_path, (prev_path or None), gr.update(value=f"Generated candidate for shot {shots[idx]['id']}.")
 
478
  i = int(idx)
479
  if i < 0 or i >= len(shots): raise gr.Error("Invalid shot index.")
480
  if not latest_img_path: raise gr.Error("Generate an image first.")
481
+ # commit
482
  shots[i]["description"] = current_prompt
483
  shots[i]["image_path"] = latest_img_path
484
  p["shots"] = shots
485
  p["meta"]["updated"] = now_iso()
486
  save_project(p)
487
 
488
+ # next
489
  if i + 1 < len(shots):
490
  ni = i + 1
491
  info = f"**Shot {shots[ni]['id']} β€” {shots[ni]['title']}** \nDuration: {shots[ni]['duration']}s @ {shots[ni]['fps']} fps"
492
  prev_path = shots[ni-1]["image_path"]
493
  return p, ni, gr.update(value=info), gr.update(value=shots[ni]["description"]), gr.update(value=prev_path), gr.update(value=None), gr.update(value=f"Approved shot {shots[i]['id']}. On to shot {shots[ni]['id']}.")
494
  else:
 
495
  return p, i, gr.update(value="**All keyframes approved.** Proceed to Videos tab."), gr.update(value=""), gr.update(value=shots[i]["image_path"]), gr.update(value=None), gr.update(value="All shots approved βœ…")
496
 
497
  approve_next_btn.click(on_approve_next, inputs=[project, current_idx, prompt_box, out_img], outputs=[project, current_idx, shot_info_md, prompt_box, prev_img, out_img, kf_status])