eeshaAI commited on
Commit
83a2068
Β·
verified Β·
1 Parent(s): 5908d9b

Rewrite: generation-only app, preload models, no auto-training

Browse files
Files changed (1) hide show
  1. app.py +153 -209
app.py CHANGED
@@ -1,128 +1,97 @@
1
  #!/usr/bin/env python3
2
  """
3
- Gradio App for EeshaAI/Zeeb β€” Training + Video Generation
4
- ==========================================================
5
- Tab 1: Training (auto-starts on boot)
6
- Tab 2: Generate Video (loads trained model + VQ-VAE, generates video from prompt)
7
  """
8
 
9
  import os
10
- import time
11
  import re
12
  import threading
13
  import numpy as np
14
  import gradio as gr
15
 
16
- LOG_FILE = "/tmp/training_log.txt"
17
- GENERATE_LOG = "/tmp/generation_log.txt"
18
-
19
  # Global model cache
20
  _model = None
21
  _tokenizer = None
22
  _vq_vae = None
23
-
24
-
25
- def start_training_background():
26
- """Start training in a background thread on Space startup."""
27
- from train_on_hf_spaces import run_training_to_file
28
- run_training_to_file(LOG_FILE)
29
-
30
-
31
- def get_log():
32
- """Read the current training log."""
33
- try:
34
- with open(LOG_FILE, "r") as f:
35
- return f.read()
36
- except FileNotFoundError:
37
- return "⏳ Training has not started yet. Please wait..."
38
-
39
-
40
- def refresh_log():
41
- return get_log()
42
 
43
 
44
  def load_models():
45
  """Load the trained LLM and VQ-VAE decoder (lazy, cached)."""
46
  global _model, _tokenizer, _vq_vae
47
 
48
- if _model is not None and _tokenizer is not None:
49
- return _model, _tokenizer, _vq_vae
50
-
51
- import torch
52
-
53
- # ── Load VQ-VAE decoder ─────────────────────────────────────────────
54
- vq_vae_path = "vq_vae_final.pt"
55
- if os.path.exists(vq_vae_path):
56
- import torch.nn as nn
57
-
58
- class VQVAEDecoderOnly(nn.Module):
59
- """Minimal VQ-VAE decoder for token β†’ pixel decoding."""
60
- def __init__(self, codebook_size=1024, codebook_dim=256, latent_dim=256):
61
- super().__init__()
62
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
63
- self.proj = nn.Linear(codebook_dim, latent_dim)
64
- # Decoder: upscale from 8x8 spatial to 64x64
65
- self.decoder = nn.Sequential(
66
- nn.ConvTranspose2d(latent_dim, 128, 4, stride=2, padding=1), # 8β†’16
67
- nn.ReLU(),
68
- nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 16β†’32
69
- nn.ReLU(),
70
- nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 32β†’64
71
- nn.ReLU(),
72
- nn.Conv2d(32, 3, 3, padding=1),
73
- nn.Sigmoid(),
74
- )
75
-
76
- def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
77
- """Decode a flat list of token IDs into a video frame."""
78
- # token_ids: list of ints, length should be grid_h * grid_w
79
- tokens = torch.tensor(token_ids[:grid_h * grid_w], dtype=torch.long)
80
- if len(tokens) < grid_h * grid_w:
81
- tokens = torch.cat([tokens, torch.zeros(grid_h * grid_w - len(tokens), dtype=torch.long)])
82
-
83
- # Lookup codebook
84
- z = self.codebook(tokens) # [H*W, D]
85
- z = self.proj(z) # [H*W, latent_dim]
86
- z = z.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2) # [1, C, H, W]
87
-
88
- # Decode
89
- frame = self.decoder(z) # [1, 3, 64, 64]
90
- return frame
91
-
92
- _vq_vae = VQVAEDecoderOnly()
93
- state = torch.load(vq_vae_path, map_location="cpu", weights_only=False)
94
- # Try to load relevant weights
95
- if isinstance(state, dict):
96
- if "codebook" in state or "state_dict" in state:
97
- # Full checkpoint
98
- sd = state.get("state_dict", state)
99
  filtered = {k: v for k, v in sd.items() if not k.startswith("encoder")}
100
  _vq_vae.load_state_dict(filtered, strict=False)
101
- elif "model_state_dict" in state:
102
- _vq_vae.load_state_dict(state["model_state_dict"], strict=False)
103
- else:
104
- _vq_vae.load_state_dict(state, strict=False)
105
- print("βœ… VQ-VAE decoder loaded")
106
-
107
- # ── Load trained LLM ────────────────────────────────────────────────
108
- from transformers import AutoModelForCausalLM, AutoTokenizer
109
-
110
- REPO_ID = "eeshaAI/zeeb"
111
-
112
- print("πŸ“¦ Loading trained model from EeshaAI/zeeb...")
113
- _tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True)
114
- if _tokenizer.pad_token is None:
115
- _tokenizer.pad_token = _tokenizer.eos_token
116
-
117
- _model = AutoModelForCausalLM.from_pretrained(
118
- REPO_ID,
119
- trust_remote_code=True,
120
- torch_dtype=torch.float32,
121
- )
122
- _model.eval()
123
- print(f"βœ… Model loaded. Vocab size: {len(_tokenizer)}")
124
 
125
- return _model, _tokenizer, _vq_vae
 
 
 
 
 
 
 
 
126
 
127
 
128
  def generate_video(prompt: str, max_tokens: int = 128):
@@ -133,17 +102,18 @@ def generate_video(prompt: str, max_tokens: int = 128):
133
  log_lines.append(f"🎬 Generating video for: '{prompt}'\n\n")
134
 
135
  try:
136
- # Load models
137
- log_lines.append("πŸ“¦ Loading trained model + VQ-VAE...\n")
138
  model, tokenizer, vq_vae = load_models()
139
  log_lines.append("βœ… Models loaded.\n\n")
140
  except Exception as e:
 
141
  log_lines.append(f"❌ Failed to load models: {e}\n")
 
142
  return None, "\n".join(log_lines)
143
 
144
  # ── Format prompt ──────────────────────────────────────────────────
145
  text = f"Create a video of: {prompt} <video_start>"
146
- log_lines.append(f"πŸ“ Prompt formatted:\n {text}\n\n")
147
 
148
  # ── Generate tokens ────────────────────────────────────────────────
149
  log_lines.append("πŸ”₯ Generating visual tokens...\n")
@@ -176,7 +146,6 @@ def generate_video(prompt: str, max_tokens: int = 128):
176
  in_video = False
177
  break
178
  if in_video:
179
- # Check if it's a <v_N> token
180
  match = re.match(r"<v_(\d+)>", decoded.strip())
181
  if match:
182
  visual_token_ids.append(int(match.group(1)))
@@ -184,23 +153,22 @@ def generate_video(prompt: str, max_tokens: int = 128):
184
  log_lines.append(f"🎨 Extracted {len(visual_token_ids)} visual tokens\n")
185
 
186
  if not visual_token_ids:
187
- log_lines.append("⚠️ No visual tokens generated! The model may need more training.\n")
188
- log_lines.append(f"\nFull output:\n{full_text}\n")
189
- # Try alternative: parse from full_text
190
  all_v_tokens = re.findall(r"<v_(\d+)>", full_text)
191
  if all_v_tokens:
192
  visual_token_ids = [int(t) for t in all_v_tokens]
193
- log_lines.append(f"\nπŸ”„ Alternative extraction found {len(visual_token_ids)} tokens\n")
194
  else:
 
 
195
  return None, "\n".join(log_lines)
196
 
197
- # Show sample of tokens
198
  sample_tokens = visual_token_ids[:20]
199
  log_lines.append(f" Sample tokens: {sample_tokens}\n")
200
  log_lines.append(f" Unique tokens: {len(set(visual_token_ids))}\n\n")
201
 
202
  # ── Decode to video frames ──────────────────────────────────────────
203
- log_lines.append("🎞️ Decoding tokens β†’ video frames via VQ-VAE...\n")
204
 
205
  grid_h, grid_w = 8, 8
206
  tokens_per_frame = grid_h * grid_w # 64
@@ -215,43 +183,22 @@ def generate_video(prompt: str, max_tokens: int = 128):
215
  start = frame_idx * tokens_per_frame
216
  end = start + tokens_per_frame
217
  frame_tokens = visual_token_ids[start:end]
218
-
219
  try:
220
  frame_tensor = vq_vae.decode_tokens(frame_tokens, grid_h, grid_w)
221
- # Convert to numpy: [1, 3, 64, 64] β†’ [64, 64, 3] uint8
222
- frame_np = (frame_tensor[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
223
  frames.append(frame_np)
224
  except Exception as e:
225
  log_lines.append(f" ⚠️ Frame {frame_idx} decode error: {e}\n")
226
- # Fallback: create frame from token values as colors
227
- frame_np = np.zeros((64, 64, 3), dtype=np.uint8)
228
- for i, t in enumerate(frame_tokens[:tokens_per_frame]):
229
- row, col = divmod(i, grid_w)
230
- cell_h, cell_w = 64 // grid_h, 64 // grid_w
231
- if row < grid_h and col < grid_w:
232
- # Use token value as a color
233
- r = (t * 37) % 256
234
- g = (t * 73) % 256
235
- b = (t * 113) % 256
236
- frame_np[row*cell_h:(row+1)*cell_h, col*cell_w:(col+1)*cell_w] = [r, g, b]
237
  frames.append(frame_np)
238
  else:
239
- # No VQ-VAE: create frames from token values as colored blocks
240
  log_lines.append(" ⚠️ No VQ-VAE, using tokenβ†’color mapping\n")
241
  for frame_idx in range(num_frames):
242
  start = frame_idx * tokens_per_frame
243
  end = start + tokens_per_frame
244
  frame_tokens = visual_token_ids[start:end]
245
- frame_np = np.zeros((64, 64, 3), dtype=np.uint8)
246
- for i, t in enumerate(frame_tokens[:tokens_per_frame]):
247
- row, col = divmod(i, grid_w)
248
- cell_h, cell_w = 64 // grid_h, 64 // grid_w
249
- if row < grid_h and col < grid_w:
250
- r = (t * 37) % 256
251
- g = (t * 73) % 256
252
- b = (t * 113) % 256
253
- frame_np[row*cell_h:(row+1)*cell_h, col*cell_w:(col+1)*cell_w] = [r, g, b]
254
- frames.append(frame_np)
255
 
256
  if not frames:
257
  log_lines.append("❌ No frames generated!\n")
@@ -261,30 +208,25 @@ def generate_video(prompt: str, max_tokens: int = 128):
261
  log_lines.append(f"πŸ’Ύ Saving {len(frames)} frames as video...\n")
262
 
263
  try:
264
- import imageio
265
- output_path = "/tmp/generated_video.mp4"
266
- # Upscale frames from 64x64 to 256x256 for better visibility
267
  from PIL import Image
 
 
268
  upscaled = []
269
  for f in frames:
270
  img = Image.fromarray(f)
271
  img = img.resize((256, 256), Image.NEAREST)
272
  upscaled.append(np.array(img))
273
 
274
- # Save as mp4 (2 fps for slow playback since we have few frames)
275
- imageio.mimsave(output_path, upscaled, fps=2)
276
- log_lines.append(f"βœ… Video saved to {output_path}\n")
277
- log_lines.append(f" Resolution: 256Γ—256\n")
278
- log_lines.append(f" Frames: {len(upscaled)}\n")
279
- log_lines.append(f" FPS: 2\n\n")
280
- log_lines.append("πŸŽ‰ Video generation complete!\n")
281
- return output_path, "\n".join(log_lines)
282
- except ImportError:
283
- # Fallback: save as GIF
284
  try:
285
- from PIL import Image
 
 
 
 
 
286
  output_path = "/tmp/generated_video.gif"
287
- pil_frames = [Image.fromarray(f).resize((256, 256), Image.NEAREST) for f in frames]
288
  pil_frames[0].save(
289
  output_path,
290
  save_all=True,
@@ -292,23 +234,43 @@ def generate_video(prompt: str, max_tokens: int = 128):
292
  duration=500,
293
  loop=0,
294
  )
295
- log_lines.append(f"βœ… GIF saved to {output_path}\n")
296
- return output_path, "\n".join(log_lines)
297
- except Exception as e:
298
- log_lines.append(f"❌ Failed to save video: {e}\n")
299
- # Return first frame as image at least
300
- img_path = "/tmp/generated_frame.png"
301
- Image.fromarray(frames[0]).resize((256, 256), Image.NEAREST).save(img_path)
302
- log_lines.append(f"πŸ“Έ Saved single frame to {img_path}\n")
303
- return img_path, "\n".join(log_lines)
304
  except Exception as e:
 
305
  log_lines.append(f"❌ Video save error: {e}\n")
 
306
  return None, "\n".join(log_lines)
307
 
308
 
309
- # ── Auto-start training on boot ────────────────────────────────────────
310
- training_thread = threading.Thread(target=start_training_background, daemon=True)
311
- training_thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
 
314
  # ── Gradio UI ──────────────────────────────────────────────────────────
@@ -320,55 +282,37 @@ with gr.Blocks(
320
  gr.Markdown(
321
  """
322
  # 🎬 Zeeb β€” Video-LLM
323
- **OLMo 2 1B Instruct** fine-tuned with **LoRA** to generate video tokens.
324
- Model repo: [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
 
 
325
  """
326
  )
327
 
328
- with gr.Tabs():
329
- # ── Tab 1: Generate Video ───────────────────────────────────────
330
- with gr.Tab("🎬 Generate Video"):
331
- prompt_input = gr.Textbox(
332
- label="Video Description",
333
- placeholder="A cat jumping on a sofa",
334
- lines=2,
335
- )
336
- max_tokens_slider = gr.Slider(
337
- minimum=32, maximum=256, value=128, step=32,
338
- label="Max Visual Tokens",
339
- )
340
- generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
341
- video_output = gr.Video(label="Generated Video")
342
- gen_log = gr.Textbox(
343
- label="Generation Log",
344
- lines=20,
345
- interactive=False,
346
- show_copy_button=True,
347
- )
348
- generate_btn.click(
349
- fn=generate_video,
350
- inputs=[prompt_input, max_tokens_slider],
351
- outputs=[video_output, gen_log],
352
- )
353
 
354
- # ── Tab 2: Training ─────────────────────────────────────────────
355
- with gr.Tab("πŸ”§ Training"):
356
- gr.Markdown(
357
- """
358
- Training **starts automatically** when this Space boots.
359
- Click **Refresh Log** to see progress.
360
- """
361
- )
362
- refresh_btn = gr.Button("πŸ”„ Refresh Log")
363
- logbox = gr.Textbox(
364
- label="Training Log",
365
- value=lambda: get_log(),
366
- lines=25,
367
- max_lines=200,
368
- interactive=False,
369
- show_copy_button=True,
370
- )
371
- refresh_btn.click(fn=refresh_log, outputs=logbox)
372
 
373
 
374
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ Gradio App for EeshaAI/Zeeb β€” Video Generation
4
+ ================================================
5
+ Uses the trained OLMo 2 1B + LoRA model to generate video tokens,
6
+ then decodes them via VQ-VAE into a video file.
7
  """
8
 
9
  import os
 
10
  import re
11
  import threading
12
  import numpy as np
13
  import gradio as gr
14
 
 
 
 
15
  # Global model cache
16
  _model = None
17
  _tokenizer = None
18
  _vq_vae = None
19
+ _loading_lock = threading.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def load_models():
23
  """Load the trained LLM and VQ-VAE decoder (lazy, cached)."""
24
  global _model, _tokenizer, _vq_vae
25
 
26
+ with _loading_lock:
27
+ if _model is not None and _tokenizer is not None:
28
+ return _model, _tokenizer, _vq_vae
29
+
30
+ import torch
31
+
32
+ # ── Load VQ-VAE decoder ─────────────────────────────────────────
33
+ vq_vae_path = "vq_vae_final.pt"
34
+ if os.path.exists(vq_vae_path):
35
+ import torch.nn as nn
36
+
37
+ class VQVAEDecoderOnly(nn.Module):
38
+ """Minimal VQ-VAE decoder for token β†’ pixel decoding."""
39
+ def __init__(self, codebook_size=1024, codebook_dim=256, latent_dim=256):
40
+ super().__init__()
41
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
42
+ self.proj = nn.Linear(codebook_dim, latent_dim)
43
+ self.decoder = nn.Sequential(
44
+ nn.ConvTranspose2d(latent_dim, 128, 4, stride=2, padding=1),
45
+ nn.ReLU(),
46
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
47
+ nn.ReLU(),
48
+ nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
49
+ nn.ReLU(),
50
+ nn.Conv2d(32, 3, 3, padding=1),
51
+ nn.Sigmoid(),
52
+ )
53
+
54
+ def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
55
+ tokens = torch.tensor(token_ids[:grid_h * grid_w], dtype=torch.long)
56
+ if len(tokens) < grid_h * grid_w:
57
+ tokens = torch.cat([tokens, torch.zeros(grid_h * grid_w - len(tokens), dtype=torch.long)])
58
+ z = self.codebook(tokens)
59
+ z = self.proj(z)
60
+ z = z.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)
61
+ frame = self.decoder(z)
62
+ return frame
63
+
64
+ _vq_vae = VQVAEDecoderOnly()
65
+ state = torch.load(vq_vae_path, map_location="cpu", weights_only=False)
66
+ if isinstance(state, dict):
67
+ if "state_dict" in state:
68
+ sd = state["state_dict"]
69
+ elif "model_state_dict" in state:
70
+ sd = state["model_state_dict"]
71
+ else:
72
+ sd = state
 
 
 
 
73
  filtered = {k: v for k, v in sd.items() if not k.startswith("encoder")}
74
  _vq_vae.load_state_dict(filtered, strict=False)
75
+ print("βœ… VQ-VAE decoder loaded")
76
+
77
+ # ── Load trained LLM ────────────────────────────────────────────
78
+ from transformers import AutoModelForCausalLM, AutoTokenizer
79
+
80
+ REPO_ID = "eeshaAI/zeeb"
81
+ print("πŸ“¦ Loading trained model from EeshaAI/zeeb...")
82
+ _tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True)
83
+ if _tokenizer.pad_token is None:
84
+ _tokenizer.pad_token = _tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ _model = AutoModelForCausalLM.from_pretrained(
87
+ REPO_ID,
88
+ trust_remote_code=True,
89
+ torch_dtype=torch.float32,
90
+ )
91
+ _model.eval()
92
+ print(f"βœ… Model loaded. Vocab size: {len(_tokenizer)}")
93
+
94
+ return _model, _tokenizer, _vq_vae
95
 
96
 
97
  def generate_video(prompt: str, max_tokens: int = 128):
 
102
  log_lines.append(f"🎬 Generating video for: '{prompt}'\n\n")
103
 
104
  try:
105
+ log_lines.append("πŸ“¦ Loading trained model + VQ-VAE (first run takes ~3 min)...\n")
 
106
  model, tokenizer, vq_vae = load_models()
107
  log_lines.append("βœ… Models loaded.\n\n")
108
  except Exception as e:
109
+ import traceback
110
  log_lines.append(f"❌ Failed to load models: {e}\n")
111
+ log_lines.append(traceback.format_exc())
112
  return None, "\n".join(log_lines)
113
 
114
  # ── Format prompt ──────────────────────────────────────────────────
115
  text = f"Create a video of: {prompt} <video_start>"
116
+ log_lines.append(f"πŸ“ Prompt: {text}\n\n")
117
 
118
  # ── Generate tokens ────────────────────────────────────────────────
119
  log_lines.append("πŸ”₯ Generating visual tokens...\n")
 
146
  in_video = False
147
  break
148
  if in_video:
 
149
  match = re.match(r"<v_(\d+)>", decoded.strip())
150
  if match:
151
  visual_token_ids.append(int(match.group(1)))
 
153
  log_lines.append(f"🎨 Extracted {len(visual_token_ids)} visual tokens\n")
154
 
155
  if not visual_token_ids:
156
+ log_lines.append("⚠️ No visual tokens in structured format. Trying regex on full output...\n")
 
 
157
  all_v_tokens = re.findall(r"<v_(\d+)>", full_text)
158
  if all_v_tokens:
159
  visual_token_ids = [int(t) for t in all_v_tokens]
160
+ log_lines.append(f"πŸ”„ Regex found {len(visual_token_ids)} tokens\n")
161
  else:
162
+ log_lines.append("⚠️ No visual tokens at all. Showing raw output:\n")
163
+ log_lines.append(f"\n{full_text[:1000]}\n")
164
  return None, "\n".join(log_lines)
165
 
 
166
  sample_tokens = visual_token_ids[:20]
167
  log_lines.append(f" Sample tokens: {sample_tokens}\n")
168
  log_lines.append(f" Unique tokens: {len(set(visual_token_ids))}\n\n")
169
 
170
  # ── Decode to video frames ──────────────────────────────────────────
171
+ log_lines.append("🎞️ Decoding tokens β†’ video frames...\n")
172
 
173
  grid_h, grid_w = 8, 8
174
  tokens_per_frame = grid_h * grid_w # 64
 
183
  start = frame_idx * tokens_per_frame
184
  end = start + tokens_per_frame
185
  frame_tokens = visual_token_ids[start:end]
 
186
  try:
187
  frame_tensor = vq_vae.decode_tokens(frame_tokens, grid_h, grid_w)
188
+ frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
 
189
  frames.append(frame_np)
190
  except Exception as e:
191
  log_lines.append(f" ⚠️ Frame {frame_idx} decode error: {e}\n")
192
+ # Fallback: color blocks
193
+ frame_np = _tokens_to_color_blocks(frame_tokens, grid_h, grid_w)
 
 
 
 
 
 
 
 
 
194
  frames.append(frame_np)
195
  else:
 
196
  log_lines.append(" ⚠️ No VQ-VAE, using tokenβ†’color mapping\n")
197
  for frame_idx in range(num_frames):
198
  start = frame_idx * tokens_per_frame
199
  end = start + tokens_per_frame
200
  frame_tokens = visual_token_ids[start:end]
201
+ frames.append(_tokens_to_color_blocks(frame_tokens, grid_h, grid_w))
 
 
 
 
 
 
 
 
 
202
 
203
  if not frames:
204
  log_lines.append("❌ No frames generated!\n")
 
208
  log_lines.append(f"πŸ’Ύ Saving {len(frames)} frames as video...\n")
209
 
210
  try:
 
 
 
211
  from PIL import Image
212
+
213
+ # Upscale 64x64 β†’ 256x256
214
  upscaled = []
215
  for f in frames:
216
  img = Image.fromarray(f)
217
  img = img.resize((256, 256), Image.NEAREST)
218
  upscaled.append(np.array(img))
219
 
220
+ # Try imageio for MP4
 
 
 
 
 
 
 
 
 
221
  try:
222
+ import imageio
223
+ output_path = "/tmp/generated_video.mp4"
224
+ imageio.mimsave(output_path, upscaled, fps=2)
225
+ log_lines.append(f"βœ… Video saved as MP4: {output_path}\n")
226
+ except Exception:
227
+ # Fallback to GIF
228
  output_path = "/tmp/generated_video.gif"
229
+ pil_frames = [Image.fromarray(f) for f in upscaled]
230
  pil_frames[0].save(
231
  output_path,
232
  save_all=True,
 
234
  duration=500,
235
  loop=0,
236
  )
237
+ log_lines.append(f"βœ… Video saved as GIF: {output_path}\n")
238
+
239
+ log_lines.append(f" Resolution: 256Γ—256\n")
240
+ log_lines.append(f" Frames: {len(upscaled)}\n")
241
+ log_lines.append(f" FPS: 2\n\n")
242
+ log_lines.append("πŸŽ‰ Video generation complete!\n")
243
+ return output_path, "\n".join(log_lines)
 
 
244
  except Exception as e:
245
+ import traceback
246
  log_lines.append(f"❌ Video save error: {e}\n")
247
+ log_lines.append(traceback.format_exc())
248
  return None, "\n".join(log_lines)
249
 
250
 
251
+ def _tokens_to_color_blocks(token_ids, grid_h=8, grid_w=8):
252
+ """Convert token IDs to a color-block image as fallback."""
253
+ frame = np.zeros((64, 64, 3), dtype=np.uint8)
254
+ cell_h, cell_w = 64 // grid_h, 64 // grid_w
255
+ for i, t in enumerate(token_ids[:grid_h * grid_w]):
256
+ row, col = divmod(i, grid_w)
257
+ r = (t * 37) % 256
258
+ g = (t * 73) % 256
259
+ b = (t * 113) % 256
260
+ frame[row*cell_h:(row+1)*cell_h, col*cell_w:(col+1)*cell_w] = [r, g, b]
261
+ return frame
262
+
263
+
264
+ # ── Preload models on boot in background ───────────────────────────────
265
+ def preload():
266
+ try:
267
+ load_models()
268
+ print("πŸš€ Models preloaded and ready!")
269
+ except Exception as e:
270
+ print(f"⚠️ Preload error: {e}")
271
+
272
+ preload_thread = threading.Thread(target=preload, daemon=True)
273
+ preload_thread.start()
274
 
275
 
276
  # ── Gradio UI ──────────────────────────────────────────────────────────
 
282
  gr.Markdown(
283
  """
284
  # 🎬 Zeeb β€” Video-LLM
285
+ **OLMo 2 1B Instruct** fine-tuned with **LoRA (r=4)** to generate video tokens.
286
+ Model: [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
287
+
288
+ Type a description and click Generate!
289
  """
290
  )
291
 
292
+ prompt_input = gr.Textbox(
293
+ label="Video Description",
294
+ placeholder="A cat jumping on a sofa",
295
+ lines=2,
296
+ value="A cat jumping on a sofa",
297
+ )
298
+ max_tokens_slider = gr.Slider(
299
+ minimum=32, maximum=256, value=128, step=32,
300
+ label="Max Visual Tokens to Generate",
301
+ )
302
+ generate_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
303
+ video_output = gr.Video(label="Generated Video")
304
+ gen_log = gr.Textbox(
305
+ label="Generation Log",
306
+ lines=20,
307
+ interactive=False,
308
+ show_copy_button=True,
309
+ )
 
 
 
 
 
 
 
310
 
311
+ generate_btn.click(
312
+ fn=generate_video,
313
+ inputs=[prompt_input, max_tokens_slider],
314
+ outputs=[video_output, gen_log],
315
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
 
318
  if __name__ == "__main__":