File size: 14,802 Bytes
defa07b
 
6e8dde1
 
 
 
defa07b
 
3a0f51d
5ab6307
c8810e1
5ab6307
91d65b4
defa07b
395c0d2
6e8dde1
5ab6307
 
 
 
83a2068
3a0f51d
6e8dde1
 
7eab64f
 
 
 
c8810e1
5ab6307
395c0d2
5ab6307
7eab64f
5ab6307
83a2068
 
 
 
 
6e8dde1
 
395c0d2
 
 
6e8dde1
395c0d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8dde1
395c0d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8dde1
 
 
395c0d2
6e8dde1
395c0d2
 
 
 
 
 
 
 
 
 
 
6e8dde1
 
395c0d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8dde1
395c0d2
6e8dde1
395c0d2
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8dde1
 
395c0d2
 
 
83a2068
395c0d2
83a2068
 
 
395c0d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eab64f
83a2068
5ab6307
 
395c0d2
7eab64f
5ab6307
7eab64f
5ab6307
395c0d2
5ab6307
 
395c0d2
5ab6307
395c0d2
 
 
5ab6307
395c0d2
6e8dde1
5ab6307
6e8dde1
5ab6307
395c0d2
 
7eab64f
5ab6307
6e8dde1
7eab64f
395c0d2
7eab64f
 
 
 
5ab6307
6e8dde1
 
5ab6307
7eab64f
 
6e8dde1
395c0d2
 
6e8dde1
 
395c0d2
 
 
 
 
 
 
 
 
 
 
6e8dde1
7eab64f
 
 
 
 
 
 
 
 
395c0d2
5ab6307
 
7eab64f
 
395c0d2
5ab6307
6e8dde1
 
5ab6307
395c0d2
 
5ab6307
6e8dde1
5ab6307
 
 
6e8dde1
 
 
 
 
395c0d2
6e8dde1
395c0d2
 
6e8dde1
5ab6307
 
6e8dde1
5ab6307
6e8dde1
5ab6307
 
395c0d2
 
83a2068
5ab6307
83a2068
6e8dde1
 
 
 
 
 
 
395c0d2
6e8dde1
5ab6307
395c0d2
6e8dde1
5ab6307
 
6e8dde1
395c0d2
 
 
83a2068
6e8dde1
 
83a2068
 
 
6e8dde1
 
 
2c311a6
 
 
 
 
 
6e8dde1
 
 
 
 
 
 
 
 
395c0d2
6e8dde1
 
395c0d2
83a2068
 
 
395c0d2
83a2068
395c0d2
83a2068
6e8dde1
defa07b
 
395c0d2
6e8dde1
defa07b
6e8dde1
395c0d2
6e8dde1
 
 
 
 
395c0d2
6e8dde1
395c0d2
 
 
 
 
6e8dde1
 
395c0d2
6e8dde1
395c0d2
6e8dde1
 
395c0d2
 
 
6e8dde1
 
395c0d2
 
6e8dde1
395c0d2
 
6e8dde1
 
 
 
defa07b
91d65b4
defa07b
3a0f51d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
#!/usr/bin/env python3
"""
Gradio App for EeshaAI/Zeeb β€” Video Generation + Training Pipeline
===================================================================
Tab 1: Generate Video (uses trained model + VQ-VAE)
Tab 2: Run Full Pipeline (VQ-VAE training β†’ dataset tokenization β†’ LLM training β†’ push)
"""

import os
import re
import threading
import numpy as np
import gradio as gr

LOG_FILE = os.path.join(os.environ.get("DATA_DIR", "/data"), "pipeline_log.txt")

# Global model cache
_model = None
_tokenizer = None
_vq_vae = None
_loading_lock = threading.Lock()

# Visual token ID range
VIDEO_START_ID = None
VIDEO_END_ID = None
V_TOKEN_START_ID = None
V_TOKEN_END_ID = None


def load_models():
    """Load the trained LLM and VQ-VAE (lazy, cached)."""
    global _model, _tokenizer, _vq_vae
    global VIDEO_START_ID, VIDEO_END_ID, V_TOKEN_START_ID, V_TOKEN_END_ID

    with _loading_lock:
        if _model is not None and _tokenizer is not None:
            return _model, _tokenizer, _vq_vae

        import torch
        import torch.nn as nn

        # Full VQ-VAE model (same architecture as training)
        class Encoder(nn.Module):
            def __init__(self, in_channels=3, latent_dim=256):
                super().__init__()
                self.net = nn.Sequential(
                    nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(256, latent_dim, 4, stride=2, padding=1),
                )
            def forward(self, x):
                return self.net(x)

        class VectorQuantizer(nn.Module):
            def __init__(self, codebook_size=1024, codebook_dim=256, commitment_cost=0.25):
                super().__init__()
                self.codebook_size = codebook_size
                self.codebook_dim = codebook_dim
                self.commitment_cost = commitment_cost
                self.codebook = nn.Embedding(codebook_size, codebook_dim)
                self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)

            def forward(self, z):
                B, H, W, C = z.shape
                z_flat = z.reshape(-1, C)
                dist = (z_flat.unsqueeze(1) - self.codebook.weight.unsqueeze(0)).pow(2).sum(-1)
                indices = dist.argmin(dim=1)
                z_q = self.codebook(indices).reshape(B, H, W, C)
                commitment_loss = torch.nn.functional.mse_loss(z_flat, z_q.reshape(-1, C).detach())
                codebook_loss = torch.nn.functional.mse_loss(z_q.reshape(-1, C), z_flat.detach())
                loss = codebook_loss + self.commitment_cost * commitment_loss
                z_q_st = z + (z_q - z).detach()
                return z_q_st, loss, indices.reshape(B, H, W)

        class Decoder(nn.Module):
            def __init__(self, out_channels=3, latent_dim=256):
                super().__init__()
                self.net = nn.Sequential(
                    nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), nn.ReLU(),
                    nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
                    nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
                    nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), nn.Sigmoid(),
                )
            def forward(self, x):
                return self.net(x)

        class VQVAE(nn.Module):
            def __init__(self):
                super().__init__()
                self.encoder = Encoder()
                self.quantizer = VectorQuantizer()
                self.proj_in = nn.Linear(256, 256)
                self.proj_out = nn.Linear(256, 256)
                self.decoder = Decoder()

            def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
                if isinstance(token_ids, list):
                    token_ids = torch.tensor(token_ids, dtype=torch.long)
                token_ids = token_ids[:grid_h * grid_w]
                if len(token_ids) < grid_h * grid_w:
                    token_ids = torch.cat([token_ids, torch.zeros(grid_h * grid_w - len(token_ids), dtype=torch.long)])
                z_q = self.quantizer.codebook(token_ids)
                z_q = self.proj_out(z_q)
                z_q = z_q.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)
                return self.decoder(z_q)

        # Try loading from multiple locations
        PERSIST_DIR = os.path.join(os.environ.get("DATA_DIR", "/data"), "zeeb_checkpoints")
        vq_paths = [
            os.path.join(PERSIST_DIR, "vq_vae_best.pt"),
            os.path.join(PERSIST_DIR, "vq_vae_latest.pt"),
            "vq_vae_real.pt",
            "vq_vae_final.pt",
        ]
        
        vq_vae_loaded = False
        for vq_path in vq_paths:
            if os.path.exists(vq_path):
                try:
                    _vq_vae = VQVAE()
                    state_dict = torch.load(vq_path, map_location="cpu", weights_only=False)
                    # Handle different save formats
                    if isinstance(state_dict, dict) and "model_state_dict" in state_dict:
                        state_dict = state_dict["model_state_dict"]
                    _vq_vae.load_state_dict(state_dict, strict=True)
                    _vq_vae.eval()
                    vq_vae_loaded = True
                    print(f"VQ-VAE loaded from {vq_path}")
                    break
                except Exception as e:
                    print(f"Failed to load VQ-VAE from {vq_path}: {e}")
                    continue

        if not vq_vae_loaded:
            _vq_vae = VQVAE()
            _vq_vae.eval()
            print("WARNING: Using untrained VQ-VAE (no checkpoint found)")

        # LLM
        from transformers import AutoModelForCausalLM, AutoTokenizer

        REPO_ID = "eeshaAI/zeeb"
        print("Loading trained model from EeshaAI/zeeb...")
        
        try:
            _tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True)
            if _tokenizer.pad_token is None:
                _tokenizer.pad_token = _tokenizer.eos_token
            
            _model = AutoModelForCausalLM.from_pretrained(
                REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
            )
            _model.eval()
            
            VIDEO_START_ID = _tokenizer.convert_tokens_to_ids("<video_start>")
            VIDEO_END_ID = _tokenizer.convert_tokens_to_ids("<video_end>")
            V_TOKEN_START_ID = _tokenizer.convert_tokens_to_ids("<v_0>")
            V_TOKEN_END_ID = _tokenizer.convert_tokens_to_ids("<v_1023>")
            print(f"Model loaded. Vocab: {len(_tokenizer)}")
        except Exception as e:
            print(f"Failed to load model from hub: {e}")
            print("Will load on-demand when generating.")
            _model = None
            _tokenizer = None

        return _model, _tokenizer, _vq_vae


def generate_video(prompt: str, max_tokens: int = 64, temperature: float = 0.9, top_k: int = 50):
    """Generate video from a text prompt using constrained decoding + VQ-VAE."""
    import torch
    import torch.nn.functional as F

    log = [f"Generating video for: '{prompt}'\n\n"]

    try:
        log.append("Loading models...\n")
        model, tokenizer, vq_vae = load_models()
        if model is None or tokenizer is None:
            return None, "Model not loaded yet. Please wait or try again."
        log.append("Models loaded.\n\n")
    except Exception as e:
        log.append(f"Load error: {e}\n")
        return None, "".join(log)

    # Format prompt
    text = f"Create a video of: {prompt} <video_start>"
    log.append(f"Prompt: {text}\n\n")
    log.append("Generating visual tokens (constrained decoding)...\n")

    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
    current_ids = inputs["input_ids"].clone()

    # Constrained decoding: only allow visual tokens + video_end
    vocab_size = len(tokenizer)
    visual_mask = torch.zeros(vocab_size, dtype=torch.bool)
    visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True
    visual_mask[VIDEO_END_ID] = True

    visual_token_ids = []

    with torch.no_grad():
        for step in range(max_tokens):
            outputs = model(input_ids=current_ids)
            logits = outputs.logits[:, -1, :]
            
            # Mask to only visual tokens
            masked = logits.clone()
            masked[0, ~visual_mask] = float('-inf')
            
            # Temperature scaling
            masked = masked / max(temperature, 0.01)
            
            # Top-k filtering
            if top_k > 0:
                top_k_values, _ = torch.topk(masked[0], min(top_k, masked.size(-1)))
                threshold = top_k_values[-1]
                masked[0, masked[0] < threshold] = float('-inf')
            
            probs = F.softmax(masked, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            next_id = next_token.item()

            if next_id == VIDEO_END_ID:
                break

            visual_idx = next_id - V_TOKEN_START_ID
            visual_token_ids.append(visual_idx)
            current_ids = torch.cat([current_ids, next_token], dim=-1)

    log.append(f"Generated {len(visual_token_ids)} visual tokens\n")

    if not visual_token_ids:
        import random
        visual_token_ids = [random.randint(0, 1023) for _ in range(64)]
        log.append("Fallback: random tokens\n")

    log.append(f"   Sample: {visual_token_ids[:20]}\n")
    log.append(f"   Unique: {len(set(visual_token_ids))}\n\n")

    # Decode frames through VQ-VAE
    log.append("Decoding tokens -> frames...\n")
    grid_h, grid_w = 8, 8
    tokens_per_frame = grid_h * grid_w
    num_frames = max(1, len(visual_token_ids) // tokens_per_frame)

    frames = []
    for fi in range(num_frames):
        ft = visual_token_ids[fi*tokens_per_frame:(fi+1)*tokens_per_frame]
        try:
            frame_tensor = vq_vae.decode_tokens(ft, grid_h, grid_w)
            frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
            # Output is 128x128 from the fixed decoder
            frames.append(frame_np)
        except Exception as e:
            log.append(f"  Frame decode error: {str(e)[:60]}\n")
            frames.append(_tokens_to_color(ft, grid_h, grid_w))

    if not frames:
        return None, "".join(log)

    # Save video
    try:
        from PIL import Image
        # Upscale to 256x256
        upscaled = [np.array(Image.fromarray(f).resize((256, 256), Image.BILINEAR)) for f in frames]

        try:
            import imageio
            out = "/tmp/generated_video.mp4"
            imageio.mimsave(out, upscaled, fps=2)
        except:
            out = "/tmp/generated_video.gif"
            pils = [Image.fromarray(f) for f in upscaled]
            pils[0].save(out, save_all=True, append_images=pils[1:], duration=500, loop=0)

        log.append(f"Video saved ({len(upscaled)} frames, 256x256)\nDone!\n")
        return out, "".join(log)
    except Exception as e:
        log.append(f"Save error: {e}\n")
        return None, "".join(log)


def _tokens_to_color(token_ids, grid_h=8, grid_w=8):
    """Fallback: convert tokens to colored grid."""
    frame = np.zeros((128, 128, 3), dtype=np.uint8)
    ch, cw = 128 // grid_h, 128 // grid_w
    for i, t in enumerate(token_ids[:grid_h * grid_w]):
        r, c = divmod(i, grid_w)
        frame[r*ch:(r+1)*ch, c*cw:(c+1)*cw] = [(t*37)%256, (t*73)%256, (t*113)%256]
    return frame


def get_log():
    try:
        with open(LOG_FILE, "r") as f:
            # Only read the last 5000 chars for efficiency
            f.seek(0, 2)  # seek to end
            size = f.tell()
            f.seek(max(0, size - 5000))
            content = f.read()
            return content
    except:
        return "No pipeline log yet."


def start_pipeline():
    """Start the full training pipeline in background."""
    from train_full_pipeline import run_pipeline
    t = threading.Thread(target=run_pipeline, args=(LOG_FILE,), daemon=True)
    t.start()
    return "Pipeline started! Click Refresh to see progress."


# Preload generation models
def preload():
    try:
        load_models()
        print("Generation models preloaded!")
    except Exception as e:
        print(f"Preload error: {e}")

threading.Thread(target=preload, daemon=True).start()


# Gradio UI
with gr.Blocks(title="Zeeb β€” Video-LLM", theme=gr.themes.Soft()) as demo:

    gr.Markdown("""
        # Zeeb β€” Video-LLM
        **OLMo 2 1B** + **LoRA** + **VQ-VAE** β†’ Text-to-Video generation.
        [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
    """)

    with gr.Tabs():
        with gr.Tab("Generate Video"):
            prompt_input = gr.Textbox(label="Video Description", value="A cat jumping on a sofa", lines=2)
            with gr.Row():
                max_tok = gr.Slider(32, 256, value=64, step=32, label="Max Visual Tokens")
                temperature = gr.Slider(0.1, 2.0, value=0.9, step=0.1, label="Temperature")
                top_k = gr.Slider(1, 200, value=50, step=1, label="Top-K")
            gen_btn = gr.Button("Generate Video", variant="primary", size="lg")
            video_out = gr.Video(label="Generated Video")
            gen_log = gr.Textbox(label="Log", lines=15, interactive=False, show_copy_button=True)
            gen_btn.click(fn=generate_video, inputs=[prompt_input, max_tok, temperature, top_k], outputs=[video_out, gen_log])

        with gr.Tab("Full Training Pipeline"):
            gr.Markdown("""
            ### Train from scratch with real data
            1. **Phase 1**: Train VQ-VAE on 10K real images (COCO/imagenette)
            2. **Phase 2**: Tokenize 10K image-text pairs through trained VQ-VAE
            3. **Phase 3**: Fine-tune OLMo 2 1B + LoRA on 5K tokenized samples
            4. **Phase 4**: Push trained model to EeshaAI/zeeb

            Checkpoints saved to persistent storage (survives Space restarts).
            Training takes several hours on CPU.
            """)
            pipe_btn = gr.Button("Start Full Pipeline", variant="primary", size="lg")
            ref_btn = gr.Button("Refresh Log")
            pipe_log = gr.Textbox(label="Pipeline Log", value=lambda: get_log(), lines=30,
                                  interactive=False, show_copy_button=True)
            pipe_btn.click(fn=start_pipeline, outputs=pipe_log)
            ref_btn.click(fn=get_log, outputs=pipe_log)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)