AGILLM-3-chat / app.py
OpenTransformer
fix: scan root-level checkpoints, not just checkpoints/ folder
1d0c70c
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download, HfApi
class TuneableAttentionMHA(nn.Module):
def __init__(self, d: int, h: int, r: int):
super().__init__()
self.h, self.dk, self.r = h, d // h, r
self.q = nn.Linear(d, d, bias=False)
self.k = nn.Linear(d, d, bias=False)
self.v = nn.Linear(d, d, bias=False)
self.U = nn.Parameter(torch.randn(self.dk, r))
self.proj = nn.Linear(h * self.dk, d, bias=False)
def forward(self, x, mask=None, kv_cache=None, use_cache=False):
B, N, _ = x.shape
q = (self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
k_new = (self.k(x).view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
v_new = self.v(x).view(B, N, self.h, self.dk).transpose(1, 2)
if kv_cache is not None:
k_cached, v_cached = kv_cache
k = torch.cat([k_cached, k_new], dim=2)
v = torch.cat([v_cached, v_new], dim=2)
else:
k, v = k_new, v_new
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
if mask is not None:
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
if use_cache:
return self.proj(z), (k, v)
return self.proj(z)
class Block(nn.Module):
def __init__(self, d: int, h: int, r: int):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.mha = TuneableAttentionMHA(d, h, r)
self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
def forward(self, x, mask, kv_cache=None, use_cache=False):
if use_cache:
attn_out, new_kv = self.mha(self.ln1(x), mask, kv_cache, use_cache=True)
x = x + attn_out
return x + self.ff(self.ln2(x)), new_kv
else:
x = x + self.mha(self.ln1(x), mask)
return x + self.ff(self.ln2(x))
class Encoder(nn.Module):
def __init__(self, vocab, d, layers, heads, rank):
super().__init__()
self.emb = nn.Embedding(vocab, d)
self.blocks = nn.ModuleList([Block(d, heads, rank) for _ in range(layers)])
self.ln = nn.LayerNorm(d)
def forward(self, ids, mask, kv_caches=None, use_cache=False):
x = self.emb(ids)
if use_cache:
new_kvs = []
for i, blk in enumerate(self.blocks):
kv = kv_caches[i] if kv_caches else None
x, new_kv = blk(x, mask, kv, use_cache=True)
new_kvs.append(new_kv)
return self.ln(x), new_kvs
for blk in self.blocks:
x = blk(x, mask)
return self.ln(x)
class ARHead(nn.Module):
def __init__(self, d, vocab):
super().__init__()
self.proj = nn.Linear(d, vocab)
def forward(self, h):
return self.proj(h)
MODEL_REPO = "OpenTransformer/AGILLM-3-large"
def get_latest_checkpoint():
api = HfApi()
files = api.list_repo_files(MODEL_REPO, revision="main")
ckpts = [f for f in files if f.endswith(".pt") and "step" in f]
if not ckpts:
raise ValueError("No checkpoints found in repo")
ckpts.sort(key=lambda x: int(x.split("step")[1].split(".")[0]), reverse=True)
return ckpts[0]
def get_current_step_from_name(ckpt_name):
return int(ckpt_name.split("step")[1].split(".")[0])
# Global model state
model_state = {
"core": None,
"ar_head": None,
"step": 0,
"ckpt_name": None,
"compiled": False,
"tokenizer": None,
"vocab": 0,
}
def load_model(ckpt_name=None, force_download=False):
global model_state
if ckpt_name is None:
ckpt_name = get_latest_checkpoint()
print(f"Loading checkpoint: {ckpt_name}")
ckpt_path = hf_hub_download(MODEL_REPO, ckpt_name, force_download=force_download)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg = ckpt["cfg"]
D, LAYERS, HEADS, RANK = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
step = ckpt.get("step", 0)
core = Encoder(model_state["vocab"], D, LAYERS, HEADS, RANK)
ar_head = ARHead(D, model_state["vocab"])
core.load_state_dict(ckpt["core"])
ar_head.load_state_dict(ckpt["ar"])
core.eval()
ar_head.eval()
# Try to compile
compiled = False
try:
core = torch.compile(core, mode="reduce-overhead")
ar_head = torch.compile(ar_head, mode="reduce-overhead")
compiled = True
print("Compilation successful!")
except Exception as e:
print(f"Compilation failed: {e}")
# Warmup
try:
with torch.no_grad():
dummy = torch.randint(0, 1000, (1, 10))
mask = torch.zeros((1, 1, 10, 10))
_ = ar_head(core(dummy, mask))
except:
pass
model_state["core"] = core
model_state["ar_head"] = ar_head
model_state["step"] = step
model_state["ckpt_name"] = ckpt_name
model_state["compiled"] = compiled
params = sum(p.numel() for p in core.parameters()) + sum(p.numel() for p in ar_head.parameters())
print(f"Loaded: {params:,} params @ step {step:,}")
return step, ckpt_name
def check_for_updates():
"""Check if a newer checkpoint exists and load it if so."""
latest = get_latest_checkpoint()
latest_step = get_current_step_from_name(latest)
current_step = model_state["step"]
if latest_step > current_step:
# Force fresh download, bypass cache
new_step, new_name = load_model(latest, force_download=True)
return f"✅ Updated! Step {current_step:,}{new_step:,}"
else:
return f"Already on latest (step {current_step:,})"
# Initial load
print("Loading tokenizer...")
model_state["tokenizer"] = AutoTokenizer.from_pretrained(MODEL_REPO, subfolder="tokenizer", trust_remote_code=True)
model_state["vocab"] = len(model_state["tokenizer"])
print("Finding latest checkpoint...")
load_model()
@torch.no_grad()
def generate(prompt, max_tokens=50, temperature=0.8, top_p=0.9):
tokenizer = model_state["tokenizer"]
core = model_state["core"]
ar_head = model_state["ar_head"]
ids = tokenizer.encode(prompt, return_tensors="pt")
start_time = time.time()
tokens_generated = 0
prompt_len = ids.size(1)
mask = torch.triu(torch.full((1, 1, prompt_len, prompt_len), float("-inf")), 1)
h, kv_caches = core(ids, mask, use_cache=True)
for _ in range(max_tokens):
if ids.size(1) >= 2048:
break
logits = ar_head(h)[:, -1, :] / max(temperature, 0.01)
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs = F.softmax(sorted_logits, dim=-1)
cumsum = torch.cumsum(probs, dim=-1)
sorted_logits[cumsum - probs > top_p] = float('-inf')
probs = F.softmax(sorted_logits, dim=-1)
next_tok = sorted_idx.gather(-1, torch.multinomial(probs, 1))
ids = torch.cat([ids, next_tok], dim=1)
tokens_generated += 1
if next_tok.item() == tokenizer.eos_token_id:
break
mask = torch.zeros((1, 1, 1, ids.size(1)))
h, kv_caches = core(next_tok, mask, kv_caches, use_cache=True)
elapsed = time.time() - start_time
tok_s = tokens_generated / elapsed if elapsed > 0 else 0
text = tokenizer.decode(ids[0], skip_special_tokens=True)
text = " ".join(text.split())
return text, tok_s, tokens_generated, elapsed
def chat(message, history, max_tokens=50, temperature=0.8):
if max_tokens is None: max_tokens = 50
if temperature is None: temperature = 0.8
prompt = "".join(f"User: {h[0]}\nAssistant: {h[1]}\n" for h in history)
prompt += f"User: {message}\nAssistant:"
response, tok_s, n_tok, elapsed = generate(prompt, int(max_tokens), float(temperature))
if "Assistant:" in response:
response = response.split("Assistant:")[-1].strip()
compiled_tag = " ⚡" if model_state["compiled"] else ""
step = model_state["step"]
return f"{response}\n\n---\n**{tok_s:.1f} tok/s**{compiled_tag} | {n_tok} tokens | {elapsed:.1f}s | step {step:,}"
def get_status():
step = model_state["step"]
compiled = "⚡ compiled" if model_state["compiled"] else "eager"
return f"Step {step:,} | {compiled}"
with gr.Blocks(title="AGILLM-3 Chat") as demo:
gr.Markdown(f"# AGILLM-3 Large (698M)")
gr.Markdown("KV-cached + torch.compile() | 0.5% pretrained | expect beautiful gibberish")
with gr.Row():
status = gr.Textbox(value=get_status(), label="Model Status", interactive=False)
refresh_btn = gr.Button("🔄 Check for Updates", variant="secondary")
refresh_btn.click(fn=check_for_updates, outputs=status)
chatbot = gr.ChatInterface(
chat,
additional_inputs=[
gr.Slider(10, 200, value=50, step=10, label="Max Tokens"),
gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature"),
],
examples=[["Hello!"], ["What is 2+2?"]],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()