File size: 4,732 Bytes
8f28f62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
#!/usr/bin/env python3
"""AGILLM-3 GPU Inference API"""
import os, sys, json, torch
import torch.nn as nn
import torch.nn.functional as F
from flask import Flask, request, jsonify
from flask_cors import CORS
import tiktoken
app = Flask(__name__)
CORS(app)
class ModelConfig:
vocab_size = 50257
d_model = 1024
n_heads = 16
n_layers = 24
d_ff = 4096
max_seq_len = 2048
dropout = 0.0
class AGILLM3(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.ln_f = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
def forward(self, idx):
B, T = idx.shape
tok_emb = self.tok_emb(idx)
pos_emb = self.pos_emb(torch.arange(T, device=idx.device))
x = tok_emb + pos_emb
for layer in self.layers:
x = layer(x)
x = self.ln_f(x)
return self.lm_head(x)
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.d_model)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.d_model)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
self.qkv = nn.Linear(config.d_model, 3 * config.d_model)
self.proj = nn.Linear(config.d_model, config.d_model)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x).chunk(3, dim=-1)
q, k, v = [t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) for t in qkv]
att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att = att.masked_fill(mask, float('-inf'))
att = F.softmax(att, dim=-1)
y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.proj(y)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.fc1 = nn.Linear(config.d_model, config.d_ff)
self.fc2 = nn.Linear(config.d_ff, config.d_model)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
model = None
enc = tiktoken.get_encoding("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(ckpt_path):
global model
print(f"Loading model on {device}...")
model = AGILLM3(ModelConfig()).to(device)
ckpt = torch.load(ckpt_path, map_location=device)
state = ckpt.get('model_state_dict', ckpt)
model.load_state_dict(state, strict=False)
model.eval()
print("Model ready!")
@torch.no_grad()
def generate(prompt, max_tokens=100, temperature=0.8):
tokens = enc.encode(prompt)
tokens = torch.tensor([tokens], device=device)
for _ in range(max_tokens):
logits = model(tokens[:, -2048:])[:, -1, :]
probs = F.softmax(logits / temperature, dim=-1)
next_tok = torch.multinomial(probs, 1)
tokens = torch.cat([tokens, next_tok], dim=1)
if next_tok.item() == enc.eot_token:
break
return enc.decode(tokens[0].tolist())
@app.route('/api/chat', methods=['POST'])
def chat():
try:
data = request.json
message = data.get('message', '')
if not message:
return jsonify({'error': 'No message'}), 400
prompt = f"User: {message}\nAssistant:"
response = generate(prompt, max_tokens=150, temperature=0.7)
if "Assistant:" in response:
response = response.split("Assistant:")[-1].strip()
if "User:" in response:
response = response.split("User:")[0].strip()
return jsonify({'response': response})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/health', methods=['GET'])
def health():
return jsonify({'status': 'ok', 'device': device, 'model_loaded': model is not None})
if __name__ == '__main__':
import glob
ckpts = sorted(glob.glob('/workspace/ckpts_expansion/*.pt'))
ckpt = ckpts[-1] if ckpts else '/workspace/checkpoint.pt'
print(f"Using checkpoint: {ckpt}")
load_model(ckpt)
app.run(host='0.0.0.0', port=5000, threaded=True)
|