Shalmoni commited on
Commit
5904c28
·
verified ·
1 Parent(s): a035fe0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -38
app.py CHANGED
@@ -33,12 +33,12 @@ def load_project_file(file_obj):
33
  return proj
34
 
35
  # =========================
36
- # LLM (ZeroGPU) — Storyboard generator (robust, two-pass)
37
  # =========================
38
  from transformers import AutoTokenizer, AutoModelForCausalLM
39
 
40
  STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
41
- HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "900"))
42
 
43
  _tokenizer = None
44
  _model = None
@@ -54,6 +54,9 @@ def _lazy_model_tok():
54
  dtype="auto",
55
  trust_remote_code=True,
56
  )
 
 
 
57
  return _model, _tokenizer
58
 
59
  def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
@@ -61,7 +64,7 @@ def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_
61
  "Return ONLY a JSON array, enclosed between <JSON> and </JSON>.\n"
62
  f"Create a storyboard of {n_shots} shots for this idea:\n\n"
63
  f"'''{user_prompt}'''\n\n"
64
- "Schema per item:\n"
65
  "{\n"
66
  ' \"id\": <int starting at 1>,\n'
67
  ' \"title\": \"Short title\",\n'
@@ -77,11 +80,10 @@ def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_
77
  )
78
 
79
  def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
80
- # Second attempt if tags fail: demand ONLY an array, nothing else.
81
  return (
82
  "Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n"
83
  f"Storyboard: {n_shots} shots for:\n'''{user_prompt}'''\n"
84
- "Each item:\n"
85
  "{\n"
86
  ' \"id\": <int starting at 1>,\n'
87
  ' \"title\": \"Short title\",\n'
@@ -116,30 +118,27 @@ def _generate_text(model, tok, prompt_text: str) -> str:
116
  temperature=0.0,
117
  repetition_penalty=1.05,
118
  eos_token_id=eos_id,
119
- pad_token_id=eos_id,
120
  )
121
  text = tok.decode(gen[0], skip_special_tokens=True)
122
- # Trim the echoed prompt if the model included it
123
  if text.startswith(prompt_text):
124
  text = text[len(prompt_text):]
125
- # Strip code fences if any
126
  text = text.strip()
127
  if text.startswith("```"):
128
- # remove ```json ... ```
129
  text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE|re.DOTALL).strip()
130
  return text
131
 
132
  def _extract_json_array(text: str) -> str:
133
- # Prefer <JSON>...</JSON>
134
  m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE)
135
  if m:
136
  inner = m.group(1).strip()
137
  if inner:
138
  return inner
139
- # Fallback: balanced array
140
  start = text.find("[")
141
  if start == -1:
142
- return "" # signal failure to caller
143
  depth = 0
144
  for i in range(start, len(text)):
145
  ch = text[i]
@@ -149,7 +148,7 @@ def _extract_json_array(text: str) -> str:
149
  depth -= 1
150
  if depth == 0:
151
  return text[start:i+1].strip()
152
- return "" # unbalanced
153
 
154
  def _normalize_shots(shots_raw, default_fps: int, default_len: int):
155
  norm = []
@@ -171,45 +170,37 @@ def _normalize_shots(shots_raw, default_fps: int, default_len: int):
171
  @spaces.GPU(duration=180)
172
  def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
173
  """
174
- Two-pass generation for robustness:
175
- 1) <JSON>...</JSON>
176
- 2) strict array-only fallback
177
  """
178
  model, tok = _lazy_model_tok()
179
  system = "You are a film previsualization assistant. Output must be valid JSON."
180
 
181
- # ---- PASS 1: with <JSON> tags
182
- p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
183
  _prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
184
  out1 = _generate_text(model, tok, p1)
 
185
  json_text = _extract_json_array(out1)
186
 
187
- # ---- PASS 2: strict array (if needed)
188
  if not json_text:
189
- p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
190
  _prompt_minimal(user_prompt, n_shots, default_fps, default_len))
191
  out2 = _generate_text(model, tok, p2)
 
192
  json_text = _extract_json_array(out2)
193
-
194
- # As a last ditch, try bracket slice only
195
- if not json_text:
196
  start = out2.find("["); end = out2.rfind("]")
197
  if start != -1 and end != -1 and end > start:
198
  json_text = out2[start:end+1].strip()
199
 
200
- if not json_text:
201
- # Show a short preview so you can see what the model returned
202
- preview = (out2[:400] + "...") if len(out2) > 400 else out2
203
- raise ValueError(f"LLM did not return parseable JSON.\nPreview:\n{preview}")
204
-
205
- # Parse & normalize
206
- if not json_text.strip():
207
- # Fallback: model returned nothing. Return a single stub shot.
208
- print("⚠️ LLM returned empty output. Using fallback storyboard.")
209
- fallback = [{
210
  "id": 1,
211
  "title": "Shot 1",
212
- "description": f"Fallback shot for: {user_prompt[:50]}",
213
  "duration": default_len,
214
  "fps": default_fps,
215
  "video_length": default_len,
@@ -218,19 +209,16 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
218
  "negative": "",
219
  "keyframe_path": None
220
  }]
221
- return fallback
222
 
 
223
  try:
224
  shots_raw = json.loads(json_text)
225
  except Exception:
226
- # Attempt a tiny cleanup: remove trailing commas and try again
227
  json_text_clean = re.sub(r",\s*([\]\}])", r"\1", json_text)
228
  shots_raw = json.loads(json_text_clean)
229
 
230
  return _normalize_shots(shots_raw, default_fps, default_len)
231
 
232
-
233
-
234
  # =========================
235
  # Gradio UI
236
  # =========================
 
33
  return proj
34
 
35
  # =========================
36
+ # LLM (ZeroGPU) — Storyboard generator (robust, two-pass + empty fallback)
37
  # =========================
38
  from transformers import AutoTokenizer, AutoModelForCausalLM
39
 
40
  STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
41
+ HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "1200")) # give a bit more room
42
 
43
  _tokenizer = None
44
  _model = None
 
54
  dtype="auto",
55
  trust_remote_code=True,
56
  )
57
+ # Ensure pad token exists to avoid warnings
58
+ if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
59
+ _tokenizer.pad_token_id = _tokenizer.eos_token_id
60
  return _model, _tokenizer
61
 
62
  def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
 
64
  "Return ONLY a JSON array, enclosed between <JSON> and </JSON>.\n"
65
  f"Create a storyboard of {n_shots} shots for this idea:\n\n"
66
  f"'''{user_prompt}'''\n\n"
67
+ "Each item schema:\n"
68
  "{\n"
69
  ' \"id\": <int starting at 1>,\n'
70
  ' \"title\": \"Short title\",\n'
 
80
  )
81
 
82
  def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
 
83
  return (
84
  "Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n"
85
  f"Storyboard: {n_shots} shots for:\n'''{user_prompt}'''\n"
86
+ "Item schema:\n"
87
  "{\n"
88
  ' \"id\": <int starting at 1>,\n'
89
  ' \"title\": \"Short title\",\n'
 
118
  temperature=0.0,
119
  repetition_penalty=1.05,
120
  eos_token_id=eos_id,
121
+ pad_token_id=tok.pad_token_id if tok.pad_token_id is not None else eos_id,
122
  )
123
  text = tok.decode(gen[0], skip_special_tokens=True)
 
124
  if text.startswith(prompt_text):
125
  text = text[len(prompt_text):]
126
+ # strip code fences if present
127
  text = text.strip()
128
  if text.startswith("```"):
 
129
  text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE|re.DOTALL).strip()
130
  return text
131
 
132
  def _extract_json_array(text: str) -> str:
 
133
  m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE)
134
  if m:
135
  inner = m.group(1).strip()
136
  if inner:
137
  return inner
138
+ # Fallback: first balanced array
139
  start = text.find("[")
140
  if start == -1:
141
+ return ""
142
  depth = 0
143
  for i in range(start, len(text)):
144
  ch = text[i]
 
148
  depth -= 1
149
  if depth == 0:
150
  return text[start:i+1].strip()
151
+ return ""
152
 
153
  def _normalize_shots(shots_raw, default_fps: int, default_len: int):
154
  norm = []
 
170
  @spaces.GPU(duration=180)
171
  def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
172
  """
173
+ Two-pass generation with robust parsing and empty-output fallback.
 
 
174
  """
175
  model, tok = _lazy_model_tok()
176
  system = "You are a film previsualization assistant. Output must be valid JSON."
177
 
178
+ # PASS 1: with <JSON> tags
179
+ p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
180
  _prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
181
  out1 = _generate_text(model, tok, p1)
182
+ print(f"[DEBUG] LLM raw out1 (first 240 chars): {out1[:240]}")
183
  json_text = _extract_json_array(out1)
184
 
185
+ # PASS 2: strict array fallback
186
  if not json_text:
187
+ p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
188
  _prompt_minimal(user_prompt, n_shots, default_fps, default_len))
189
  out2 = _generate_text(model, tok, p2)
190
+ print(f"[DEBUG] LLM raw out2 (first 240 chars): {out2[:240]}")
191
  json_text = _extract_json_array(out2)
192
+ if not json_text and "[" in out2 and "]" in out2:
 
 
193
  start = out2.find("["); end = out2.rfind("]")
194
  if start != -1 and end != -1 and end > start:
195
  json_text = out2[start:end+1].strip()
196
 
197
+ # EMPTY FALLBACK → return a single stub so the app does not crash
198
+ if not json_text or not json_text.strip():
199
+ print("⚠️ LLM returned empty or unparsable JSON. Using fallback storyboard.")
200
+ return [{
 
 
 
 
 
 
201
  "id": 1,
202
  "title": "Shot 1",
203
+ "description": f"Fallback shot for: {user_prompt[:80]}",
204
  "duration": default_len,
205
  "fps": default_fps,
206
  "video_length": default_len,
 
209
  "negative": "",
210
  "keyframe_path": None
211
  }]
 
212
 
213
+ # Parse & normalize (with tiny trailing-comma cleanup)
214
  try:
215
  shots_raw = json.loads(json_text)
216
  except Exception:
 
217
  json_text_clean = re.sub(r",\s*([\]\}])", r"\1", json_text)
218
  shots_raw = json.loads(json_text_clean)
219
 
220
  return _normalize_shots(shots_raw, default_fps, default_len)
221
 
 
 
222
  # =========================
223
  # Gradio UI
224
  # =========================