eeshaAI commited on
Commit
7eab64f
Β·
verified Β·
1 Parent(s): 83a2068

Add constrained decoding: force visual tokens after <video_start>

Browse files
Files changed (1) hide show
  1. app.py +86 -62
app.py CHANGED
@@ -4,6 +4,8 @@ Gradio App for EeshaAI/Zeeb β€” Video Generation
4
  ================================================
5
  Uses the trained OLMo 2 1B + LoRA model to generate video tokens,
6
  then decodes them via VQ-VAE into a video file.
 
 
7
  """
8
 
9
  import os
@@ -18,10 +20,17 @@ _tokenizer = None
18
  _vq_vae = None
19
  _loading_lock = threading.Lock()
20
 
 
 
 
 
 
 
21
 
22
  def load_models():
23
  """Load the trained LLM and VQ-VAE decoder (lazy, cached)."""
24
  global _model, _tokenizer, _vq_vae
 
25
 
26
  with _loading_lock:
27
  if _model is not None and _tokenizer is not None:
@@ -91,18 +100,27 @@ def load_models():
91
  _model.eval()
92
  print(f"βœ… Model loaded. Vocab size: {len(_tokenizer)}")
93
 
 
 
 
 
 
 
 
 
94
  return _model, _tokenizer, _vq_vae
95
 
96
 
97
  def generate_video(prompt: str, max_tokens: int = 128):
98
- """Generate video from a text prompt using the trained LLM + VQ-VAE."""
99
  import torch
 
100
 
101
  log_lines = []
102
  log_lines.append(f"🎬 Generating video for: '{prompt}'\n\n")
103
 
104
  try:
105
- log_lines.append("πŸ“¦ Loading trained model + VQ-VAE (first run takes ~3 min)...\n")
106
  model, tokenizer, vq_vae = load_models()
107
  log_lines.append("βœ… Models loaded.\n\n")
108
  except Exception as e:
@@ -115,60 +133,73 @@ def generate_video(prompt: str, max_tokens: int = 128):
115
  text = f"Create a video of: {prompt} <video_start>"
116
  log_lines.append(f"πŸ“ Prompt: {text}\n\n")
117
 
118
- # ── Generate tokens ────────────────────────────────────────────────
119
- log_lines.append("πŸ”₯ Generating visual tokens...\n")
 
 
 
120
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  with torch.no_grad():
123
- output_ids = model.generate(
124
- **inputs,
125
- max_new_tokens=max_tokens,
126
- do_sample=True,
127
- temperature=0.8,
128
- top_p=0.9,
129
- pad_token_id=tokenizer.pad_token_id,
130
- )
131
 
132
- # Decode the full output
133
- full_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
134
- log_lines.append(f"πŸ“€ Raw output length: {len(full_text)} chars\n")
135
 
136
- # Extract visual tokens between <video_start> and <video_end>
137
- visual_token_ids = []
138
- in_video = False
139
-
140
- for token_id in output_ids[0].tolist():
141
- decoded = tokenizer.decode([token_id])
142
- if "<video_start>" in decoded:
143
- in_video = True
144
- continue
145
- if "<video_end>" in decoded:
146
- in_video = False
147
- break
148
- if in_video:
149
- match = re.match(r"<v_(\d+)>", decoded.strip())
150
- if match:
151
- visual_token_ids.append(int(match.group(1)))
152
-
153
- log_lines.append(f"🎨 Extracted {len(visual_token_ids)} visual tokens\n")
 
 
 
 
 
154
 
155
  if not visual_token_ids:
156
- log_lines.append("⚠️ No visual tokens in structured format. Trying regex on full output...\n")
157
- all_v_tokens = re.findall(r"<v_(\d+)>", full_text)
158
- if all_v_tokens:
159
- visual_token_ids = [int(t) for t in all_v_tokens]
160
- log_lines.append(f"πŸ”„ Regex found {len(visual_token_ids)} tokens\n")
161
- else:
162
- log_lines.append("⚠️ No visual tokens at all. Showing raw output:\n")
163
- log_lines.append(f"\n{full_text[:1000]}\n")
164
- return None, "\n".join(log_lines)
165
 
166
  sample_tokens = visual_token_ids[:20]
167
  log_lines.append(f" Sample tokens: {sample_tokens}\n")
168
- log_lines.append(f" Unique tokens: {len(set(visual_token_ids))}\n\n")
 
169
 
170
  # ── Decode to video frames ──────────────────────────────────────────
171
- log_lines.append("🎞️ Decoding tokens β†’ video frames...\n")
172
 
173
  grid_h, grid_w = 8, 8
174
  tokens_per_frame = grid_h * grid_w # 64
@@ -178,26 +209,20 @@ def generate_video(prompt: str, max_tokens: int = 128):
178
 
179
  frames = []
180
 
181
- if vq_vae is not None:
182
- for frame_idx in range(num_frames):
183
- start = frame_idx * tokens_per_frame
184
- end = start + tokens_per_frame
185
- frame_tokens = visual_token_ids[start:end]
 
186
  try:
187
  frame_tensor = vq_vae.decode_tokens(frame_tokens, grid_h, grid_w)
188
  frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
189
  frames.append(frame_np)
190
  except Exception as e:
191
- log_lines.append(f" ⚠️ Frame {frame_idx} decode error: {e}\n")
192
- # Fallback: color blocks
193
- frame_np = _tokens_to_color_blocks(frame_tokens, grid_h, grid_w)
194
- frames.append(frame_np)
195
- else:
196
- log_lines.append(" ⚠️ No VQ-VAE, using tokenβ†’color mapping\n")
197
- for frame_idx in range(num_frames):
198
- start = frame_idx * tokens_per_frame
199
- end = start + tokens_per_frame
200
- frame_tokens = visual_token_ids[start:end]
201
  frames.append(_tokens_to_color_blocks(frame_tokens, grid_h, grid_w))
202
 
203
  if not frames:
@@ -224,7 +249,6 @@ def generate_video(prompt: str, max_tokens: int = 128):
224
  imageio.mimsave(output_path, upscaled, fps=2)
225
  log_lines.append(f"βœ… Video saved as MP4: {output_path}\n")
226
  except Exception:
227
- # Fallback to GIF
228
  output_path = "/tmp/generated_video.gif"
229
  pil_frames = [Image.fromarray(f) for f in upscaled]
230
  pil_frames[0].save(
@@ -285,7 +309,7 @@ with gr.Blocks(
285
  **OLMo 2 1B Instruct** fine-tuned with **LoRA (r=4)** to generate video tokens.
286
  Model: [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
287
 
288
- Type a description and click Generate!
289
  """
290
  )
291
 
@@ -303,7 +327,7 @@ with gr.Blocks(
303
  video_output = gr.Video(label="Generated Video")
304
  gen_log = gr.Textbox(
305
  label="Generation Log",
306
- lines=20,
307
  interactive=False,
308
  show_copy_button=True,
309
  )
 
4
  ================================================
5
  Uses the trained OLMo 2 1B + LoRA model to generate video tokens,
6
  then decodes them via VQ-VAE into a video file.
7
+
8
+ Uses constrained decoding: after <video_start>, only <v_N> tokens are allowed.
9
  """
10
 
11
  import os
 
20
  _vq_vae = None
21
  _loading_lock = threading.Lock()
22
 
23
+ # Visual token ID range (from tokenizer: <v_0>=100281, <v_1023>=101304)
24
+ VIDEO_START_ID = None # Will be set after tokenizer loads
25
+ VIDEO_END_ID = None
26
+ V_TOKEN_START_ID = None
27
+ V_TOKEN_END_ID = None
28
+
29
 
30
  def load_models():
31
  """Load the trained LLM and VQ-VAE decoder (lazy, cached)."""
32
  global _model, _tokenizer, _vq_vae
33
+ global VIDEO_START_ID, VIDEO_END_ID, V_TOKEN_START_ID, V_TOKEN_END_ID
34
 
35
  with _loading_lock:
36
  if _model is not None and _tokenizer is not None:
 
100
  _model.eval()
101
  print(f"βœ… Model loaded. Vocab size: {len(_tokenizer)}")
102
 
103
+ # Set visual token ID ranges
104
+ VIDEO_START_ID = _tokenizer.convert_tokens_to_ids("<video_start>")
105
+ VIDEO_END_ID = _tokenizer.convert_tokens_to_ids("<video_end>")
106
+ V_TOKEN_START_ID = _tokenizer.convert_tokens_to_ids("<v_0>")
107
+ V_TOKEN_END_ID = _tokenizer.convert_tokens_to_ids("<v_1023>")
108
+ print(f" <video_start>={VIDEO_START_ID}, <video_end>={VIDEO_END_ID}")
109
+ print(f" <v_0>={V_TOKEN_START_ID}, <v_1023>={V_TOKEN_END_ID}")
110
+
111
  return _model, _tokenizer, _vq_vae
112
 
113
 
114
  def generate_video(prompt: str, max_tokens: int = 128):
115
+ """Generate video from a text prompt using constrained decoding + VQ-VAE."""
116
  import torch
117
+ import torch.nn.functional as F
118
 
119
  log_lines = []
120
  log_lines.append(f"🎬 Generating video for: '{prompt}'\n\n")
121
 
122
  try:
123
+ log_lines.append("πŸ“¦ Loading trained model + VQ-VAE...\n")
124
  model, tokenizer, vq_vae = load_models()
125
  log_lines.append("βœ… Models loaded.\n\n")
126
  except Exception as e:
 
133
  text = f"Create a video of: {prompt} <video_start>"
134
  log_lines.append(f"πŸ“ Prompt: {text}\n\n")
135
 
136
+ # ── Constrained token generation ────────────────────────────────────
137
+ # After <video_start>, we FORCE the model to only pick from <v_0>...<v_1023>
138
+ # This is done by masking the logits at each step
139
+ log_lines.append("πŸ”₯ Generating visual tokens (constrained decoding)...\n")
140
+
141
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
142
+ input_ids = inputs["input_ids"]
143
+
144
+ visual_token_ids = []
145
+ current_ids = input_ids.clone()
146
+
147
+ # Create a mask that only allows visual token IDs
148
+ vocab_size = len(tokenizer)
149
+ visual_mask = torch.zeros(vocab_size, dtype=torch.bool)
150
+ visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True
151
+ # Also allow <video_end> so the model can stop
152
+ visual_mask[VIDEO_END_ID] = True
153
 
154
  with torch.no_grad():
155
+ for step in range(max_tokens):
156
+ # Forward pass
157
+ outputs = model(input_ids=current_ids)
158
+ next_token_logits = outputs.logits[:, -1, :] # [1, vocab_size]
 
 
 
 
159
 
160
+ # Apply constraint: only allow visual tokens + <video_end>
161
+ masked_logits = next_token_logits.clone()
162
+ masked_logits[0, ~visual_mask] = float('-inf')
163
 
164
+ # Sample from the constrained distribution
165
+ probs = F.softmax(masked_logits / 0.8, dim=-1) # temperature=0.8
166
+
167
+ # Check if <video_end> has high probability
168
+ end_prob = probs[0, VIDEO_END_ID].item()
169
+
170
+ # Sample
171
+ next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
172
+ next_id = next_token.item()
173
+
174
+ # If the model chose <video_end>, stop
175
+ if next_id == VIDEO_END_ID:
176
+ log_lines.append(f" Model chose <video_end> at step {step} (end_prob={end_prob:.4f})\n")
177
+ break
178
+
179
+ # Convert token ID to visual token index
180
+ visual_idx = next_id - V_TOKEN_START_ID
181
+ visual_token_ids.append(visual_idx)
182
+
183
+ # Append to sequence
184
+ current_ids = torch.cat([current_ids, next_token], dim=-1)
185
+
186
+ log_lines.append(f"🎨 Generated {len(visual_token_ids)} visual tokens\n")
187
 
188
  if not visual_token_ids:
189
+ log_lines.append("⚠️ No visual tokens generated even with constrained decoding.\n")
190
+ log_lines.append(" Falling back to random token sampling from VQ-VAE codebook.\n")
191
+ # Fallback: generate random visual tokens
192
+ import random
193
+ visual_token_ids = [random.randint(0, 1023) for _ in range(64)]
194
+ log_lines.append(f" Generated {len(visual_token_ids)} random tokens as fallback\n")
 
 
 
195
 
196
  sample_tokens = visual_token_ids[:20]
197
  log_lines.append(f" Sample tokens: {sample_tokens}\n")
198
+ unique = len(set(visual_token_ids))
199
+ log_lines.append(f" Unique tokens: {unique} / {len(visual_token_ids)}\n\n")
200
 
201
  # ── Decode to video frames ──────────────────────────────────────────
202
+ log_lines.append("🎞️ Decoding tokens β†’ video frames via VQ-VAE...\n")
203
 
204
  grid_h, grid_w = 8, 8
205
  tokens_per_frame = grid_h * grid_w # 64
 
209
 
210
  frames = []
211
 
212
+ for frame_idx in range(num_frames):
213
+ start_idx = frame_idx * tokens_per_frame
214
+ end_idx = start_idx + tokens_per_frame
215
+ frame_tokens = visual_token_ids[start_idx:end_idx]
216
+
217
+ if vq_vae is not None:
218
  try:
219
  frame_tensor = vq_vae.decode_tokens(frame_tokens, grid_h, grid_w)
220
  frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
221
  frames.append(frame_np)
222
  except Exception as e:
223
+ log_lines.append(f" ⚠️ Frame {frame_idx} VQ-VAE error: {e}, using color blocks\n")
224
+ frames.append(_tokens_to_color_blocks(frame_tokens, grid_h, grid_w))
225
+ else:
 
 
 
 
 
 
 
226
  frames.append(_tokens_to_color_blocks(frame_tokens, grid_h, grid_w))
227
 
228
  if not frames:
 
249
  imageio.mimsave(output_path, upscaled, fps=2)
250
  log_lines.append(f"βœ… Video saved as MP4: {output_path}\n")
251
  except Exception:
 
252
  output_path = "/tmp/generated_video.gif"
253
  pil_frames = [Image.fromarray(f) for f in upscaled]
254
  pil_frames[0].save(
 
309
  **OLMo 2 1B Instruct** fine-tuned with **LoRA (r=4)** to generate video tokens.
310
  Model: [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
311
 
312
+ Uses **constrained decoding** β€” after `<video_start>`, only visual tokens are allowed.
313
  """
314
  )
315
 
 
327
  video_output = gr.Video(label="Generated Video")
328
  gen_log = gr.Textbox(
329
  label="Generation Log",
330
+ lines=25,
331
  interactive=False,
332
  show_copy_button=True,
333
  )