eeshaAI commited on
Commit
6e8dde1
Β·
verified Β·
1 Parent(s): 7eab64f

Update app.py: full training pipeline with real datasets

Browse files
Files changed (1) hide show
  1. app.py +151 -218
app.py CHANGED
@@ -1,11 +1,9 @@
1
  #!/usr/bin/env python3
2
  """
3
- Gradio App for EeshaAI/Zeeb β€” Video Generation
4
- ================================================
5
- Uses the trained OLMo 2 1B + LoRA model to generate video tokens,
6
- then decodes them via VQ-VAE into a video file.
7
-
8
- Uses constrained decoding: after <video_start>, only <v_N> tokens are allowed.
9
  """
10
 
11
  import os
@@ -14,14 +12,16 @@ import threading
14
  import numpy as np
15
  import gradio as gr
16
 
 
 
17
  # Global model cache
18
  _model = None
19
  _tokenizer = None
20
  _vq_vae = None
21
  _loading_lock = threading.Lock()
22
 
23
- # Visual token ID range (from tokenizer: <v_0>=100281, <v_1023>=101304)
24
- VIDEO_START_ID = None # Will be set after tokenizer loads
25
  VIDEO_END_ID = None
26
  V_TOKEN_START_ID = None
27
  V_TOKEN_END_ID = None
@@ -37,53 +37,49 @@ def load_models():
37
  return _model, _tokenizer, _vq_vae
38
 
39
  import torch
40
-
41
- # ── Load VQ-VAE decoder ─────────────────────────────────────────
42
- vq_vae_path = "vq_vae_final.pt"
43
- if os.path.exists(vq_vae_path):
44
- import torch.nn as nn
45
-
46
- class VQVAEDecoderOnly(nn.Module):
47
- """Minimal VQ-VAE decoder for token β†’ pixel decoding."""
48
- def __init__(self, codebook_size=1024, codebook_dim=256, latent_dim=256):
49
- super().__init__()
50
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
51
- self.proj = nn.Linear(codebook_dim, latent_dim)
52
- self.decoder = nn.Sequential(
53
- nn.ConvTranspose2d(latent_dim, 128, 4, stride=2, padding=1),
54
- nn.ReLU(),
55
- nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
56
- nn.ReLU(),
57
- nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
58
- nn.ReLU(),
59
- nn.Conv2d(32, 3, 3, padding=1),
60
- nn.Sigmoid(),
61
- )
62
-
63
- def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
64
- tokens = torch.tensor(token_ids[:grid_h * grid_w], dtype=torch.long)
65
- if len(tokens) < grid_h * grid_w:
66
- tokens = torch.cat([tokens, torch.zeros(grid_h * grid_w - len(tokens), dtype=torch.long)])
67
- z = self.codebook(tokens)
68
- z = self.proj(z)
69
- z = z.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)
70
- frame = self.decoder(z)
71
- return frame
72
-
73
- _vq_vae = VQVAEDecoderOnly()
74
- state = torch.load(vq_vae_path, map_location="cpu", weights_only=False)
75
- if isinstance(state, dict):
76
- if "state_dict" in state:
77
- sd = state["state_dict"]
78
- elif "model_state_dict" in state:
79
- sd = state["model_state_dict"]
80
- else:
81
- sd = state
82
  filtered = {k: v for k, v in sd.items() if not k.startswith("encoder")}
83
  _vq_vae.load_state_dict(filtered, strict=False)
84
- print("βœ… VQ-VAE decoder loaded")
 
 
 
 
 
 
85
 
86
- # ── Load trained LLM ────────────────────────────────────────────
87
  from transformers import AutoModelForCausalLM, AutoTokenizer
88
 
89
  REPO_ID = "eeshaAI/zeeb"
@@ -93,250 +89,187 @@ def load_models():
93
  _tokenizer.pad_token = _tokenizer.eos_token
94
 
95
  _model = AutoModelForCausalLM.from_pretrained(
96
- REPO_ID,
97
- trust_remote_code=True,
98
- torch_dtype=torch.float32,
99
  )
100
  _model.eval()
101
- print(f"βœ… Model loaded. Vocab size: {len(_tokenizer)}")
102
 
103
- # Set visual token ID ranges
104
  VIDEO_START_ID = _tokenizer.convert_tokens_to_ids("<video_start>")
105
  VIDEO_END_ID = _tokenizer.convert_tokens_to_ids("<video_end>")
106
  V_TOKEN_START_ID = _tokenizer.convert_tokens_to_ids("<v_0>")
107
  V_TOKEN_END_ID = _tokenizer.convert_tokens_to_ids("<v_1023>")
108
- print(f" <video_start>={VIDEO_START_ID}, <video_end>={VIDEO_END_ID}")
109
- print(f" <v_0>={V_TOKEN_START_ID}, <v_1023>={V_TOKEN_END_ID}")
110
 
111
  return _model, _tokenizer, _vq_vae
112
 
113
 
114
- def generate_video(prompt: str, max_tokens: int = 128):
115
  """Generate video from a text prompt using constrained decoding + VQ-VAE."""
116
  import torch
117
  import torch.nn.functional as F
118
 
119
- log_lines = []
120
- log_lines.append(f"🎬 Generating video for: '{prompt}'\n\n")
121
 
122
  try:
123
- log_lines.append("πŸ“¦ Loading trained model + VQ-VAE...\n")
124
  model, tokenizer, vq_vae = load_models()
125
- log_lines.append("βœ… Models loaded.\n\n")
126
  except Exception as e:
127
- import traceback
128
- log_lines.append(f"❌ Failed to load models: {e}\n")
129
- log_lines.append(traceback.format_exc())
130
- return None, "\n".join(log_lines)
131
 
132
- # ── Format prompt ──────────────────────────────────────────────────
133
  text = f"Create a video of: {prompt} <video_start>"
134
- log_lines.append(f"πŸ“ Prompt: {text}\n\n")
135
-
136
- # ── Constrained token generation ────────────────────────────────────
137
- # After <video_start>, we FORCE the model to only pick from <v_0>...<v_1023>
138
- # This is done by masking the logits at each step
139
- log_lines.append("πŸ”₯ Generating visual tokens (constrained decoding)...\n")
140
 
141
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
142
- input_ids = inputs["input_ids"]
143
-
144
- visual_token_ids = []
145
- current_ids = input_ids.clone()
146
 
147
- # Create a mask that only allows visual token IDs
148
  vocab_size = len(tokenizer)
149
  visual_mask = torch.zeros(vocab_size, dtype=torch.bool)
150
  visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True
151
- # Also allow <video_end> so the model can stop
152
  visual_mask[VIDEO_END_ID] = True
153
 
 
 
154
  with torch.no_grad():
155
  for step in range(max_tokens):
156
- # Forward pass
157
  outputs = model(input_ids=current_ids)
158
- next_token_logits = outputs.logits[:, -1, :] # [1, vocab_size]
159
-
160
- # Apply constraint: only allow visual tokens + <video_end>
161
- masked_logits = next_token_logits.clone()
162
- masked_logits[0, ~visual_mask] = float('-inf')
163
-
164
- # Sample from the constrained distribution
165
- probs = F.softmax(masked_logits / 0.8, dim=-1) # temperature=0.8
166
-
167
- # Check if <video_end> has high probability
168
- end_prob = probs[0, VIDEO_END_ID].item()
169
-
170
- # Sample
171
- next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
172
  next_id = next_token.item()
173
 
174
- # If the model chose <video_end>, stop
175
  if next_id == VIDEO_END_ID:
176
- log_lines.append(f" Model chose <video_end> at step {step} (end_prob={end_prob:.4f})\n")
177
  break
178
 
179
- # Convert token ID to visual token index
180
  visual_idx = next_id - V_TOKEN_START_ID
181
  visual_token_ids.append(visual_idx)
182
-
183
- # Append to sequence
184
  current_ids = torch.cat([current_ids, next_token], dim=-1)
185
 
186
- log_lines.append(f"🎨 Generated {len(visual_token_ids)} visual tokens\n")
187
 
188
  if not visual_token_ids:
189
- log_lines.append("⚠️ No visual tokens generated even with constrained decoding.\n")
190
- log_lines.append(" Falling back to random token sampling from VQ-VAE codebook.\n")
191
- # Fallback: generate random visual tokens
192
  import random
193
  visual_token_ids = [random.randint(0, 1023) for _ in range(64)]
194
- log_lines.append(f" Generated {len(visual_token_ids)} random tokens as fallback\n")
195
 
196
- sample_tokens = visual_token_ids[:20]
197
- log_lines.append(f" Sample tokens: {sample_tokens}\n")
198
- unique = len(set(visual_token_ids))
199
- log_lines.append(f" Unique tokens: {unique} / {len(visual_token_ids)}\n\n")
200
-
201
- # ── Decode to video frames ──────────────────────────────────────────
202
- log_lines.append("🎞️ Decoding tokens β†’ video frames via VQ-VAE...\n")
203
 
 
 
204
  grid_h, grid_w = 8, 8
205
- tokens_per_frame = grid_h * grid_w # 64
206
  num_frames = max(1, len(visual_token_ids) // tokens_per_frame)
207
- log_lines.append(f" Grid: {grid_h}Γ—{grid_w} = {tokens_per_frame} tokens/frame\n")
208
- log_lines.append(f" Frames: {num_frames}\n\n")
209
 
210
  frames = []
211
-
212
- for frame_idx in range(num_frames):
213
- start_idx = frame_idx * tokens_per_frame
214
- end_idx = start_idx + tokens_per_frame
215
- frame_tokens = visual_token_ids[start_idx:end_idx]
216
-
217
- if vq_vae is not None:
218
- try:
219
- frame_tensor = vq_vae.decode_tokens(frame_tokens, grid_h, grid_w)
220
- frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
221
- frames.append(frame_np)
222
- except Exception as e:
223
- log_lines.append(f" ⚠️ Frame {frame_idx} VQ-VAE error: {e}, using color blocks\n")
224
- frames.append(_tokens_to_color_blocks(frame_tokens, grid_h, grid_w))
225
- else:
226
- frames.append(_tokens_to_color_blocks(frame_tokens, grid_h, grid_w))
227
 
228
  if not frames:
229
- log_lines.append("❌ No frames generated!\n")
230
- return None, "\n".join(log_lines)
231
-
232
- # ── Save as video ──────────────────────────────────────────────────
233
- log_lines.append(f"πŸ’Ύ Saving {len(frames)} frames as video...\n")
234
 
 
235
  try:
236
  from PIL import Image
 
237
 
238
- # Upscale 64x64 β†’ 256x256
239
- upscaled = []
240
- for f in frames:
241
- img = Image.fromarray(f)
242
- img = img.resize((256, 256), Image.NEAREST)
243
- upscaled.append(np.array(img))
244
-
245
- # Try imageio for MP4
246
  try:
247
  import imageio
248
- output_path = "/tmp/generated_video.mp4"
249
- imageio.mimsave(output_path, upscaled, fps=2)
250
- log_lines.append(f"βœ… Video saved as MP4: {output_path}\n")
251
- except Exception:
252
- output_path = "/tmp/generated_video.gif"
253
- pil_frames = [Image.fromarray(f) for f in upscaled]
254
- pil_frames[0].save(
255
- output_path,
256
- save_all=True,
257
- append_images=pil_frames[1:],
258
- duration=500,
259
- loop=0,
260
- )
261
- log_lines.append(f"βœ… Video saved as GIF: {output_path}\n")
262
-
263
- log_lines.append(f" Resolution: 256Γ—256\n")
264
- log_lines.append(f" Frames: {len(upscaled)}\n")
265
- log_lines.append(f" FPS: 2\n\n")
266
- log_lines.append("πŸŽ‰ Video generation complete!\n")
267
- return output_path, "\n".join(log_lines)
268
  except Exception as e:
269
- import traceback
270
- log_lines.append(f"❌ Video save error: {e}\n")
271
- log_lines.append(traceback.format_exc())
272
- return None, "\n".join(log_lines)
273
 
274
 
275
- def _tokens_to_color_blocks(token_ids, grid_h=8, grid_w=8):
276
- """Convert token IDs to a color-block image as fallback."""
277
  frame = np.zeros((64, 64, 3), dtype=np.uint8)
278
- cell_h, cell_w = 64 // grid_h, 64 // grid_w
279
  for i, t in enumerate(token_ids[:grid_h * grid_w]):
280
- row, col = divmod(i, grid_w)
281
- r = (t * 37) % 256
282
- g = (t * 73) % 256
283
- b = (t * 113) % 256
284
- frame[row*cell_h:(row+1)*cell_h, col*cell_w:(col+1)*cell_w] = [r, g, b]
285
  return frame
286
 
287
 
288
- # ── Preload models on boot in background ───────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  def preload():
290
  try:
291
  load_models()
292
- print("πŸš€ Models preloaded and ready!")
293
  except Exception as e:
294
  print(f"⚠️ Preload error: {e}")
295
 
296
- preload_thread = threading.Thread(target=preload, daemon=True)
297
- preload_thread.start()
298
 
299
 
300
  # ── Gradio UI ──────────────────────────────────────────────────────────
301
- with gr.Blocks(
302
- title="Zeeb β€” Video-LLM",
303
- theme=gr.themes.Soft(),
304
- ) as demo:
305
 
306
- gr.Markdown(
307
- """
308
  # 🎬 Zeeb β€” Video-LLM
309
- **OLMo 2 1B Instruct** fine-tuned with **LoRA (r=4)** to generate video tokens.
310
- Model: [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
311
-
312
- Uses **constrained decoding** β€” after `<video_start>`, only visual tokens are allowed.
313
- """
314
- )
315
-
316
- prompt_input = gr.Textbox(
317
- label="Video Description",
318
- placeholder="A cat jumping on a sofa",
319
- lines=2,
320
- value="A cat jumping on a sofa",
321
- )
322
- max_tokens_slider = gr.Slider(
323
- minimum=32, maximum=256, value=128, step=32,
324
- label="Max Visual Tokens to Generate",
325
- )
326
- generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
327
- video_output = gr.Video(label="Generated Video")
328
- gen_log = gr.Textbox(
329
- label="Generation Log",
330
- lines=25,
331
- interactive=False,
332
- show_copy_button=True,
333
- )
334
-
335
- generate_btn.click(
336
- fn=generate_video,
337
- inputs=[prompt_input, max_tokens_slider],
338
- outputs=[video_output, gen_log],
339
- )
340
 
341
 
342
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ Gradio App for EeshaAI/Zeeb β€” Video Generation + Training Pipeline
4
+ ===================================================================
5
+ Tab 1: Generate Video (uses trained model + VQ-VAE)
6
+ Tab 2: Run Full Pipeline (VQ-VAE training β†’ dataset tokenization β†’ LLM training β†’ push)
 
 
7
  """
8
 
9
  import os
 
12
  import numpy as np
13
  import gradio as gr
14
 
15
+ LOG_FILE = "/tmp/pipeline_log.txt"
16
+
17
  # Global model cache
18
  _model = None
19
  _tokenizer = None
20
  _vq_vae = None
21
  _loading_lock = threading.Lock()
22
 
23
+ # Visual token ID range
24
+ VIDEO_START_ID = None
25
  VIDEO_END_ID = None
26
  V_TOKEN_START_ID = None
27
  V_TOKEN_END_ID = None
 
37
  return _model, _tokenizer, _vq_vae
38
 
39
  import torch
40
+ import torch.nn as nn
41
+
42
+ # ── VQ-VAE decoder ─────────────────────────────────────────────
43
+ class VQVAEDecoderOnly(nn.Module):
44
+ def __init__(self, codebook_size=1024, codebook_dim=256, latent_dim=256):
45
+ super().__init__()
46
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
47
+ self.proj = nn.Linear(codebook_dim, latent_dim)
48
+ self.decoder = nn.Sequential(
49
+ nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), nn.ReLU(),
50
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
51
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
52
+ nn.Conv2d(64, 3, 3, padding=1), nn.Sigmoid(),
53
+ )
54
+
55
+ def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
56
+ tokens = torch.tensor(token_ids[:grid_h * grid_w], dtype=torch.long)
57
+ if len(tokens) < grid_h * grid_w:
58
+ tokens = torch.cat([tokens, torch.zeros(grid_h * grid_w - len(tokens), dtype=torch.long)])
59
+ z = self.codebook(tokens)
60
+ z = self.proj(z)
61
+ z = z.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)
62
+ frame = self.decoder(z)
63
+ return frame
64
+
65
+ # Try loading from local file first, then from model repo
66
+ vq_vae_loaded = False
67
+ for vq_path in ["vq_vae_real.pt", "vq_vae_final.pt"]:
68
+ if os.path.exists(vq_path):
69
+ _vq_vae = VQVAEDecoderOnly()
70
+ state = torch.load(vq_path, map_location="cpu", weights_only=False)
71
+ sd = state.get("state_dict", state.get("model_state_dict", state)) if isinstance(state, dict) else state
 
 
 
 
 
 
 
 
 
 
72
  filtered = {k: v for k, v in sd.items() if not k.startswith("encoder")}
73
  _vq_vae.load_state_dict(filtered, strict=False)
74
+ vq_vae_loaded = True
75
+ print(f"βœ… VQ-VAE loaded from {vq_path}")
76
+ break
77
+
78
+ if not vq_vae_loaded:
79
+ _vq_vae = VQVAEDecoderOnly()
80
+ print("⚠️ Using untrained VQ-VAE (no checkpoint found)")
81
 
82
+ # ── LLM ─────────────────────────────────────────────────────────
83
  from transformers import AutoModelForCausalLM, AutoTokenizer
84
 
85
  REPO_ID = "eeshaAI/zeeb"
 
89
  _tokenizer.pad_token = _tokenizer.eos_token
90
 
91
  _model = AutoModelForCausalLM.from_pretrained(
92
+ REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
 
 
93
  )
94
  _model.eval()
 
95
 
 
96
  VIDEO_START_ID = _tokenizer.convert_tokens_to_ids("<video_start>")
97
  VIDEO_END_ID = _tokenizer.convert_tokens_to_ids("<video_end>")
98
  V_TOKEN_START_ID = _tokenizer.convert_tokens_to_ids("<v_0>")
99
  V_TOKEN_END_ID = _tokenizer.convert_tokens_to_ids("<v_1023>")
100
+ print(f"βœ… Model loaded. Vocab: {len(_tokenizer)}")
 
101
 
102
  return _model, _tokenizer, _vq_vae
103
 
104
 
105
+ def generate_video(prompt: str, max_tokens: int = 64):
106
  """Generate video from a text prompt using constrained decoding + VQ-VAE."""
107
  import torch
108
  import torch.nn.functional as F
109
 
110
+ log = [f"🎬 Generating video for: '{prompt}'\n\n"]
 
111
 
112
  try:
113
+ log.append("πŸ“¦ Loading models...\n")
114
  model, tokenizer, vq_vae = load_models()
115
+ log.append("βœ… Models loaded.\n\n")
116
  except Exception as e:
117
+ log.append(f"❌ Load error: {e}\n")
118
+ return None, "".join(log)
 
 
119
 
120
+ # Format prompt
121
  text = f"Create a video of: {prompt} <video_start>"
122
+ log.append(f"πŸ“ Prompt: {text}\n\n")
123
+ log.append("πŸ”₯ Generating visual tokens (constrained decoding)...\n")
 
 
 
 
124
 
125
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
126
+ current_ids = inputs["input_ids"].clone()
 
 
 
127
 
128
+ # Constrained decoding mask
129
  vocab_size = len(tokenizer)
130
  visual_mask = torch.zeros(vocab_size, dtype=torch.bool)
131
  visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True
 
132
  visual_mask[VIDEO_END_ID] = True
133
 
134
+ visual_token_ids = []
135
+
136
  with torch.no_grad():
137
  for step in range(max_tokens):
 
138
  outputs = model(input_ids=current_ids)
139
+ logits = outputs.logits[:, -1, :]
140
+ masked = logits.clone()
141
+ masked[0, ~visual_mask] = float('-inf')
142
+ probs = F.softmax(masked / 0.8, dim=-1)
143
+ next_token = torch.multinomial(probs, num_samples=1)
 
 
 
 
 
 
 
 
 
144
  next_id = next_token.item()
145
 
 
146
  if next_id == VIDEO_END_ID:
 
147
  break
148
 
 
149
  visual_idx = next_id - V_TOKEN_START_ID
150
  visual_token_ids.append(visual_idx)
 
 
151
  current_ids = torch.cat([current_ids, next_token], dim=-1)
152
 
153
+ log.append(f"🎨 Generated {len(visual_token_ids)} visual tokens\n")
154
 
155
  if not visual_token_ids:
 
 
 
156
  import random
157
  visual_token_ids = [random.randint(0, 1023) for _ in range(64)]
158
+ log.append("⚠️ Fallback: random tokens\n")
159
 
160
+ log.append(f" Sample: {visual_token_ids[:20]}\n")
161
+ log.append(f" Unique: {len(set(visual_token_ids))}\n\n")
 
 
 
 
 
162
 
163
+ # Decode frames
164
+ log.append("🎞️ Decoding tokens β†’ frames...\n")
165
  grid_h, grid_w = 8, 8
166
+ tokens_per_frame = grid_h * grid_w
167
  num_frames = max(1, len(visual_token_ids) // tokens_per_frame)
 
 
168
 
169
  frames = []
170
+ for fi in range(num_frames):
171
+ ft = visual_token_ids[fi*tokens_per_frame:(fi+1)*tokens_per_frame]
172
+ try:
173
+ frame_tensor = vq_vae.decode_tokens(ft, grid_h, grid_w)
174
+ frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
175
+ frames.append(frame_np)
176
+ except:
177
+ frames.append(_tokens_to_color(ft, grid_h, grid_w))
 
 
 
 
 
 
 
 
178
 
179
  if not frames:
180
+ return None, "".join(log)
 
 
 
 
181
 
182
+ # Save video
183
  try:
184
  from PIL import Image
185
+ upscaled = [np.array(Image.fromarray(f).resize((256, 256), Image.NEAREST)) for f in frames]
186
 
 
 
 
 
 
 
 
 
187
  try:
188
  import imageio
189
+ out = "/tmp/generated_video.mp4"
190
+ imageio.mimsave(out, upscaled, fps=2)
191
+ except:
192
+ out = "/tmp/generated_video.gif"
193
+ pils = [Image.fromarray(f) for f in upscaled]
194
+ pils[0].save(out, save_all=True, append_images=pils[1:], duration=500, loop=0)
195
+
196
+ log.append(f"βœ… Video saved ({len(upscaled)} frames, 256Γ—256)\n\nπŸŽ‰ Done!\n")
197
+ return out, "".join(log)
 
 
 
 
 
 
 
 
 
 
 
198
  except Exception as e:
199
+ log.append(f"❌ Save error: {e}\n")
200
+ return None, "".join(log)
 
 
201
 
202
 
203
+ def _tokens_to_color(token_ids, grid_h=8, grid_w=8):
 
204
  frame = np.zeros((64, 64, 3), dtype=np.uint8)
205
+ ch, cw = 64 // grid_h, 64 // grid_w
206
  for i, t in enumerate(token_ids[:grid_h * grid_w]):
207
+ r, c = divmod(i, grid_w)
208
+ frame[r*ch:(r+1)*ch, c*cw:(c+1)*cw] = [(t*37)%256, (t*73)%256, (t*113)%256]
 
 
 
209
  return frame
210
 
211
 
212
+ def get_log():
213
+ try:
214
+ with open(LOG_FILE, "r") as f:
215
+ return f.read()
216
+ except:
217
+ return "No pipeline log yet."
218
+
219
+
220
+ def start_pipeline():
221
+ """Start the full training pipeline in background."""
222
+ from train_full_pipeline import run_pipeline
223
+ t = threading.Thread(target=run_pipeline, args=(LOG_FILE,), daemon=True)
224
+ t.start()
225
+ return "πŸš€ Pipeline started! Click Refresh to see progress."
226
+
227
+
228
+ # ── Preload generation models ───────────────────────────────────────────
229
  def preload():
230
  try:
231
  load_models()
232
+ print("πŸš€ Generation models preloaded!")
233
  except Exception as e:
234
  print(f"⚠️ Preload error: {e}")
235
 
236
+ threading.Thread(target=preload, daemon=True).start()
 
237
 
238
 
239
  # ── Gradio UI ──────────────────────────────────────────────────────────
240
+ with gr.Blocks(title="Zeeb β€” Video-LLM", theme=gr.themes.Soft()) as demo:
 
 
 
241
 
242
+ gr.Markdown("""
 
243
  # 🎬 Zeeb β€” Video-LLM
244
+ **OLMo 2 1B** + **LoRA** + **VQ-VAE** β†’ Text-to-Video generation.
245
+ [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
246
+ """)
247
+
248
+ with gr.Tabs():
249
+ with gr.Tab("🎬 Generate Video"):
250
+ prompt_input = gr.Textbox(label="Video Description", value="A cat jumping on a sofa", lines=2)
251
+ max_tok = gr.Slider(32, 128, value=64, step=32, label="Max Visual Tokens")
252
+ gen_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
253
+ video_out = gr.Video(label="Generated Video")
254
+ gen_log = gr.Textbox(label="Log", lines=15, interactive=False, show_copy_button=True)
255
+ gen_btn.click(fn=generate_video, inputs=[prompt_input, max_tok], outputs=[video_out, gen_log])
256
+
257
+ with gr.Tab("πŸ”§ Full Training Pipeline"):
258
+ gr.Markdown("""
259
+ ### Train from scratch with real data
260
+ 1. **Phase 1**: Train VQ-VAE on 50K COCO images (real photos!)
261
+ 2. **Phase 2**: Tokenize 10K OpenVid-1M clips (or 50K COCO images as fallback)
262
+ 3. **Phase 3**: Fine-tune OLMo 2 1B + LoRA on tokenized data
263
+ 4. **Phase 4**: Push trained model to EeshaAI/zeeb
264
+
265
+ ⚠️ This takes **many hours** on CPU. The Space may need restarts.
266
+ """)
267
+ pipe_btn = gr.Button("πŸš€ Start Full Pipeline", variant="primary", size="lg")
268
+ ref_btn = gr.Button("πŸ”„ Refresh Log")
269
+ pipe_log = gr.Textbox(label="Pipeline Log", value=lambda: get_log(), lines=30,
270
+ interactive=False, show_copy_button=True)
271
+ pipe_btn.click(fn=start_pipeline, outputs=pipe_log)
272
+ ref_btn.click(fn=get_log, outputs=pipe_log)
 
 
273
 
274
 
275
  if __name__ == "__main__":