| |
| """ |
| Gradio App for EeshaAI/Zeeb — Video Generation + Training Pipeline |
| =================================================================== |
| Tab 1: Generate Video (uses trained model + VQ-VAE) |
| Tab 2: Run Full Pipeline (VQ-VAE training → dataset tokenization → LLM training → push) |
| """ |
|
|
| import os |
| import re |
| import threading |
| import numpy as np |
| import gradio as gr |
|
|
| LOG_FILE = os.path.join(os.environ.get("DATA_DIR", "/data"), "pipeline_log.txt") |
|
|
| |
| _model = None |
| _tokenizer = None |
| _vq_vae = None |
| _loading_lock = threading.Lock() |
|
|
| |
| VIDEO_START_ID = None |
| VIDEO_END_ID = None |
| V_TOKEN_START_ID = None |
| V_TOKEN_END_ID = None |
|
|
|
|
| def load_models(): |
| """Load the trained LLM and VQ-VAE (lazy, cached).""" |
| global _model, _tokenizer, _vq_vae |
| global VIDEO_START_ID, VIDEO_END_ID, V_TOKEN_START_ID, V_TOKEN_END_ID |
|
|
| with _loading_lock: |
| if _model is not None and _tokenizer is not None: |
| return _model, _tokenizer, _vq_vae |
|
|
| import torch |
| import torch.nn as nn |
|
|
| |
| class Encoder(nn.Module): |
| def __init__(self, in_channels=3, latent_dim=256): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(in_channels, 64, 4, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 128, 4, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(128, 256, 4, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(256, latent_dim, 4, stride=2, padding=1), |
| ) |
| def forward(self, x): |
| return self.net(x) |
|
|
| class VectorQuantizer(nn.Module): |
| def __init__(self, codebook_size=1024, codebook_dim=256, commitment_cost=0.25): |
| super().__init__() |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
| self.commitment_cost = commitment_cost |
| self.codebook = nn.Embedding(codebook_size, codebook_dim) |
| self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) |
|
|
| def forward(self, z): |
| B, H, W, C = z.shape |
| z_flat = z.reshape(-1, C) |
| dist = (z_flat.unsqueeze(1) - self.codebook.weight.unsqueeze(0)).pow(2).sum(-1) |
| indices = dist.argmin(dim=1) |
| z_q = self.codebook(indices).reshape(B, H, W, C) |
| commitment_loss = torch.nn.functional.mse_loss(z_flat, z_q.reshape(-1, C).detach()) |
| codebook_loss = torch.nn.functional.mse_loss(z_q.reshape(-1, C), z_flat.detach()) |
| loss = codebook_loss + self.commitment_cost * commitment_loss |
| z_q_st = z + (z_q - z).detach() |
| return z_q_st, loss, indices.reshape(B, H, W) |
|
|
| class Decoder(nn.Module): |
| def __init__(self, out_channels=3, latent_dim=256): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), nn.ReLU(), |
| nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(), |
| nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(), |
| nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), nn.Sigmoid(), |
| ) |
| def forward(self, x): |
| return self.net(x) |
|
|
| class VQVAE(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.encoder = Encoder() |
| self.quantizer = VectorQuantizer() |
| self.proj_in = nn.Linear(256, 256) |
| self.proj_out = nn.Linear(256, 256) |
| self.decoder = Decoder() |
|
|
| def decode_tokens(self, token_ids, grid_h=8, grid_w=8): |
| if isinstance(token_ids, list): |
| token_ids = torch.tensor(token_ids, dtype=torch.long) |
| token_ids = token_ids[:grid_h * grid_w] |
| if len(token_ids) < grid_h * grid_w: |
| token_ids = torch.cat([token_ids, torch.zeros(grid_h * grid_w - len(token_ids), dtype=torch.long)]) |
| z_q = self.quantizer.codebook(token_ids) |
| z_q = self.proj_out(z_q) |
| z_q = z_q.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2) |
| return self.decoder(z_q) |
|
|
| |
| PERSIST_DIR = os.path.join(os.environ.get("DATA_DIR", "/data"), "zeeb_checkpoints") |
| vq_paths = [ |
| os.path.join(PERSIST_DIR, "vq_vae_best.pt"), |
| os.path.join(PERSIST_DIR, "vq_vae_latest.pt"), |
| "vq_vae_real.pt", |
| "vq_vae_final.pt", |
| ] |
| |
| vq_vae_loaded = False |
| for vq_path in vq_paths: |
| if os.path.exists(vq_path): |
| try: |
| _vq_vae = VQVAE() |
| state_dict = torch.load(vq_path, map_location="cpu", weights_only=False) |
| |
| if isinstance(state_dict, dict) and "model_state_dict" in state_dict: |
| state_dict = state_dict["model_state_dict"] |
| _vq_vae.load_state_dict(state_dict, strict=True) |
| _vq_vae.eval() |
| vq_vae_loaded = True |
| print(f"VQ-VAE loaded from {vq_path}") |
| break |
| except Exception as e: |
| print(f"Failed to load VQ-VAE from {vq_path}: {e}") |
| continue |
|
|
| if not vq_vae_loaded: |
| _vq_vae = VQVAE() |
| _vq_vae.eval() |
| print("WARNING: Using untrained VQ-VAE (no checkpoint found)") |
|
|
| |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| REPO_ID = "eeshaAI/zeeb" |
| print("Loading trained model from EeshaAI/zeeb...") |
| |
| try: |
| _tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True) |
| if _tokenizer.pad_token is None: |
| _tokenizer.pad_token = _tokenizer.eos_token |
| |
| _model = AutoModelForCausalLM.from_pretrained( |
| REPO_ID, trust_remote_code=True, torch_dtype=torch.float32 |
| ) |
| _model.eval() |
| |
| VIDEO_START_ID = _tokenizer.convert_tokens_to_ids("<video_start>") |
| VIDEO_END_ID = _tokenizer.convert_tokens_to_ids("<video_end>") |
| V_TOKEN_START_ID = _tokenizer.convert_tokens_to_ids("<v_0>") |
| V_TOKEN_END_ID = _tokenizer.convert_tokens_to_ids("<v_1023>") |
| print(f"Model loaded. Vocab: {len(_tokenizer)}") |
| except Exception as e: |
| print(f"Failed to load model from hub: {e}") |
| print("Will load on-demand when generating.") |
| _model = None |
| _tokenizer = None |
|
|
| return _model, _tokenizer, _vq_vae |
|
|
|
|
| def generate_video(prompt: str, max_tokens: int = 64, temperature: float = 0.9, top_k: int = 50): |
| """Generate video from a text prompt using constrained decoding + VQ-VAE.""" |
| import torch |
| import torch.nn.functional as F |
|
|
| log = [f"Generating video for: '{prompt}'\n\n"] |
|
|
| try: |
| log.append("Loading models...\n") |
| model, tokenizer, vq_vae = load_models() |
| if model is None or tokenizer is None: |
| return None, "Model not loaded yet. Please wait or try again." |
| log.append("Models loaded.\n\n") |
| except Exception as e: |
| log.append(f"Load error: {e}\n") |
| return None, "".join(log) |
|
|
| |
| text = f"Create a video of: {prompt} <video_start>" |
| log.append(f"Prompt: {text}\n\n") |
| log.append("Generating visual tokens (constrained decoding)...\n") |
|
|
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) |
| current_ids = inputs["input_ids"].clone() |
|
|
| |
| vocab_size = len(tokenizer) |
| visual_mask = torch.zeros(vocab_size, dtype=torch.bool) |
| visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True |
| visual_mask[VIDEO_END_ID] = True |
|
|
| visual_token_ids = [] |
|
|
| with torch.no_grad(): |
| for step in range(max_tokens): |
| outputs = model(input_ids=current_ids) |
| logits = outputs.logits[:, -1, :] |
| |
| |
| masked = logits.clone() |
| masked[0, ~visual_mask] = float('-inf') |
| |
| |
| masked = masked / max(temperature, 0.01) |
| |
| |
| if top_k > 0: |
| top_k_values, _ = torch.topk(masked[0], min(top_k, masked.size(-1))) |
| threshold = top_k_values[-1] |
| masked[0, masked[0] < threshold] = float('-inf') |
| |
| probs = F.softmax(masked, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| next_id = next_token.item() |
|
|
| if next_id == VIDEO_END_ID: |
| break |
|
|
| visual_idx = next_id - V_TOKEN_START_ID |
| visual_token_ids.append(visual_idx) |
| current_ids = torch.cat([current_ids, next_token], dim=-1) |
|
|
| log.append(f"Generated {len(visual_token_ids)} visual tokens\n") |
|
|
| if not visual_token_ids: |
| import random |
| visual_token_ids = [random.randint(0, 1023) for _ in range(64)] |
| log.append("Fallback: random tokens\n") |
|
|
| log.append(f" Sample: {visual_token_ids[:20]}\n") |
| log.append(f" Unique: {len(set(visual_token_ids))}\n\n") |
|
|
| |
| log.append("Decoding tokens -> frames...\n") |
| grid_h, grid_w = 8, 8 |
| tokens_per_frame = grid_h * grid_w |
| num_frames = max(1, len(visual_token_ids) // tokens_per_frame) |
|
|
| frames = [] |
| for fi in range(num_frames): |
| ft = visual_token_ids[fi*tokens_per_frame:(fi+1)*tokens_per_frame] |
| try: |
| frame_tensor = vq_vae.decode_tokens(ft, grid_h, grid_w) |
| frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8) |
| |
| frames.append(frame_np) |
| except Exception as e: |
| log.append(f" Frame decode error: {str(e)[:60]}\n") |
| frames.append(_tokens_to_color(ft, grid_h, grid_w)) |
|
|
| if not frames: |
| return None, "".join(log) |
|
|
| |
| try: |
| from PIL import Image |
| |
| upscaled = [np.array(Image.fromarray(f).resize((256, 256), Image.BILINEAR)) for f in frames] |
|
|
| try: |
| import imageio |
| out = "/tmp/generated_video.mp4" |
| imageio.mimsave(out, upscaled, fps=2) |
| except: |
| out = "/tmp/generated_video.gif" |
| pils = [Image.fromarray(f) for f in upscaled] |
| pils[0].save(out, save_all=True, append_images=pils[1:], duration=500, loop=0) |
|
|
| log.append(f"Video saved ({len(upscaled)} frames, 256x256)\nDone!\n") |
| return out, "".join(log) |
| except Exception as e: |
| log.append(f"Save error: {e}\n") |
| return None, "".join(log) |
|
|
|
|
| def _tokens_to_color(token_ids, grid_h=8, grid_w=8): |
| """Fallback: convert tokens to colored grid.""" |
| frame = np.zeros((128, 128, 3), dtype=np.uint8) |
| ch, cw = 128 // grid_h, 128 // grid_w |
| for i, t in enumerate(token_ids[:grid_h * grid_w]): |
| r, c = divmod(i, grid_w) |
| frame[r*ch:(r+1)*ch, c*cw:(c+1)*cw] = [(t*37)%256, (t*73)%256, (t*113)%256] |
| return frame |
|
|
|
|
| def get_log(): |
| try: |
| with open(LOG_FILE, "r") as f: |
| |
| f.seek(0, 2) |
| size = f.tell() |
| f.seek(max(0, size - 5000)) |
| content = f.read() |
| return content |
| except: |
| return "No pipeline log yet." |
|
|
|
|
| def start_pipeline(): |
| """Start the full training pipeline in background.""" |
| from train_full_pipeline import run_pipeline |
| t = threading.Thread(target=run_pipeline, args=(LOG_FILE,), daemon=True) |
| t.start() |
| return "Pipeline started! Click Refresh to see progress." |
|
|
|
|
| |
| def preload(): |
| try: |
| load_models() |
| print("Generation models preloaded!") |
| except Exception as e: |
| print(f"Preload error: {e}") |
|
|
| threading.Thread(target=preload, daemon=True).start() |
|
|
|
|
| |
| with gr.Blocks(title="Zeeb — Video-LLM", theme=gr.themes.Soft()) as demo: |
|
|
| gr.Markdown(""" |
| # Zeeb — Video-LLM |
| **OLMo 2 1B** + **LoRA** + **VQ-VAE** → Text-to-Video generation. |
| [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb) |
| """) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Generate Video"): |
| prompt_input = gr.Textbox(label="Video Description", value="A cat jumping on a sofa", lines=2) |
| with gr.Row(): |
| max_tok = gr.Slider(32, 256, value=64, step=32, label="Max Visual Tokens") |
| temperature = gr.Slider(0.1, 2.0, value=0.9, step=0.1, label="Temperature") |
| top_k = gr.Slider(1, 200, value=50, step=1, label="Top-K") |
| gen_btn = gr.Button("Generate Video", variant="primary", size="lg") |
| video_out = gr.Video(label="Generated Video") |
| gen_log = gr.Textbox(label="Log", lines=15, interactive=False, show_copy_button=True) |
| gen_btn.click(fn=generate_video, inputs=[prompt_input, max_tok, temperature, top_k], outputs=[video_out, gen_log]) |
|
|
| with gr.Tab("Full Training Pipeline"): |
| gr.Markdown(""" |
| ### Train from scratch with real data |
| 1. **Phase 1**: Train VQ-VAE on 10K real images (COCO/imagenette) |
| 2. **Phase 2**: Tokenize 10K image-text pairs through trained VQ-VAE |
| 3. **Phase 3**: Fine-tune OLMo 2 1B + LoRA on 5K tokenized samples |
| 4. **Phase 4**: Push trained model to EeshaAI/zeeb |
| |
| Checkpoints saved to persistent storage (survives Space restarts). |
| Training takes several hours on CPU. |
| """) |
| pipe_btn = gr.Button("Start Full Pipeline", variant="primary", size="lg") |
| ref_btn = gr.Button("Refresh Log") |
| pipe_log = gr.Textbox(label="Pipeline Log", value=lambda: get_log(), lines=30, |
| interactive=False, show_copy_button=True) |
| pipe_btn.click(fn=start_pipeline, outputs=pipe_log) |
| ref_btn.click(fn=get_log, outputs=pipe_log) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) |
|
|