eeshaAI commited on
Commit
5ab6307
Β·
verified Β·
1 Parent(s): 893985c

Update app.py: add video generation tab

Browse files
Files changed (1) hide show
  1. app.py +326 -24
app.py CHANGED
@@ -1,17 +1,25 @@
1
  #!/usr/bin/env python3
2
  """
3
- Gradio App for EeshaAI/Zeeb Training Space
4
- ==========================================
5
- Auto-starts LoRA fine-tuning on Space boot.
6
- The UI shows real-time training progress from the log file.
7
  """
8
 
9
  import os
10
  import time
 
11
  import threading
 
12
  import gradio as gr
13
 
14
  LOG_FILE = "/tmp/training_log.txt"
 
 
 
 
 
 
15
 
16
 
17
  def start_training_background():
@@ -30,43 +38,337 @@ def get_log():
30
 
31
 
32
  def refresh_log():
33
- """Refresh button callback."""
34
  return get_log()
35
 
36
 
37
- # Auto-start training on Space boot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  training_thread = threading.Thread(target=start_training_background, daemon=True)
39
  training_thread.start()
40
 
41
 
 
42
  with gr.Blocks(
43
- title="Zeeb β€” Video-LLM Trainer",
44
  theme=gr.themes.Soft(),
45
  ) as demo:
46
 
47
  gr.Markdown(
48
  """
49
- # 🎬 Zeeb β€” Video-LLM Trainer
50
- Fine-tuning **OLMo 2 1B Instruct** with **LoRA (r=4)** to generate video tokens.
51
- Trained model will be pushed to [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb).
52
-
53
- Training **starts automatically** when this Space boots.
54
- Click **Refresh Log** to see progress.
55
  """
56
  )
57
 
58
- refresh_btn = gr.Button("πŸ”„ Refresh Log", variant="primary")
59
-
60
- logbox = gr.Textbox(
61
- label="Training Log",
62
- value=lambda: get_log(),
63
- lines=30,
64
- max_lines=200,
65
- interactive=False,
66
- show_copy_button=True,
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- refresh_btn.click(fn=refresh_log, outputs=logbox)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ Gradio App for EeshaAI/Zeeb β€” Training + Video Generation
4
+ ==========================================================
5
+ Tab 1: Training (auto-starts on boot)
6
+ Tab 2: Generate Video (loads trained model + VQ-VAE, generates video from prompt)
7
  """
8
 
9
  import os
10
  import time
11
+ import re
12
  import threading
13
+ import numpy as np
14
  import gradio as gr
15
 
16
  LOG_FILE = "/tmp/training_log.txt"
17
+ GENERATE_LOG = "/tmp/generation_log.txt"
18
+
19
+ # Global model cache
20
+ _model = None
21
+ _tokenizer = None
22
+ _vq_vae = None
23
 
24
 
25
  def start_training_background():
 
38
 
39
 
40
  def refresh_log():
 
41
  return get_log()
42
 
43
 
44
+ def load_models():
45
+ """Load the trained LLM and VQ-VAE decoder (lazy, cached)."""
46
+ global _model, _tokenizer, _vq_vae
47
+
48
+ if _model is not None and _tokenizer is not None:
49
+ return _model, _tokenizer, _vq_vae
50
+
51
+ import torch
52
+
53
+ # ── Load VQ-VAE decoder ─────────────────────────────────────────────
54
+ vq_vae_path = "vq_vae_final.pt"
55
+ if os.path.exists(vq_vae_path):
56
+ import torch.nn as nn
57
+
58
+ class VQVAEDecoderOnly(nn.Module):
59
+ """Minimal VQ-VAE decoder for token β†’ pixel decoding."""
60
+ def __init__(self, codebook_size=1024, codebook_dim=256, latent_dim=256):
61
+ super().__init__()
62
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
63
+ self.proj = nn.Linear(codebook_dim, latent_dim)
64
+ # Decoder: upscale from 8x8 spatial to 64x64
65
+ self.decoder = nn.Sequential(
66
+ nn.ConvTranspose2d(latent_dim, 128, 4, stride=2, padding=1), # 8β†’16
67
+ nn.ReLU(),
68
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 16β†’32
69
+ nn.ReLU(),
70
+ nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 32β†’64
71
+ nn.ReLU(),
72
+ nn.Conv2d(32, 3, 3, padding=1),
73
+ nn.Sigmoid(),
74
+ )
75
+
76
+ def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
77
+ """Decode a flat list of token IDs into a video frame."""
78
+ # token_ids: list of ints, length should be grid_h * grid_w
79
+ tokens = torch.tensor(token_ids[:grid_h * grid_w], dtype=torch.long)
80
+ if len(tokens) < grid_h * grid_w:
81
+ tokens = torch.cat([tokens, torch.zeros(grid_h * grid_w - len(tokens), dtype=torch.long)])
82
+
83
+ # Lookup codebook
84
+ z = self.codebook(tokens) # [H*W, D]
85
+ z = self.proj(z) # [H*W, latent_dim]
86
+ z = z.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2) # [1, C, H, W]
87
+
88
+ # Decode
89
+ frame = self.decoder(z) # [1, 3, 64, 64]
90
+ return frame
91
+
92
+ _vq_vae = VQVAEDecoderOnly()
93
+ state = torch.load(vq_vae_path, map_location="cpu", weights_only=False)
94
+ # Try to load relevant weights
95
+ if isinstance(state, dict):
96
+ if "codebook" in state or "state_dict" in state:
97
+ # Full checkpoint
98
+ sd = state.get("state_dict", state)
99
+ filtered = {k: v for k, v in sd.items() if not k.startswith("encoder")}
100
+ _vq_vae.load_state_dict(filtered, strict=False)
101
+ elif "model_state_dict" in state:
102
+ _vq_vae.load_state_dict(state["model_state_dict"], strict=False)
103
+ else:
104
+ _vq_vae.load_state_dict(state, strict=False)
105
+ print("βœ… VQ-VAE decoder loaded")
106
+
107
+ # ── Load trained LLM ────────────────────────────────────────────────
108
+ from transformers import AutoModelForCausalLM, AutoTokenizer
109
+
110
+ REPO_ID = "eeshaAI/zeeb"
111
+
112
+ print("πŸ“¦ Loading trained model from EeshaAI/zeeb...")
113
+ _tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True)
114
+ if _tokenizer.pad_token is None:
115
+ _tokenizer.pad_token = _tokenizer.eos_token
116
+
117
+ _model = AutoModelForCausalLM.from_pretrained(
118
+ REPO_ID,
119
+ trust_remote_code=True,
120
+ torch_dtype=torch.float32,
121
+ )
122
+ _model.eval()
123
+ print(f"βœ… Model loaded. Vocab size: {len(_tokenizer)}")
124
+
125
+ return _model, _tokenizer, _vq_vae
126
+
127
+
128
+ def generate_video(prompt: str, max_tokens: int = 128):
129
+ """Generate video from a text prompt using the trained LLM + VQ-VAE."""
130
+ import torch
131
+
132
+ log_lines = []
133
+ log_lines.append(f"🎬 Generating video for: '{prompt}'\n\n")
134
+
135
+ try:
136
+ # Load models
137
+ log_lines.append("πŸ“¦ Loading trained model + VQ-VAE...\n")
138
+ model, tokenizer, vq_vae = load_models()
139
+ log_lines.append("βœ… Models loaded.\n\n")
140
+ except Exception as e:
141
+ log_lines.append(f"❌ Failed to load models: {e}\n")
142
+ return None, "\n".join(log_lines)
143
+
144
+ # ── Format prompt ──────────────────────────────────────────────────
145
+ text = f"Create a video of: {prompt} <video_start>"
146
+ log_lines.append(f"πŸ“ Prompt formatted:\n {text}\n\n")
147
+
148
+ # ── Generate tokens ────────────────────────────────────────────────
149
+ log_lines.append("πŸ”₯ Generating visual tokens...\n")
150
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
151
+
152
+ with torch.no_grad():
153
+ output_ids = model.generate(
154
+ **inputs,
155
+ max_new_tokens=max_tokens,
156
+ do_sample=True,
157
+ temperature=0.8,
158
+ top_p=0.9,
159
+ pad_token_id=tokenizer.pad_token_id,
160
+ )
161
+
162
+ # Decode the full output
163
+ full_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
164
+ log_lines.append(f"πŸ“€ Raw output length: {len(full_text)} chars\n")
165
+
166
+ # Extract visual tokens between <video_start> and <video_end>
167
+ visual_token_ids = []
168
+ in_video = False
169
+
170
+ for token_id in output_ids[0].tolist():
171
+ decoded = tokenizer.decode([token_id])
172
+ if "<video_start>" in decoded:
173
+ in_video = True
174
+ continue
175
+ if "<video_end>" in decoded:
176
+ in_video = False
177
+ break
178
+ if in_video:
179
+ # Check if it's a <v_N> token
180
+ match = re.match(r"<v_(\d+)>", decoded.strip())
181
+ if match:
182
+ visual_token_ids.append(int(match.group(1)))
183
+
184
+ log_lines.append(f"🎨 Extracted {len(visual_token_ids)} visual tokens\n")
185
+
186
+ if not visual_token_ids:
187
+ log_lines.append("⚠️ No visual tokens generated! The model may need more training.\n")
188
+ log_lines.append(f"\nFull output:\n{full_text}\n")
189
+ # Try alternative: parse from full_text
190
+ all_v_tokens = re.findall(r"<v_(\d+)>", full_text)
191
+ if all_v_tokens:
192
+ visual_token_ids = [int(t) for t in all_v_tokens]
193
+ log_lines.append(f"\nπŸ”„ Alternative extraction found {len(visual_token_ids)} tokens\n")
194
+ else:
195
+ return None, "\n".join(log_lines)
196
+
197
+ # Show sample of tokens
198
+ sample_tokens = visual_token_ids[:20]
199
+ log_lines.append(f" Sample tokens: {sample_tokens}\n")
200
+ log_lines.append(f" Unique tokens: {len(set(visual_token_ids))}\n\n")
201
+
202
+ # ── Decode to video frames ──────────────────────────────────────────
203
+ log_lines.append("🎞️ Decoding tokens β†’ video frames via VQ-VAE...\n")
204
+
205
+ grid_h, grid_w = 8, 8
206
+ tokens_per_frame = grid_h * grid_w # 64
207
+ num_frames = max(1, len(visual_token_ids) // tokens_per_frame)
208
+ log_lines.append(f" Grid: {grid_h}Γ—{grid_w} = {tokens_per_frame} tokens/frame\n")
209
+ log_lines.append(f" Frames: {num_frames}\n\n")
210
+
211
+ frames = []
212
+
213
+ if vq_vae is not None:
214
+ for frame_idx in range(num_frames):
215
+ start = frame_idx * tokens_per_frame
216
+ end = start + tokens_per_frame
217
+ frame_tokens = visual_token_ids[start:end]
218
+
219
+ try:
220
+ frame_tensor = vq_vae.decode_tokens(frame_tokens, grid_h, grid_w)
221
+ # Convert to numpy: [1, 3, 64, 64] β†’ [64, 64, 3] uint8
222
+ frame_np = (frame_tensor[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
223
+ frames.append(frame_np)
224
+ except Exception as e:
225
+ log_lines.append(f" ⚠️ Frame {frame_idx} decode error: {e}\n")
226
+ # Fallback: create frame from token values as colors
227
+ frame_np = np.zeros((64, 64, 3), dtype=np.uint8)
228
+ for i, t in enumerate(frame_tokens[:tokens_per_frame]):
229
+ row, col = divmod(i, grid_w)
230
+ cell_h, cell_w = 64 // grid_h, 64 // grid_w
231
+ if row < grid_h and col < grid_w:
232
+ # Use token value as a color
233
+ r = (t * 37) % 256
234
+ g = (t * 73) % 256
235
+ b = (t * 113) % 256
236
+ frame_np[row*cell_h:(row+1)*cell_h, col*cell_w:(col+1)*cell_w] = [r, g, b]
237
+ frames.append(frame_np)
238
+ else:
239
+ # No VQ-VAE: create frames from token values as colored blocks
240
+ log_lines.append(" ⚠️ No VQ-VAE, using tokenβ†’color mapping\n")
241
+ for frame_idx in range(num_frames):
242
+ start = frame_idx * tokens_per_frame
243
+ end = start + tokens_per_frame
244
+ frame_tokens = visual_token_ids[start:end]
245
+ frame_np = np.zeros((64, 64, 3), dtype=np.uint8)
246
+ for i, t in enumerate(frame_tokens[:tokens_per_frame]):
247
+ row, col = divmod(i, grid_w)
248
+ cell_h, cell_w = 64 // grid_h, 64 // grid_w
249
+ if row < grid_h and col < grid_w:
250
+ r = (t * 37) % 256
251
+ g = (t * 73) % 256
252
+ b = (t * 113) % 256
253
+ frame_np[row*cell_h:(row+1)*cell_h, col*cell_w:(col+1)*cell_w] = [r, g, b]
254
+ frames.append(frame_np)
255
+
256
+ if not frames:
257
+ log_lines.append("❌ No frames generated!\n")
258
+ return None, "\n".join(log_lines)
259
+
260
+ # ── Save as video ──────────────────────────────────────────────────
261
+ log_lines.append(f"πŸ’Ύ Saving {len(frames)} frames as video...\n")
262
+
263
+ try:
264
+ import imageio
265
+ output_path = "/tmp/generated_video.mp4"
266
+ # Upscale frames from 64x64 to 256x256 for better visibility
267
+ from PIL import Image
268
+ upscaled = []
269
+ for f in frames:
270
+ img = Image.fromarray(f)
271
+ img = img.resize((256, 256), Image.NEAREST)
272
+ upscaled.append(np.array(img))
273
+
274
+ # Save as mp4 (2 fps for slow playback since we have few frames)
275
+ imageio.mimsave(output_path, upscaled, fps=2)
276
+ log_lines.append(f"βœ… Video saved to {output_path}\n")
277
+ log_lines.append(f" Resolution: 256Γ—256\n")
278
+ log_lines.append(f" Frames: {len(upscaled)}\n")
279
+ log_lines.append(f" FPS: 2\n\n")
280
+ log_lines.append("πŸŽ‰ Video generation complete!\n")
281
+ return output_path, "\n".join(log_lines)
282
+ except ImportError:
283
+ # Fallback: save as GIF
284
+ try:
285
+ from PIL import Image
286
+ output_path = "/tmp/generated_video.gif"
287
+ pil_frames = [Image.fromarray(f).resize((256, 256), Image.NEAREST) for f in frames]
288
+ pil_frames[0].save(
289
+ output_path,
290
+ save_all=True,
291
+ append_images=pil_frames[1:],
292
+ duration=500,
293
+ loop=0,
294
+ )
295
+ log_lines.append(f"βœ… GIF saved to {output_path}\n")
296
+ return output_path, "\n".join(log_lines)
297
+ except Exception as e:
298
+ log_lines.append(f"❌ Failed to save video: {e}\n")
299
+ # Return first frame as image at least
300
+ img_path = "/tmp/generated_frame.png"
301
+ Image.fromarray(frames[0]).resize((256, 256), Image.NEAREST).save(img_path)
302
+ log_lines.append(f"πŸ“Έ Saved single frame to {img_path}\n")
303
+ return img_path, "\n".join(log_lines)
304
+ except Exception as e:
305
+ log_lines.append(f"❌ Video save error: {e}\n")
306
+ return None, "\n".join(log_lines)
307
+
308
+
309
+ # ── Auto-start training on boot ────────────────────────────────────────
310
  training_thread = threading.Thread(target=start_training_background, daemon=True)
311
  training_thread.start()
312
 
313
 
314
+ # ── Gradio UI ──────────────────────────────────────────────────────────
315
  with gr.Blocks(
316
+ title="Zeeb β€” Video-LLM",
317
  theme=gr.themes.Soft(),
318
  ) as demo:
319
 
320
  gr.Markdown(
321
  """
322
+ # 🎬 Zeeb β€” Video-LLM
323
+ **OLMo 2 1B Instruct** fine-tuned with **LoRA** to generate video tokens.
324
+ Model repo: [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
 
 
 
325
  """
326
  )
327
 
328
+ with gr.Tabs():
329
+ # ── Tab 1: Generate Video ───────────────────────────────────────
330
+ with gr.Tab("🎬 Generate Video"):
331
+ prompt_input = gr.Textbox(
332
+ label="Video Description",
333
+ placeholder="A cat jumping on a sofa",
334
+ lines=2,
335
+ )
336
+ max_tokens_slider = gr.Slider(
337
+ minimum=32, maximum=256, value=128, step=32,
338
+ label="Max Visual Tokens",
339
+ )
340
+ generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
341
+ video_output = gr.Video(label="Generated Video")
342
+ gen_log = gr.Textbox(
343
+ label="Generation Log",
344
+ lines=20,
345
+ interactive=False,
346
+ show_copy_button=True,
347
+ )
348
+ generate_btn.click(
349
+ fn=generate_video,
350
+ inputs=[prompt_input, max_tokens_slider],
351
+ outputs=[video_output, gen_log],
352
+ )
353
 
354
+ # ── Tab 2: Training ─────────────────────────────────────────────
355
+ with gr.Tab("πŸ”§ Training"):
356
+ gr.Markdown(
357
+ """
358
+ Training **starts automatically** when this Space boots.
359
+ Click **Refresh Log** to see progress.
360
+ """
361
+ )
362
+ refresh_btn = gr.Button("πŸ”„ Refresh Log")
363
+ logbox = gr.Textbox(
364
+ label="Training Log",
365
+ value=lambda: get_log(),
366
+ lines=25,
367
+ max_lines=200,
368
+ interactive=False,
369
+ show_copy_button=True,
370
+ )
371
+ refresh_btn.click(fn=refresh_log, outputs=logbox)
372
 
373
 
374
  if __name__ == "__main__":