#!/usr/bin/env python3 """ Gradio App for EeshaAI/Zeeb — Video Generation + Training Pipeline =================================================================== Tab 1: Generate Video (uses trained model + VQ-VAE) Tab 2: Run Full Pipeline (VQ-VAE training → dataset tokenization → LLM training → push) """ import os import re import threading import numpy as np import gradio as gr LOG_FILE = os.path.join(os.environ.get("DATA_DIR", "/data"), "pipeline_log.txt") # Global model cache _model = None _tokenizer = None _vq_vae = None _loading_lock = threading.Lock() # Visual token ID range VIDEO_START_ID = None VIDEO_END_ID = None V_TOKEN_START_ID = None V_TOKEN_END_ID = None def load_models(): """Load the trained LLM and VQ-VAE (lazy, cached).""" global _model, _tokenizer, _vq_vae global VIDEO_START_ID, VIDEO_END_ID, V_TOKEN_START_ID, V_TOKEN_END_ID with _loading_lock: if _model is not None and _tokenizer is not None: return _model, _tokenizer, _vq_vae import torch import torch.nn as nn # Full VQ-VAE model (same architecture as training) class Encoder(nn.Module): def __init__(self, in_channels=3, latent_dim=256): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_channels, 64, 4, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU(), nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.ReLU(), nn.Conv2d(256, latent_dim, 4, stride=2, padding=1), ) def forward(self, x): return self.net(x) class VectorQuantizer(nn.Module): def __init__(self, codebook_size=1024, codebook_dim=256, commitment_cost=0.25): super().__init__() self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.commitment_cost = commitment_cost self.codebook = nn.Embedding(codebook_size, codebook_dim) self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) def forward(self, z): B, H, W, C = z.shape z_flat = z.reshape(-1, C) dist = (z_flat.unsqueeze(1) - self.codebook.weight.unsqueeze(0)).pow(2).sum(-1) indices = dist.argmin(dim=1) z_q = self.codebook(indices).reshape(B, H, W, C) commitment_loss = torch.nn.functional.mse_loss(z_flat, z_q.reshape(-1, C).detach()) codebook_loss = torch.nn.functional.mse_loss(z_q.reshape(-1, C), z_flat.detach()) loss = codebook_loss + self.commitment_cost * commitment_loss z_q_st = z + (z_q - z).detach() return z_q_st, loss, indices.reshape(B, H, W) class Decoder(nn.Module): def __init__(self, out_channels=3, latent_dim=256): super().__init__() self.net = nn.Sequential( nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(), nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), nn.Sigmoid(), ) def forward(self, x): return self.net(x) class VQVAE(nn.Module): def __init__(self): super().__init__() self.encoder = Encoder() self.quantizer = VectorQuantizer() self.proj_in = nn.Linear(256, 256) self.proj_out = nn.Linear(256, 256) self.decoder = Decoder() def decode_tokens(self, token_ids, grid_h=8, grid_w=8): if isinstance(token_ids, list): token_ids = torch.tensor(token_ids, dtype=torch.long) token_ids = token_ids[:grid_h * grid_w] if len(token_ids) < grid_h * grid_w: token_ids = torch.cat([token_ids, torch.zeros(grid_h * grid_w - len(token_ids), dtype=torch.long)]) z_q = self.quantizer.codebook(token_ids) z_q = self.proj_out(z_q) z_q = z_q.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2) return self.decoder(z_q) # Try loading from multiple locations PERSIST_DIR = os.path.join(os.environ.get("DATA_DIR", "/data"), "zeeb_checkpoints") vq_paths = [ os.path.join(PERSIST_DIR, "vq_vae_best.pt"), os.path.join(PERSIST_DIR, "vq_vae_latest.pt"), "vq_vae_real.pt", "vq_vae_final.pt", ] vq_vae_loaded = False for vq_path in vq_paths: if os.path.exists(vq_path): try: _vq_vae = VQVAE() state_dict = torch.load(vq_path, map_location="cpu", weights_only=False) # Handle different save formats if isinstance(state_dict, dict) and "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] _vq_vae.load_state_dict(state_dict, strict=True) _vq_vae.eval() vq_vae_loaded = True print(f"VQ-VAE loaded from {vq_path}") break except Exception as e: print(f"Failed to load VQ-VAE from {vq_path}: {e}") continue if not vq_vae_loaded: _vq_vae = VQVAE() _vq_vae.eval() print("WARNING: Using untrained VQ-VAE (no checkpoint found)") # LLM from transformers import AutoModelForCausalLM, AutoTokenizer REPO_ID = "eeshaAI/zeeb" print("Loading trained model from EeshaAI/zeeb...") try: _tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True) if _tokenizer.pad_token is None: _tokenizer.pad_token = _tokenizer.eos_token _model = AutoModelForCausalLM.from_pretrained( REPO_ID, trust_remote_code=True, torch_dtype=torch.float32 ) _model.eval() VIDEO_START_ID = _tokenizer.convert_tokens_to_ids("") VIDEO_END_ID = _tokenizer.convert_tokens_to_ids("") V_TOKEN_START_ID = _tokenizer.convert_tokens_to_ids("") V_TOKEN_END_ID = _tokenizer.convert_tokens_to_ids("") print(f"Model loaded. Vocab: {len(_tokenizer)}") except Exception as e: print(f"Failed to load model from hub: {e}") print("Will load on-demand when generating.") _model = None _tokenizer = None return _model, _tokenizer, _vq_vae def generate_video(prompt: str, max_tokens: int = 64, temperature: float = 0.9, top_k: int = 50): """Generate video from a text prompt using constrained decoding + VQ-VAE.""" import torch import torch.nn.functional as F log = [f"Generating video for: '{prompt}'\n\n"] try: log.append("Loading models...\n") model, tokenizer, vq_vae = load_models() if model is None or tokenizer is None: return None, "Model not loaded yet. Please wait or try again." log.append("Models loaded.\n\n") except Exception as e: log.append(f"Load error: {e}\n") return None, "".join(log) # Format prompt text = f"Create a video of: {prompt} " log.append(f"Prompt: {text}\n\n") log.append("Generating visual tokens (constrained decoding)...\n") inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) current_ids = inputs["input_ids"].clone() # Constrained decoding: only allow visual tokens + video_end vocab_size = len(tokenizer) visual_mask = torch.zeros(vocab_size, dtype=torch.bool) visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True visual_mask[VIDEO_END_ID] = True visual_token_ids = [] with torch.no_grad(): for step in range(max_tokens): outputs = model(input_ids=current_ids) logits = outputs.logits[:, -1, :] # Mask to only visual tokens masked = logits.clone() masked[0, ~visual_mask] = float('-inf') # Temperature scaling masked = masked / max(temperature, 0.01) # Top-k filtering if top_k > 0: top_k_values, _ = torch.topk(masked[0], min(top_k, masked.size(-1))) threshold = top_k_values[-1] masked[0, masked[0] < threshold] = float('-inf') probs = F.softmax(masked, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_id = next_token.item() if next_id == VIDEO_END_ID: break visual_idx = next_id - V_TOKEN_START_ID visual_token_ids.append(visual_idx) current_ids = torch.cat([current_ids, next_token], dim=-1) log.append(f"Generated {len(visual_token_ids)} visual tokens\n") if not visual_token_ids: import random visual_token_ids = [random.randint(0, 1023) for _ in range(64)] log.append("Fallback: random tokens\n") log.append(f" Sample: {visual_token_ids[:20]}\n") log.append(f" Unique: {len(set(visual_token_ids))}\n\n") # Decode frames through VQ-VAE log.append("Decoding tokens -> frames...\n") grid_h, grid_w = 8, 8 tokens_per_frame = grid_h * grid_w num_frames = max(1, len(visual_token_ids) // tokens_per_frame) frames = [] for fi in range(num_frames): ft = visual_token_ids[fi*tokens_per_frame:(fi+1)*tokens_per_frame] try: frame_tensor = vq_vae.decode_tokens(ft, grid_h, grid_w) frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8) # Output is 128x128 from the fixed decoder frames.append(frame_np) except Exception as e: log.append(f" Frame decode error: {str(e)[:60]}\n") frames.append(_tokens_to_color(ft, grid_h, grid_w)) if not frames: return None, "".join(log) # Save video try: from PIL import Image # Upscale to 256x256 upscaled = [np.array(Image.fromarray(f).resize((256, 256), Image.BILINEAR)) for f in frames] try: import imageio out = "/tmp/generated_video.mp4" imageio.mimsave(out, upscaled, fps=2) except: out = "/tmp/generated_video.gif" pils = [Image.fromarray(f) for f in upscaled] pils[0].save(out, save_all=True, append_images=pils[1:], duration=500, loop=0) log.append(f"Video saved ({len(upscaled)} frames, 256x256)\nDone!\n") return out, "".join(log) except Exception as e: log.append(f"Save error: {e}\n") return None, "".join(log) def _tokens_to_color(token_ids, grid_h=8, grid_w=8): """Fallback: convert tokens to colored grid.""" frame = np.zeros((128, 128, 3), dtype=np.uint8) ch, cw = 128 // grid_h, 128 // grid_w for i, t in enumerate(token_ids[:grid_h * grid_w]): r, c = divmod(i, grid_w) frame[r*ch:(r+1)*ch, c*cw:(c+1)*cw] = [(t*37)%256, (t*73)%256, (t*113)%256] return frame def get_log(): try: with open(LOG_FILE, "r") as f: # Only read the last 5000 chars for efficiency f.seek(0, 2) # seek to end size = f.tell() f.seek(max(0, size - 5000)) content = f.read() return content except: return "No pipeline log yet." def start_pipeline(): """Start the full training pipeline in background.""" from train_full_pipeline import run_pipeline t = threading.Thread(target=run_pipeline, args=(LOG_FILE,), daemon=True) t.start() return "Pipeline started! Click Refresh to see progress." # Preload generation models def preload(): try: load_models() print("Generation models preloaded!") except Exception as e: print(f"Preload error: {e}") threading.Thread(target=preload, daemon=True).start() # Gradio UI with gr.Blocks(title="Zeeb — Video-LLM", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # Zeeb — Video-LLM **OLMo 2 1B** + **LoRA** + **VQ-VAE** → Text-to-Video generation. [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb) """) with gr.Tabs(): with gr.Tab("Generate Video"): prompt_input = gr.Textbox(label="Video Description", value="A cat jumping on a sofa", lines=2) with gr.Row(): max_tok = gr.Slider(32, 256, value=64, step=32, label="Max Visual Tokens") temperature = gr.Slider(0.1, 2.0, value=0.9, step=0.1, label="Temperature") top_k = gr.Slider(1, 200, value=50, step=1, label="Top-K") gen_btn = gr.Button("Generate Video", variant="primary", size="lg") video_out = gr.Video(label="Generated Video") gen_log = gr.Textbox(label="Log", lines=15, interactive=False, show_copy_button=True) gen_btn.click(fn=generate_video, inputs=[prompt_input, max_tok, temperature, top_k], outputs=[video_out, gen_log]) with gr.Tab("Full Training Pipeline"): gr.Markdown(""" ### Train from scratch with real data 1. **Phase 1**: Train VQ-VAE on 10K real images (COCO/imagenette) 2. **Phase 2**: Tokenize 10K image-text pairs through trained VQ-VAE 3. **Phase 3**: Fine-tune OLMo 2 1B + LoRA on 5K tokenized samples 4. **Phase 4**: Push trained model to EeshaAI/zeeb Checkpoints saved to persistent storage (survives Space restarts). Training takes several hours on CPU. """) pipe_btn = gr.Button("Start Full Pipeline", variant="primary", size="lg") ref_btn = gr.Button("Refresh Log") pipe_log = gr.Textbox(label="Pipeline Log", value=lambda: get_log(), lines=30, interactive=False, show_copy_button=True) pipe_btn.click(fn=start_pipeline, outputs=pipe_log) ref_btn.click(fn=get_log, outputs=pipe_log) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)