#!/usr/bin/env python # -*- coding: utf-8 -*- """ Cosmos-T3 — single-file Gradio app for inference/chat. """ from __future__ import annotations import os import sys import queue import threading from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoTokenizer # ───────────────────────────────────────────────────────────── # CONFIG # ───────────────────────────────────────────────────────────── TOKENIZER_NAME = "Qwen/Qwen2.5-0.5B" BLOCK_SIZE = 1024 MAX_LEN = 1024 D_MODEL = 768 N_LAYERS = 12 N_HEADS = 12 N_KV_HEADS = 4 D_FF = 2048 ROPE_BASE = 10000 DROP_OUT = 0.0 USE_ENGRAM = True ENGRAM_EVERY = 4 ENGRAM_BUCKETS = 8192 ENGRAM_DIM = 64 ENGRAM_ORDER = 3 DEFAULT_SYSTEM_PROMPT = "Enable thinking features: INTUITION" STAGE_CKPT = { "pretrain": "Cosmos-T3-Pretrain.resume.pt", "finetune": "Cosmos-T3-Instruct.resume.pt", } STAGE_BUCKET = { "pretrain": "pretrain/checkpoints/Cosmos-T3-Pretrain.resume.pt", "finetune": "finetune/checkpoints/Cosmos-T3-Instruct.resume.pt", } HF_BUCKET_ID = "wop/Cosmos-SFT" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") PAD_ID = 0 STOP_TOKEN_IDS: set[int] = set() MODEL_LOCK = threading.Lock() def resolve_checkpoint(stage="finetune", work_dir="cosmos_t3_run", no_bucket=False): local = Path(work_dir) / STAGE_CKPT[stage] if local.exists(): return local if no_bucket: raise FileNotFoundError(f"Missing checkpoint: {local}") token = os.environ.get("HF_TOKEN", "empty") os.environ["HF_TOKEN"] = token from huggingface_hub import download_bucket_files remote = STAGE_BUCKET[stage] local.parent.mkdir(parents=True, exist_ok=True) print(f"Downloading from bucket: {HF_BUCKET_ID}/{remote}") download_bucket_files(HF_BUCKET_ID, files=[(remote, str(local))]) if not local.exists(): raise RuntimeError("Bucket download failed") return local # ───────────────────────────────────────────────────────────── # MODEL CORE # ───────────────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x): rms = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(rms + self.eps) return x * self.weight def rotate_half(x): x1 = x[..., ::2] x2 = x[..., 1::2] return torch.stack((-x2, x1), dim=-1).flatten(-2) def apply_rope(q, k, cos, sin): q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) return q, k class GQAAttention(nn.Module): def __init__(self, d_model, n_heads, n_kv_heads, rope_base=10000, dropout=0.0): super().__init__() assert d_model % n_heads == 0 assert n_heads % n_kv_heads == 0 self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = d_model // n_heads self.rope_base = rope_base self.dropout = dropout self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) def forward(self, x, rope_cos, rope_sin, past_kv=None, use_cache=False): bsz, seq_len, _ = x.shape q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) q, k = apply_rope(q, k, rope_cos, rope_sin) if past_kv is not None: past_k, past_v = past_kv k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) present_kv = (k, v) if use_cache else None if self.n_kv_heads != self.n_heads: repeat = self.n_heads // self.n_kv_heads k = k.repeat_interleave(repeat, dim=1) v = v.repeat_interleave(repeat, dim=1) attn_out = F.scaled_dot_product_attention( q, k, v, is_causal=(past_kv is None), dropout_p=self.dropout if self.training else 0.0, ) attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seq_len, -1) attn_out = self.o_proj(attn_out) return (attn_out, present_kv) if use_cache else attn_out class SwiGLUMLP(nn.Module): def __init__(self, d_model, hidden_dim, dropout=0.0): super().__init__() self.gate = nn.Linear(d_model, hidden_dim, bias=False) self.up = nn.Linear(d_model, hidden_dim, bias=False) self.down = nn.Linear(hidden_dim, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): x = F.silu(self.gate(x)) * self.up(x) return self.down(self.dropout(x)) class EngramMemory(nn.Module): def __init__(self, d_model, bucket_count, memory_dim, order, pad_id=0, dropout=0.0): super().__init__() self.bucket_count = bucket_count self.memory_dim = memory_dim self.order = order self.pad_id = pad_id self.bucket = nn.Embedding(bucket_count, memory_dim) self.query = nn.Linear(d_model, memory_dim, bias=False) self.project = nn.Linear(memory_dim, d_model, bias=False) self.gate = nn.Linear(d_model, d_model, bias=True) self.dropout = nn.Dropout(dropout) primes = [1, 1315423911, 2654435761, 97531, 433494437] self.register_buffer("primes", torch.tensor(primes[:order], dtype=torch.long), persistent=False) def hash_tokens(self, idx): batch, seq_len = idx.shape pad = torch.full((batch, self.order - 1), self.pad_id, device=idx.device, dtype=idx.dtype) history = torch.cat([pad, idx], dim=1) hashed = torch.zeros((batch, seq_len), device=idx.device, dtype=torch.long) for offset in range(self.order): slice_ = history[:, offset: offset + seq_len].long() hashed = (hashed * 1315423911 + slice_ * self.primes[offset]) % self.bucket_count return hashed def forward(self, x, idx): hashed = self.hash_tokens(idx) if hashed.size(1) != x.size(1): hashed = hashed[:, -x.size(1):] query = torch.tanh(self.query(x)) memory = self.bucket(hashed) * query memory = self.project(memory) gate = torch.sigmoid(self.gate(x)) return self.dropout(gate * memory) class Block(nn.Module): def __init__( self, d_model, n_heads, n_kv_heads, d_ff, rope_base, dropout=0.0, use_engram=False, engram_bucket_count=4096, engram_dim=96, engram_order=3, pad_id=0, ): super().__init__() self.norm1 = RMSNorm(d_model) self.attn = GQAAttention(d_model, n_heads, n_kv_heads, rope_base=rope_base, dropout=dropout) self.norm2 = RMSNorm(d_model) self.engram = ( EngramMemory(d_model, engram_bucket_count, engram_dim, engram_order, pad_id=pad_id, dropout=dropout) if use_engram else None ) self.norm3 = RMSNorm(d_model) self.mlp = SwiGLUMLP(d_model, d_ff, dropout=dropout) def forward(self, x, idx, rope_cos, rope_sin): x = x + self.attn(self.norm1(x), rope_cos, rope_sin) if self.engram is not None: x = x + self.engram(self.norm2(x), idx) x = x + self.mlp(self.norm3(x)) return x def forward_cached(self, x, idx_context, rope_cos, rope_sin, past_kv=None): attn_out, present_kv = self.attn( self.norm1(x), rope_cos, rope_sin, past_kv=past_kv, use_cache=True, ) x = x + attn_out if self.engram is not None: x = x + self.engram(self.norm2(x), idx_context) x = x + self.mlp(self.norm3(x)) return x, present_kv class CosmosT2_Accelerate_LLM(nn.Module): def __init__( self, vocab_size, d_model=D_MODEL, n_layers=N_LAYERS, n_heads=N_HEADS, n_kv_heads=N_KV_HEADS, d_ff=D_FF, max_len=MAX_LEN, rope_base=ROPE_BASE, dropout=DROP_OUT, use_engram=USE_ENGRAM, engram_every=ENGRAM_EVERY, engram_bucket_count=ENGRAM_BUCKETS, engram_dim=ENGRAM_DIM, engram_order=ENGRAM_ORDER, pad_id=0, ): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.rope_theta = float(rope_base) self.head_dim = d_model // n_heads self.max_len = max_len self.rope_base = rope_base self.pad_id = pad_id self.tok_emb = nn.Embedding(vocab_size, d_model) self.blocks = nn.ModuleList() for layer_index in range(n_layers): block_uses_engram = use_engram and ((layer_index + 1) % engram_every == 0) self.blocks.append( Block( d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads, d_ff=d_ff, rope_base=rope_base, dropout=dropout, use_engram=block_uses_engram, engram_bucket_count=engram_bucket_count, engram_dim=engram_dim, engram_order=engram_order, pad_id=pad_id, ) ) self.norm_f = RMSNorm(d_model) def build_rope(self, seq_len, device, dtype, start_pos=0): inv_freq = 1.0 / ( self.rope_theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim) ) positions = torch.arange(start_pos, start_pos + seq_len, device=device).float() freqs = torch.outer(positions, inv_freq) cos = freqs.cos().repeat_interleave(2, dim=-1).to(dtype)[None, None, :, :] sin = freqs.sin().repeat_interleave(2, dim=-1).to(dtype)[None, None, :, :] return cos, sin def forward(self, idx, targets=None): if idx.size(1) > self.max_len: idx = idx[:, -self.max_len:] seq_len = idx.size(1) rope_cos, rope_sin = self.build_rope(seq_len, idx.device, self.tok_emb.weight.dtype) x = self.tok_emb(idx) for block in self.blocks: x = block(x, idx, rope_cos, rope_sin) x = self.norm_f(x) logits = F.linear(x, self.tok_emb.weight) loss = None if targets is not None: loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) return logits, loss def trim_kv_cache(self, past_kv, max_tokens): if past_kv is None: return None max_tokens = max(0, int(max_tokens)) trimmed = [] for k, v in past_kv: if max_tokens == 0: k = k[:, :, :0, :].contiguous() v = v[:, :, :0, :].contiguous() elif k.size(2) > max_tokens: k = k[:, :, -max_tokens:, :].contiguous() v = v[:, :, -max_tokens:, :].contiguous() trimmed.append((k, v)) return trimmed @torch.no_grad() def forward_cached(self, idx, past_kv=None, cache_pos=0, max_ctx=None, idx_context=None): self.eval() max_ctx = self.max_len if max_ctx is None else int(max_ctx) if past_kv is None: idx = idx[:, -max_ctx:] idx_context = idx cache_pos = 0 else: keep_past = max(0, max_ctx - idx.size(1)) past_kv = self.trim_kv_cache(past_kv, keep_past) idx_context = idx if idx_context is None else idx_context[:, -max_ctx:] seq_len = idx.size(1) rope_cos, rope_sin = self.build_rope( seq_len, idx.device, self.tok_emb.weight.dtype, start_pos=cache_pos, ) x = self.tok_emb(idx) present_kv = [] for layer_index, block in enumerate(self.blocks): layer_past = None if past_kv is None else past_kv[layer_index] x, layer_present = block.forward_cached( x, idx_context, rope_cos, rope_sin, past_kv=layer_past, ) present_kv.append(layer_present) x = self.norm_f(x) logits = F.linear(x, self.tok_emb.weight) return logits, present_kv, cache_pos + seq_len def sample_next(self, logits, temperature=0.8, top_k=50): if logits.dim() == 3: logits = logits[:, -1, :] if temperature <= 1e-6: return torch.argmax(logits, dim=-1, keepdim=True) logits = logits / temperature if top_k and top_k > 0: values, _ = torch.topk(logits, min(top_k, logits.size(-1))) cutoff = values[:, [-1]] logits = logits.masked_fill(logits < cutoff, float("-inf")) probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1) @torch.no_grad() def prefill_cache(self, idx, max_ctx=None): logits, past_kv, cache_pos = self.forward_cached(idx, past_kv=None, cache_pos=0, max_ctx=max_ctx) return logits[:, -1, :], past_kv, cache_pos @torch.no_grad() def decode_cached(self, idx, past_kv, cache_pos, idx_context, max_ctx=None): logits, past_kv, cache_pos = self.forward_cached( idx, past_kv=past_kv, cache_pos=cache_pos, max_ctx=max_ctx, idx_context=idx_context, ) return logits[:, -1, :], past_kv, cache_pos @torch.no_grad() def generate( self, idx, max_new_tokens=256, temperature=0.9, top_k=45, max_ctx=None, stop_ids=None, on_token=None, ): self.eval() max_ctx = self.max_len if max_ctx is None else int(max_ctx) idx = idx[:, -max_ctx:] logits, past_kv, cache_pos = self.prefill_cache(idx, max_ctx=max_ctx) stop_ids = STOP_TOKEN_IDS if stop_ids is None else stop_ids for step in range(max_new_tokens): nxt = self.sample_next(logits, temperature=temperature, top_k=top_k) if stop_ids and nxt.numel() == 1 and int(nxt.item()) in stop_ids: break if on_token is not None: on_token(int(nxt.item())) idx = torch.cat([idx, nxt], dim=1) if step + 1 < max_new_tokens: logits, past_kv, cache_pos = self.decode_cached( nxt, past_kv, cache_pos, idx[:, -max_ctx:], max_ctx=max_ctx, ) return idx # ───────────────────────────────────────────────────────────── # HELPERS # ───────────────────────────────────────────────────────────── def _resolve_stop_ids(tok): ids = set() for t in ("<|im_end|>", "<|endoftext|>"): i = tok.convert_tokens_to_ids(t) if isinstance(i, int) and i >= 0 and i != tok.unk_token_id: ids.add(i) if tok.eos_token_id is not None: ids.add(tok.eos_token_id) return ids def _looks_like_state_dict(d): if not isinstance(d, dict) or not d: return False tensor_vals = [v for v in d.values() if torch.is_tensor(v)] if len(tensor_vals) < max(4, 0.5 * len(d)): return False return any("." in str(k) for k in d.keys()) def _extract_state_dict(blob): if _looks_like_state_dict(blob): return blob if isinstance(blob, dict): for key in ("model_state_dict", "model", "model_state", "state_dict", "weights", "net", "module", "ema", "ema_model"): inner = blob.get(key) if _looks_like_state_dict(inner): return inner if isinstance(inner, dict): for k2, v2 in inner.items(): if _looks_like_state_dict(v2): return v2 for v in blob.values(): if _looks_like_state_dict(v): return v if isinstance(v, dict): for v2 in v.values(): if _looks_like_state_dict(v2): return v2 raise ValueError( "Could not find a model state_dict in the checkpoint. " f"Top-level keys were: {list(blob.keys())}" ) raise ValueError(f"Unexpected checkpoint type: {type(blob)}") def load_model(ckpt_path, tokenizer): blob = torch.load(ckpt_path, map_location="cpu", weights_only=False) cfg = {} if isinstance(blob, dict): for key in ("model_config", "config"): if isinstance(blob.get(key), dict): cfg = blob[key] break model = CosmosT2_Accelerate_LLM( vocab_size=cfg.get("vocab_size", len(tokenizer)), d_model=cfg.get("d_model", D_MODEL), n_layers=cfg.get("n_layers", N_LAYERS), n_heads=cfg.get("n_heads", N_HEADS), n_kv_heads=cfg.get("n_kv_heads", N_KV_HEADS), d_ff=cfg.get("d_ff", D_FF), max_len=cfg.get("max_len", MAX_LEN), rope_base=cfg.get("rope_base", ROPE_BASE), dropout=0.0, use_engram=cfg.get("use_engram", USE_ENGRAM), engram_every=cfg.get("engram_every", ENGRAM_EVERY), engram_bucket_count=cfg.get("engram_buckets", ENGRAM_BUCKETS), engram_dim=cfg.get("engram_dim", ENGRAM_DIM), engram_order=cfg.get("engram_order", ENGRAM_ORDER), pad_id=tokenizer.pad_token_id or 0, ) state = _extract_state_dict(blob) missing, unexpected = model.load_state_dict(state, strict=False) if missing: print(f"[warn] missing keys: {len(missing)} (e.g. {missing[:3]})") if unexpected: print(f"[warn] unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})") model.eval() return model def build_prompt_ids(tokenizer, user_text, stage, system_prompt, history=None): if stage == "pretrain": ids = tokenizer(user_text, add_special_tokens=False, return_attention_mask=False)["input_ids"] return ids messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) for role, content in (history or []): messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": user_text}) text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"] # ───────────────────────────────────────────────────────────── # LOAD ON STARTUP # ───────────────────────────────────────────────────────────── print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token PAD_ID = tokenizer.pad_token_id STOP_TOKEN_IDS = _resolve_stop_ids(tokenizer) # ───────────────────────────── # FIX: resolve FIRST # ───────────────────────────── CKPT_PATH = resolve_checkpoint(stage="finetune") print(f"Loading model checkpoint: {CKPT_PATH}") model = load_model(CKPT_PATH, tokenizer) model.to(device) model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f"Model ready: {n_params/1e6:.1f}M params | device={device}") # ───────────────────────────────────────────────────────────── # GRADIO STREAMING # ───────────────────────────────────────────────────────────── def history_to_role_messages(history): messages = [] for user_msg, assistant_msg in history or []: messages.append(("user", user_msg)) messages.append(("assistant", assistant_msg)) return messages def chat_stream(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT): role_history = history_to_role_messages(history) prompt_ids = build_prompt_ids( tokenizer=tokenizer, user_text=message, stage="finetune", system_prompt=system_prompt, history=role_history, ) idx = torch.tensor([prompt_ids], dtype=torch.long, device=device) q: queue.Queue[str | object] = queue.Queue() END = object() def worker(): try: def on_token(tid: int): txt = tokenizer.decode([tid], skip_special_tokens=True) q.put(txt) with MODEL_LOCK: with torch.inference_mode(): model.generate( idx, max_new_tokens=256, temperature=0.9, top_k=45, max_ctx=MAX_LEN, on_token=on_token, ) finally: q.put(END) threading.Thread(target=worker, daemon=True).start() output = "" while True: item = q.get() if item is END: break output += item yield output def chat(message, history): yield from chat_stream(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT) # ───────────────────────────────────────────────────────────── # UI # ───────────────────────────────────────────────────────────── demo = gr.ChatInterface( fn=chat, title="Cosmos-T3 API", description="Streaming inference API (backend for your frontend)", ) if __name__ == "__main__": demo.queue().launch( server_name="0.0.0.0", server_port=7860, )