Shalmoni commited on
Commit
5ac63ce
·
verified ·
1 Parent(s): 58c4d87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -52
app.py CHANGED
@@ -33,7 +33,7 @@ def load_project_file(file_obj):
33
  return proj
34
 
35
  # =========================
36
- # LLM (ZeroGPU) — Storyboard generator (robust JSON)
37
  # =========================
38
  from transformers import AutoTokenizer, AutoModelForCausalLM
39
 
@@ -51,13 +51,12 @@ def _lazy_model_tok():
51
  _model = AutoModelForCausalLM.from_pretrained(
52
  STORYBOARD_MODEL,
53
  device_map="auto",
54
- dtype="auto", # prefer `dtype` (torch_dtype is deprecated)
55
  trust_remote_code=True,
56
  )
57
  return _model, _tokenizer
58
 
59
- def _storyboard_prompt(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
60
- # Force the model to wrap JSON in tags; makes parsing deterministic.
61
  return (
62
  "Return ONLY a JSON array, enclosed between <JSON> and </JSON>.\n"
63
  f"Create a storyboard of {n_shots} shots for this idea:\n\n"
@@ -77,55 +76,38 @@ def _storyboard_prompt(user_prompt: str, n_shots: int, default_fps: int, default
77
  "Output:\n<JSON>\n[ { ... }, ... ]\n</JSON>\n"
78
  )
79
 
80
- def _extract_json_array(text: str) -> str:
81
- """
82
- Prefer <JSON>...</JSON>. Fallback: first balanced top-level JSON array.
83
- """
84
- m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE)
85
- if m:
86
- return m.group(1).strip()
87
-
88
- start = text.find("[")
89
- if start == -1:
90
- raise ValueError("No JSON array start '[' found in model output.")
91
- depth = 0
92
- for i in range(start, len(text)):
93
- ch = text[i]
94
- if ch == "[":
95
- depth += 1
96
- elif ch == "]":
97
- depth -= 1
98
- if depth == 0:
99
- return text[start:i+1]
100
- raise ValueError("Unbalanced JSON array in model output.")
101
-
102
- @spaces.GPU(duration=180) # ZeroGPU entrypoint
103
- def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
104
- """
105
- Chat-format prompt -> deterministic generation -> robust JSON parse.
106
- """
107
- model, tok = _lazy_model_tok()
108
-
109
- system = (
110
- "You are a film previsualization assistant. "
111
- "Return ONLY JSON inside <JSON>...</JSON>. No extra text."
112
  )
113
- user = _storyboard_prompt(user_prompt, n_shots, default_fps, default_len)
114
 
115
- # Use chat template if available for the model
116
  if hasattr(tok, "apply_chat_template"):
117
- prompt_text = tok.apply_chat_template(
118
- [{"role": "system", "content": system},
119
- {"role": "user", "content": user}],
120
  tokenize=False,
121
  add_generation_prompt=True
122
  )
123
- else:
124
- prompt_text = system + "\n\n" + user
125
 
 
126
  inputs = tok(prompt_text, return_tensors="pt")
127
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
128
-
129
  eos_id = tok.eos_token_id
130
  gen = model.generate(
131
  **inputs,
@@ -136,16 +118,40 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
136
  eos_token_id=eos_id,
137
  pad_token_id=eos_id,
138
  )
 
 
 
 
 
 
 
 
 
 
139
 
140
- out_text = tok.decode(gen[0], skip_special_tokens=True)
141
- # Trim the echoed prompt if present
142
- if out_text.startswith(prompt_text):
143
- out_text = out_text[len(prompt_text):]
144
-
145
- json_text = _extract_json_array(out_text)
146
- shots_raw = json.loads(json_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Normalize fields
149
  norm = []
150
  for i, s in enumerate(shots_raw, start=1):
151
  norm.append({
@@ -162,6 +168,51 @@ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: in
162
  })
163
  return norm
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # =========================
166
  # Gradio UI
167
  # =========================
 
33
  return proj
34
 
35
  # =========================
36
+ # LLM (ZeroGPU) — Storyboard generator (robust, two-pass)
37
  # =========================
38
  from transformers import AutoTokenizer, AutoModelForCausalLM
39
 
 
51
  _model = AutoModelForCausalLM.from_pretrained(
52
  STORYBOARD_MODEL,
53
  device_map="auto",
54
+ dtype="auto",
55
  trust_remote_code=True,
56
  )
57
  return _model, _tokenizer
58
 
59
+ def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
 
60
  return (
61
  "Return ONLY a JSON array, enclosed between <JSON> and </JSON>.\n"
62
  f"Create a storyboard of {n_shots} shots for this idea:\n\n"
 
76
  "Output:\n<JSON>\n[ { ... }, ... ]\n</JSON>\n"
77
  )
78
 
79
+ def _prompt_minimal(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
80
+ # Second attempt if tags fail: demand ONLY an array, nothing else.
81
+ return (
82
+ "Reply ONLY with a JSON array starting with '[' and ending with ']'. No extra text.\n"
83
+ f"Storyboard: {n_shots} shots for:\n'''{user_prompt}'''\n"
84
+ "Each item:\n"
85
+ "{\n"
86
+ ' \"id\": <int starting at 1>,\n'
87
+ ' \"title\": \"Short title\",\n'
88
+ ' \"description\": \"Visual description\",\n'
89
+ f" \"duration\": {default_len},\n"
90
+ f" \"fps\": {default_fps},\n"
91
+ f" \"video_length\": {default_len},\n"
92
+ " \"steps\": 30,\n"
93
+ " \"seed\": null,\n"
94
+ ' \"negative\": \"\"\n'
95
+ "}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
 
97
 
98
+ def _apply_chat(tok, system_msg: str, user_msg: str) -> str:
99
  if hasattr(tok, "apply_chat_template"):
100
+ return tok.apply_chat_template(
101
+ [{"role": "system", "content": system_msg},
102
+ {"role": "user", "content": user_msg}],
103
  tokenize=False,
104
  add_generation_prompt=True
105
  )
106
+ return system_msg + "\n\n" + user_msg
 
107
 
108
+ def _generate_text(model, tok, prompt_text: str) -> str:
109
  inputs = tok(prompt_text, return_tensors="pt")
110
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
111
  eos_id = tok.eos_token_id
112
  gen = model.generate(
113
  **inputs,
 
118
  eos_token_id=eos_id,
119
  pad_token_id=eos_id,
120
  )
121
+ text = tok.decode(gen[0], skip_special_tokens=True)
122
+ # Trim the echoed prompt if the model included it
123
+ if text.startswith(prompt_text):
124
+ text = text[len(prompt_text):]
125
+ # Strip code fences if any
126
+ text = text.strip()
127
+ if text.startswith("```"):
128
+ # remove ```json ... ```
129
+ text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.IGNORECASE|re.DOTALL).strip()
130
+ return text
131
 
132
+ def _extract_json_array(text: str) -> str:
133
+ # Prefer <JSON>...</JSON>
134
+ m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE)
135
+ if m:
136
+ inner = m.group(1).strip()
137
+ if inner:
138
+ return inner
139
+ # Fallback: balanced array
140
+ start = text.find("[")
141
+ if start == -1:
142
+ return "" # signal failure to caller
143
+ depth = 0
144
+ for i in range(start, len(text)):
145
+ ch = text[i]
146
+ if ch == "[":
147
+ depth += 1
148
+ elif ch == "]":
149
+ depth -= 1
150
+ if depth == 0:
151
+ return text[start:i+1].strip()
152
+ return "" # unbalanced
153
 
154
+ def _normalize_shots(shots_raw, default_fps: int, default_len: int):
155
  norm = []
156
  for i, s in enumerate(shots_raw, start=1):
157
  norm.append({
 
168
  })
169
  return norm
170
 
171
+ @spaces.GPU(duration=180)
172
+ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
173
+ """
174
+ Two-pass generation for robustness:
175
+ 1) <JSON>...</JSON>
176
+ 2) strict array-only fallback
177
+ """
178
+ model, tok = _lazy_model_tok()
179
+ system = "You are a film previsualization assistant. Output must be valid JSON."
180
+
181
+ # ---- PASS 1: with <JSON> tags
182
+ p1 = _apply_chat(tok, system + " Return ONLY JSON inside <JSON> tags.",
183
+ _prompt_with_tags(user_prompt, n_shots, default_fps, default_len))
184
+ out1 = _generate_text(model, tok, p1)
185
+ json_text = _extract_json_array(out1)
186
+
187
+ # ---- PASS 2: strict array (if needed)
188
+ if not json_text:
189
+ p2 = _apply_chat(tok, system + " Reply ONLY with a JSON array.",
190
+ _prompt_minimal(user_prompt, n_shots, default_fps, default_len))
191
+ out2 = _generate_text(model, tok, p2)
192
+ json_text = _extract_json_array(out2)
193
+
194
+ # As a last ditch, try bracket slice only
195
+ if not json_text:
196
+ start = out2.find("["); end = out2.rfind("]")
197
+ if start != -1 and end != -1 and end > start:
198
+ json_text = out2[start:end+1].strip()
199
+
200
+ if not json_text:
201
+ # Show a short preview so you can see what the model returned
202
+ preview = (out2[:400] + "...") if len(out2) > 400 else out2
203
+ raise ValueError(f"LLM did not return parseable JSON.\nPreview:\n{preview}")
204
+
205
+ # Parse & normalize
206
+ try:
207
+ shots_raw = json.loads(json_text)
208
+ except Exception as e:
209
+ # Attempt a tiny cleanup: remove trailing commas
210
+ json_text_clean = re.sub(r",\s*([\]\}])", r"\1", json_text)
211
+ shots_raw = json.loads(json_text_clean)
212
+
213
+ return _normalize_shots(shots_raw, default_fps, default_len)
214
+
215
+
216
  # =========================
217
  # Gradio UI
218
  # =========================