|
|
|
|
|
"""AGILLM-3 GPU Inference API""" |
|
|
import os, sys, json, torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
import tiktoken |
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
class ModelConfig: |
|
|
vocab_size = 50257 |
|
|
d_model = 1024 |
|
|
n_heads = 16 |
|
|
n_layers = 24 |
|
|
d_ff = 4096 |
|
|
max_seq_len = 2048 |
|
|
dropout = 0.0 |
|
|
|
|
|
class AGILLM3(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) |
|
|
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) |
|
|
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) |
|
|
self.ln_f = nn.LayerNorm(config.d_model) |
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
def forward(self, idx): |
|
|
B, T = idx.shape |
|
|
tok_emb = self.tok_emb(idx) |
|
|
pos_emb = self.pos_emb(torch.arange(T, device=idx.device)) |
|
|
x = tok_emb + pos_emb |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
x = self.ln_f(x) |
|
|
return self.lm_head(x) |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(config.d_model) |
|
|
self.attn = CausalSelfAttention(config) |
|
|
self.ln2 = nn.LayerNorm(config.d_model) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.ln1(x)) |
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.n_heads = config.n_heads |
|
|
self.head_dim = config.d_model // config.n_heads |
|
|
self.qkv = nn.Linear(config.d_model, 3 * config.d_model) |
|
|
self.proj = nn.Linear(config.d_model, config.d_model) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.shape |
|
|
qkv = self.qkv(x).chunk(3, dim=-1) |
|
|
q, k, v = [t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) for t in qkv] |
|
|
att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() |
|
|
att = att.masked_fill(mask, float('-inf')) |
|
|
att = F.softmax(att, dim=-1) |
|
|
y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) |
|
|
return self.proj(y) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(config.d_model, config.d_ff) |
|
|
self.fc2 = nn.Linear(config.d_ff, config.d_model) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.fc2(F.gelu(self.fc1(x))) |
|
|
|
|
|
model = None |
|
|
enc = tiktoken.get_encoding("gpt2") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def load_model(ckpt_path): |
|
|
global model |
|
|
print(f"Loading model on {device}...") |
|
|
model = AGILLM3(ModelConfig()).to(device) |
|
|
ckpt = torch.load(ckpt_path, map_location=device) |
|
|
state = ckpt.get('model_state_dict', ckpt) |
|
|
model.load_state_dict(state, strict=False) |
|
|
model.eval() |
|
|
print("Model ready!") |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(prompt, max_tokens=100, temperature=0.8): |
|
|
tokens = enc.encode(prompt) |
|
|
tokens = torch.tensor([tokens], device=device) |
|
|
for _ in range(max_tokens): |
|
|
logits = model(tokens[:, -2048:])[:, -1, :] |
|
|
probs = F.softmax(logits / temperature, dim=-1) |
|
|
next_tok = torch.multinomial(probs, 1) |
|
|
tokens = torch.cat([tokens, next_tok], dim=1) |
|
|
if next_tok.item() == enc.eot_token: |
|
|
break |
|
|
return enc.decode(tokens[0].tolist()) |
|
|
|
|
|
@app.route('/api/chat', methods=['POST']) |
|
|
def chat(): |
|
|
try: |
|
|
data = request.json |
|
|
message = data.get('message', '') |
|
|
if not message: |
|
|
return jsonify({'error': 'No message'}), 400 |
|
|
prompt = f"User: {message}\nAssistant:" |
|
|
response = generate(prompt, max_tokens=150, temperature=0.7) |
|
|
if "Assistant:" in response: |
|
|
response = response.split("Assistant:")[-1].strip() |
|
|
if "User:" in response: |
|
|
response = response.split("User:")[0].strip() |
|
|
return jsonify({'response': response}) |
|
|
except Exception as e: |
|
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
@app.route('/api/health', methods=['GET']) |
|
|
def health(): |
|
|
return jsonify({'status': 'ok', 'device': device, 'model_loaded': model is not None}) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
import glob |
|
|
ckpts = sorted(glob.glob('/workspace/ckpts_expansion/*.pt')) |
|
|
ckpt = ckpts[-1] if ckpts else '/workspace/checkpoint.pt' |
|
|
print(f"Using checkpoint: {ckpt}") |
|
|
load_model(ckpt) |
|
|
app.run(host='0.0.0.0', port=5000, threaded=True) |
|
|
|