Zeeb / app.py
eeshaAI's picture
Fix: efficient log reading (last 5000 chars only)
2c311a6 verified
#!/usr/bin/env python3
"""
Gradio App for EeshaAI/Zeeb — Video Generation + Training Pipeline
===================================================================
Tab 1: Generate Video (uses trained model + VQ-VAE)
Tab 2: Run Full Pipeline (VQ-VAE training → dataset tokenization → LLM training → push)
"""
import os
import re
import threading
import numpy as np
import gradio as gr
LOG_FILE = os.path.join(os.environ.get("DATA_DIR", "/data"), "pipeline_log.txt")
# Global model cache
_model = None
_tokenizer = None
_vq_vae = None
_loading_lock = threading.Lock()
# Visual token ID range
VIDEO_START_ID = None
VIDEO_END_ID = None
V_TOKEN_START_ID = None
V_TOKEN_END_ID = None
def load_models():
"""Load the trained LLM and VQ-VAE (lazy, cached)."""
global _model, _tokenizer, _vq_vae
global VIDEO_START_ID, VIDEO_END_ID, V_TOKEN_START_ID, V_TOKEN_END_ID
with _loading_lock:
if _model is not None and _tokenizer is not None:
return _model, _tokenizer, _vq_vae
import torch
import torch.nn as nn
# Full VQ-VAE model (same architecture as training)
class Encoder(nn.Module):
def __init__(self, in_channels=3, latent_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(256, latent_dim, 4, stride=2, padding=1),
)
def forward(self, x):
return self.net(x)
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size=1024, codebook_dim=256, commitment_cost=0.25):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment_cost = commitment_cost
self.codebook = nn.Embedding(codebook_size, codebook_dim)
self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
def forward(self, z):
B, H, W, C = z.shape
z_flat = z.reshape(-1, C)
dist = (z_flat.unsqueeze(1) - self.codebook.weight.unsqueeze(0)).pow(2).sum(-1)
indices = dist.argmin(dim=1)
z_q = self.codebook(indices).reshape(B, H, W, C)
commitment_loss = torch.nn.functional.mse_loss(z_flat, z_q.reshape(-1, C).detach())
codebook_loss = torch.nn.functional.mse_loss(z_q.reshape(-1, C), z_flat.detach())
loss = codebook_loss + self.commitment_cost * commitment_loss
z_q_st = z + (z_q - z).detach()
return z_q_st, loss, indices.reshape(B, H, W)
class Decoder(nn.Module):
def __init__(self, out_channels=3, latent_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), nn.Sigmoid(),
)
def forward(self, x):
return self.net(x)
class VQVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.quantizer = VectorQuantizer()
self.proj_in = nn.Linear(256, 256)
self.proj_out = nn.Linear(256, 256)
self.decoder = Decoder()
def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
if isinstance(token_ids, list):
token_ids = torch.tensor(token_ids, dtype=torch.long)
token_ids = token_ids[:grid_h * grid_w]
if len(token_ids) < grid_h * grid_w:
token_ids = torch.cat([token_ids, torch.zeros(grid_h * grid_w - len(token_ids), dtype=torch.long)])
z_q = self.quantizer.codebook(token_ids)
z_q = self.proj_out(z_q)
z_q = z_q.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)
return self.decoder(z_q)
# Try loading from multiple locations
PERSIST_DIR = os.path.join(os.environ.get("DATA_DIR", "/data"), "zeeb_checkpoints")
vq_paths = [
os.path.join(PERSIST_DIR, "vq_vae_best.pt"),
os.path.join(PERSIST_DIR, "vq_vae_latest.pt"),
"vq_vae_real.pt",
"vq_vae_final.pt",
]
vq_vae_loaded = False
for vq_path in vq_paths:
if os.path.exists(vq_path):
try:
_vq_vae = VQVAE()
state_dict = torch.load(vq_path, map_location="cpu", weights_only=False)
# Handle different save formats
if isinstance(state_dict, dict) and "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
_vq_vae.load_state_dict(state_dict, strict=True)
_vq_vae.eval()
vq_vae_loaded = True
print(f"VQ-VAE loaded from {vq_path}")
break
except Exception as e:
print(f"Failed to load VQ-VAE from {vq_path}: {e}")
continue
if not vq_vae_loaded:
_vq_vae = VQVAE()
_vq_vae.eval()
print("WARNING: Using untrained VQ-VAE (no checkpoint found)")
# LLM
from transformers import AutoModelForCausalLM, AutoTokenizer
REPO_ID = "eeshaAI/zeeb"
print("Loading trained model from EeshaAI/zeeb...")
try:
_tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True)
if _tokenizer.pad_token is None:
_tokenizer.pad_token = _tokenizer.eos_token
_model = AutoModelForCausalLM.from_pretrained(
REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
)
_model.eval()
VIDEO_START_ID = _tokenizer.convert_tokens_to_ids("<video_start>")
VIDEO_END_ID = _tokenizer.convert_tokens_to_ids("<video_end>")
V_TOKEN_START_ID = _tokenizer.convert_tokens_to_ids("<v_0>")
V_TOKEN_END_ID = _tokenizer.convert_tokens_to_ids("<v_1023>")
print(f"Model loaded. Vocab: {len(_tokenizer)}")
except Exception as e:
print(f"Failed to load model from hub: {e}")
print("Will load on-demand when generating.")
_model = None
_tokenizer = None
return _model, _tokenizer, _vq_vae
def generate_video(prompt: str, max_tokens: int = 64, temperature: float = 0.9, top_k: int = 50):
"""Generate video from a text prompt using constrained decoding + VQ-VAE."""
import torch
import torch.nn.functional as F
log = [f"Generating video for: '{prompt}'\n\n"]
try:
log.append("Loading models...\n")
model, tokenizer, vq_vae = load_models()
if model is None or tokenizer is None:
return None, "Model not loaded yet. Please wait or try again."
log.append("Models loaded.\n\n")
except Exception as e:
log.append(f"Load error: {e}\n")
return None, "".join(log)
# Format prompt
text = f"Create a video of: {prompt} <video_start>"
log.append(f"Prompt: {text}\n\n")
log.append("Generating visual tokens (constrained decoding)...\n")
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
current_ids = inputs["input_ids"].clone()
# Constrained decoding: only allow visual tokens + video_end
vocab_size = len(tokenizer)
visual_mask = torch.zeros(vocab_size, dtype=torch.bool)
visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True
visual_mask[VIDEO_END_ID] = True
visual_token_ids = []
with torch.no_grad():
for step in range(max_tokens):
outputs = model(input_ids=current_ids)
logits = outputs.logits[:, -1, :]
# Mask to only visual tokens
masked = logits.clone()
masked[0, ~visual_mask] = float('-inf')
# Temperature scaling
masked = masked / max(temperature, 0.01)
# Top-k filtering
if top_k > 0:
top_k_values, _ = torch.topk(masked[0], min(top_k, masked.size(-1)))
threshold = top_k_values[-1]
masked[0, masked[0] < threshold] = float('-inf')
probs = F.softmax(masked, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_id = next_token.item()
if next_id == VIDEO_END_ID:
break
visual_idx = next_id - V_TOKEN_START_ID
visual_token_ids.append(visual_idx)
current_ids = torch.cat([current_ids, next_token], dim=-1)
log.append(f"Generated {len(visual_token_ids)} visual tokens\n")
if not visual_token_ids:
import random
visual_token_ids = [random.randint(0, 1023) for _ in range(64)]
log.append("Fallback: random tokens\n")
log.append(f" Sample: {visual_token_ids[:20]}\n")
log.append(f" Unique: {len(set(visual_token_ids))}\n\n")
# Decode frames through VQ-VAE
log.append("Decoding tokens -> frames...\n")
grid_h, grid_w = 8, 8
tokens_per_frame = grid_h * grid_w
num_frames = max(1, len(visual_token_ids) // tokens_per_frame)
frames = []
for fi in range(num_frames):
ft = visual_token_ids[fi*tokens_per_frame:(fi+1)*tokens_per_frame]
try:
frame_tensor = vq_vae.decode_tokens(ft, grid_h, grid_w)
frame_np = (frame_tensor[0].permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
# Output is 128x128 from the fixed decoder
frames.append(frame_np)
except Exception as e:
log.append(f" Frame decode error: {str(e)[:60]}\n")
frames.append(_tokens_to_color(ft, grid_h, grid_w))
if not frames:
return None, "".join(log)
# Save video
try:
from PIL import Image
# Upscale to 256x256
upscaled = [np.array(Image.fromarray(f).resize((256, 256), Image.BILINEAR)) for f in frames]
try:
import imageio
out = "/tmp/generated_video.mp4"
imageio.mimsave(out, upscaled, fps=2)
except:
out = "/tmp/generated_video.gif"
pils = [Image.fromarray(f) for f in upscaled]
pils[0].save(out, save_all=True, append_images=pils[1:], duration=500, loop=0)
log.append(f"Video saved ({len(upscaled)} frames, 256x256)\nDone!\n")
return out, "".join(log)
except Exception as e:
log.append(f"Save error: {e}\n")
return None, "".join(log)
def _tokens_to_color(token_ids, grid_h=8, grid_w=8):
"""Fallback: convert tokens to colored grid."""
frame = np.zeros((128, 128, 3), dtype=np.uint8)
ch, cw = 128 // grid_h, 128 // grid_w
for i, t in enumerate(token_ids[:grid_h * grid_w]):
r, c = divmod(i, grid_w)
frame[r*ch:(r+1)*ch, c*cw:(c+1)*cw] = [(t*37)%256, (t*73)%256, (t*113)%256]
return frame
def get_log():
try:
with open(LOG_FILE, "r") as f:
# Only read the last 5000 chars for efficiency
f.seek(0, 2) # seek to end
size = f.tell()
f.seek(max(0, size - 5000))
content = f.read()
return content
except:
return "No pipeline log yet."
def start_pipeline():
"""Start the full training pipeline in background."""
from train_full_pipeline import run_pipeline
t = threading.Thread(target=run_pipeline, args=(LOG_FILE,), daemon=True)
t.start()
return "Pipeline started! Click Refresh to see progress."
# Preload generation models
def preload():
try:
load_models()
print("Generation models preloaded!")
except Exception as e:
print(f"Preload error: {e}")
threading.Thread(target=preload, daemon=True).start()
# Gradio UI
with gr.Blocks(title="Zeeb — Video-LLM", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Zeeb — Video-LLM
**OLMo 2 1B** + **LoRA** + **VQ-VAE** → Text-to-Video generation.
[EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb)
""")
with gr.Tabs():
with gr.Tab("Generate Video"):
prompt_input = gr.Textbox(label="Video Description", value="A cat jumping on a sofa", lines=2)
with gr.Row():
max_tok = gr.Slider(32, 256, value=64, step=32, label="Max Visual Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.9, step=0.1, label="Temperature")
top_k = gr.Slider(1, 200, value=50, step=1, label="Top-K")
gen_btn = gr.Button("Generate Video", variant="primary", size="lg")
video_out = gr.Video(label="Generated Video")
gen_log = gr.Textbox(label="Log", lines=15, interactive=False, show_copy_button=True)
gen_btn.click(fn=generate_video, inputs=[prompt_input, max_tok, temperature, top_k], outputs=[video_out, gen_log])
with gr.Tab("Full Training Pipeline"):
gr.Markdown("""
### Train from scratch with real data
1. **Phase 1**: Train VQ-VAE on 10K real images (COCO/imagenette)
2. **Phase 2**: Tokenize 10K image-text pairs through trained VQ-VAE
3. **Phase 3**: Fine-tune OLMo 2 1B + LoRA on 5K tokenized samples
4. **Phase 4**: Push trained model to EeshaAI/zeeb
Checkpoints saved to persistent storage (survives Space restarts).
Training takes several hours on CPU.
""")
pipe_btn = gr.Button("Start Full Pipeline", variant="primary", size="lg")
ref_btn = gr.Button("Refresh Log")
pipe_log = gr.Textbox(label="Pipeline Log", value=lambda: get_log(), lines=30,
interactive=False, show_copy_button=True)
pipe_btn.click(fn=start_pipeline, outputs=pipe_log)
ref_btn.click(fn=get_log, outputs=pipe_log)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)