#!/usr/bin/env python3 """AGILLM-3 GPU Inference API""" import os, sys, json, torch import torch.nn as nn import torch.nn.functional as F from flask import Flask, request, jsonify from flask_cors import CORS import tiktoken app = Flask(__name__) CORS(app) class ModelConfig: vocab_size = 50257 d_model = 1024 n_heads = 16 n_layers = 24 d_ff = 4096 max_seq_len = 2048 dropout = 0.0 class AGILLM3(nn.Module): def __init__(self, config): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.ln_f = nn.LayerNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) def forward(self, idx): B, T = idx.shape tok_emb = self.tok_emb(idx) pos_emb = self.pos_emb(torch.arange(T, device=idx.device)) x = tok_emb + pos_emb for layer in self.layers: x = layer(x) x = self.ln_f(x) return self.lm_head(x) class TransformerBlock(nn.Module): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.d_model) self.attn = CausalSelfAttention(config) self.ln2 = nn.LayerNorm(config.d_model) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.n_heads = config.n_heads self.head_dim = config.d_model // config.n_heads self.qkv = nn.Linear(config.d_model, 3 * config.d_model) self.proj = nn.Linear(config.d_model, config.d_model) def forward(self, x): B, T, C = x.shape qkv = self.qkv(x).chunk(3, dim=-1) q, k, v = [t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) for t in qkv] att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() att = att.masked_fill(mask, float('-inf')) att = F.softmax(att, dim=-1) y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) return self.proj(y) class MLP(nn.Module): def __init__(self, config): super().__init__() self.fc1 = nn.Linear(config.d_model, config.d_ff) self.fc2 = nn.Linear(config.d_ff, config.d_model) def forward(self, x): return self.fc2(F.gelu(self.fc1(x))) model = None enc = tiktoken.get_encoding("gpt2") device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(ckpt_path): global model print(f"Loading model on {device}...") model = AGILLM3(ModelConfig()).to(device) ckpt = torch.load(ckpt_path, map_location=device) state = ckpt.get('model_state_dict', ckpt) model.load_state_dict(state, strict=False) model.eval() print("Model ready!") @torch.no_grad() def generate(prompt, max_tokens=100, temperature=0.8): tokens = enc.encode(prompt) tokens = torch.tensor([tokens], device=device) for _ in range(max_tokens): logits = model(tokens[:, -2048:])[:, -1, :] probs = F.softmax(logits / temperature, dim=-1) next_tok = torch.multinomial(probs, 1) tokens = torch.cat([tokens, next_tok], dim=1) if next_tok.item() == enc.eot_token: break return enc.decode(tokens[0].tolist()) @app.route('/api/chat', methods=['POST']) def chat(): try: data = request.json message = data.get('message', '') if not message: return jsonify({'error': 'No message'}), 400 prompt = f"User: {message}\nAssistant:" response = generate(prompt, max_tokens=150, temperature=0.7) if "Assistant:" in response: response = response.split("Assistant:")[-1].strip() if "User:" in response: response = response.split("User:")[0].strip() return jsonify({'response': response}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/api/health', methods=['GET']) def health(): return jsonify({'status': 'ok', 'device': device, 'model_loaded': model is not None}) if __name__ == '__main__': import glob ckpts = sorted(glob.glob('/workspace/ckpts_expansion/*.pt')) ckpt = ckpts[-1] if ckpts else '/workspace/checkpoint.pt' print(f"Using checkpoint: {ckpt}") load_model(ckpt) app.run(host='0.0.0.0', port=5000, threaded=True)