Shalmoni commited on
Commit
58c4d87
·
verified ·
1 Parent(s): 6b00576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -65
app.py CHANGED
@@ -1,7 +1,8 @@
1
- import os, json, uuid
2
  from datetime import datetime
3
  import gradio as gr
4
- import spaces # <<< required for ZeroGPU
 
5
 
6
  # =========================
7
  # Storage helpers
@@ -32,81 +33,121 @@ def load_project_file(file_obj):
32
  return proj
33
 
34
  # =========================
35
- # LLM (ZeroGPU) — Storyboard generator
36
  # =========================
37
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
38
 
39
  STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
40
- HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "900")) # keep tidy for JSON
41
- _pipe = None # lazy-loaded global
42
-
43
- def _lazy_pipe():
44
- global _pipe
45
- if _pipe is not None:
46
- return _pipe
47
- tok = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True)
48
- mdl = AutoModelForCausalLM.from_pretrained(
 
 
49
  STORYBOARD_MODEL,
50
  device_map="auto",
51
- torch_dtype="auto",
52
  trust_remote_code=True,
53
  )
54
- _pipe = pipeline(
55
- "text-generation",
56
- model=mdl,
57
- tokenizer=tok,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  max_new_tokens=HF_TASK_MAX_TOKENS,
59
- do_sample=False, # deterministic JSON
60
  temperature=0.0,
61
  repetition_penalty=1.05,
 
 
62
  )
63
- return _pipe
64
 
65
- def _storyboard_prompt(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
66
- return f"""
67
- You are a film previsualization assistant. Return ONLY valid JSON (no explanations).
68
- Create a storyboard of {n_shots} numbered shots for the following idea:
69
-
70
- \"\"\"{user_prompt}\"\"\"
71
-
72
- Return an array of objects with this exact schema and default values:
73
- [
74
- {{
75
- "id": 1,
76
- "title": "Short title",
77
- "description": "A visual description suitable for keyframe generation",
78
- "duration": {default_len},
79
- "fps": {default_fps},
80
- "video_length": {default_len},
81
- "steps": 30,
82
- "seed": null,
83
- "negative": ""
84
- }}
85
- ]
86
-
87
- Rules:
88
- - IDs must start at 1 and increment by 1.
89
- - Use simple ASCII only. No trailing commas.
90
- - Output must be valid JSON parseable by Python's json.loads.
91
- """.strip()
92
-
93
- @spaces.GPU(duration=180) # <<< ZeroGPU entrypoint: triggers pooled GPU allocation
94
- def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
95
- pipe = _lazy_pipe()
96
- prompt = _storyboard_prompt(user_prompt, n_shots, default_fps, default_len)
97
- out = pipe(prompt)[0]["generated_text"]
98
-
99
- # Extract the JSON array
100
- start = out.find("[")
101
- end = out.rfind("]")
102
- if start == -1 or end == -1 or end <= start:
103
- raise ValueError("LLM did not return valid JSON.")
104
- text = out[start:end+1]
105
- shots = json.loads(text)
106
-
107
- # Normalize & enforce required fields
108
  norm = []
109
- for i, s in enumerate(shots, start=1):
110
  norm.append({
111
  "id": int(s.get("id", i)),
112
  "title": s.get("title", f"Shot {i}"),
@@ -255,5 +296,4 @@ with gr.Blocks() as demo:
255
  )
256
 
257
  if __name__ == "__main__":
258
- # SSR is fine; you can set share=True if you want a public link automatically
259
  demo.launch()
 
1
+ import os, json, uuid, re
2
  from datetime import datetime
3
  import gradio as gr
4
+ import spaces # ZeroGPU decorator
5
+ import torch
6
 
7
  # =========================
8
  # Storage helpers
 
33
  return proj
34
 
35
  # =========================
36
+ # LLM (ZeroGPU) — Storyboard generator (robust JSON)
37
  # =========================
38
+ from transformers import AutoTokenizer, AutoModelForCausalLM
39
 
40
  STORYBOARD_MODEL = os.getenv("STORYBOARD_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
41
+ HF_TASK_MAX_TOKENS = int(os.getenv("HF_TASK_MAX_TOKENS", "900"))
42
+
43
+ _tokenizer = None
44
+ _model = None
45
+
46
+ def _lazy_model_tok():
47
+ global _tokenizer, _model
48
+ if _tokenizer is not None and _model is not None:
49
+ return _model, _tokenizer
50
+ _tokenizer = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True)
51
+ _model = AutoModelForCausalLM.from_pretrained(
52
  STORYBOARD_MODEL,
53
  device_map="auto",
54
+ dtype="auto", # prefer `dtype` (torch_dtype is deprecated)
55
  trust_remote_code=True,
56
  )
57
+ return _model, _tokenizer
58
+
59
+ def _storyboard_prompt(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
60
+ # Force the model to wrap JSON in tags; makes parsing deterministic.
61
+ return (
62
+ "Return ONLY a JSON array, enclosed between <JSON> and </JSON>.\n"
63
+ f"Create a storyboard of {n_shots} shots for this idea:\n\n"
64
+ f"'''{user_prompt}'''\n\n"
65
+ "Schema per item:\n"
66
+ "{\n"
67
+ ' \"id\": <int starting at 1>,\n'
68
+ ' \"title\": \"Short title\",\n'
69
+ ' \"description\": \"Visual description for keyframe generation\",\n'
70
+ f" \"duration\": {default_len},\n"
71
+ f" \"fps\": {default_fps},\n"
72
+ f" \"video_length\": {default_len},\n"
73
+ " \"steps\": 30,\n"
74
+ " \"seed\": null,\n"
75
+ ' \"negative\": \"\"\n'
76
+ "}\n\n"
77
+ "Output:\n<JSON>\n[ { ... }, ... ]\n</JSON>\n"
78
+ )
79
+
80
+ def _extract_json_array(text: str) -> str:
81
+ """
82
+ Prefer <JSON>...</JSON>. Fallback: first balanced top-level JSON array.
83
+ """
84
+ m = re.search(r"<JSON>(.*?)</JSON>", text, flags=re.DOTALL | re.IGNORECASE)
85
+ if m:
86
+ return m.group(1).strip()
87
+
88
+ start = text.find("[")
89
+ if start == -1:
90
+ raise ValueError("No JSON array start '[' found in model output.")
91
+ depth = 0
92
+ for i in range(start, len(text)):
93
+ ch = text[i]
94
+ if ch == "[":
95
+ depth += 1
96
+ elif ch == "]":
97
+ depth -= 1
98
+ if depth == 0:
99
+ return text[start:i+1]
100
+ raise ValueError("Unbalanced JSON array in model output.")
101
+
102
+ @spaces.GPU(duration=180) # ZeroGPU entrypoint
103
+ def generate_storyboard_with_llm(user_prompt: str, n_shots: int, default_fps: int, default_len: int):
104
+ """
105
+ Chat-format prompt -> deterministic generation -> robust JSON parse.
106
+ """
107
+ model, tok = _lazy_model_tok()
108
+
109
+ system = (
110
+ "You are a film previsualization assistant. "
111
+ "Return ONLY JSON inside <JSON>...</JSON>. No extra text."
112
+ )
113
+ user = _storyboard_prompt(user_prompt, n_shots, default_fps, default_len)
114
+
115
+ # Use chat template if available for the model
116
+ if hasattr(tok, "apply_chat_template"):
117
+ prompt_text = tok.apply_chat_template(
118
+ [{"role": "system", "content": system},
119
+ {"role": "user", "content": user}],
120
+ tokenize=False,
121
+ add_generation_prompt=True
122
+ )
123
+ else:
124
+ prompt_text = system + "\n\n" + user
125
+
126
+ inputs = tok(prompt_text, return_tensors="pt")
127
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
128
+
129
+ eos_id = tok.eos_token_id
130
+ gen = model.generate(
131
+ **inputs,
132
  max_new_tokens=HF_TASK_MAX_TOKENS,
133
+ do_sample=False,
134
  temperature=0.0,
135
  repetition_penalty=1.05,
136
+ eos_token_id=eos_id,
137
+ pad_token_id=eos_id,
138
  )
 
139
 
140
+ out_text = tok.decode(gen[0], skip_special_tokens=True)
141
+ # Trim the echoed prompt if present
142
+ if out_text.startswith(prompt_text):
143
+ out_text = out_text[len(prompt_text):]
144
+
145
+ json_text = _extract_json_array(out_text)
146
+ shots_raw = json.loads(json_text)
147
+
148
+ # Normalize fields
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  norm = []
150
+ for i, s in enumerate(shots_raw, start=1):
151
  norm.append({
152
  "id": int(s.get("id", i)),
153
  "title": s.get("title", f"Shot {i}"),
 
296
  )
297
 
298
  if __name__ == "__main__":
 
299
  demo.launch()