OpenTransformer commited on
Commit
8f28f62
·
verified ·
1 Parent(s): 4cb6700

Backup script inference_api.py

Browse files
Files changed (1) hide show
  1. scripts/inference_api.py +137 -0
scripts/inference_api.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """AGILLM-3 GPU Inference API"""
3
+ import os, sys, json, torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from flask import Flask, request, jsonify
7
+ from flask_cors import CORS
8
+ import tiktoken
9
+
10
+ app = Flask(__name__)
11
+ CORS(app)
12
+
13
+ class ModelConfig:
14
+ vocab_size = 50257
15
+ d_model = 1024
16
+ n_heads = 16
17
+ n_layers = 24
18
+ d_ff = 4096
19
+ max_seq_len = 2048
20
+ dropout = 0.0
21
+
22
+ class AGILLM3(nn.Module):
23
+ def __init__(self, config):
24
+ super().__init__()
25
+ self.config = config
26
+ self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
27
+ self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
28
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
29
+ self.ln_f = nn.LayerNorm(config.d_model)
30
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
31
+
32
+ def forward(self, idx):
33
+ B, T = idx.shape
34
+ tok_emb = self.tok_emb(idx)
35
+ pos_emb = self.pos_emb(torch.arange(T, device=idx.device))
36
+ x = tok_emb + pos_emb
37
+ for layer in self.layers:
38
+ x = layer(x)
39
+ x = self.ln_f(x)
40
+ return self.lm_head(x)
41
+
42
+ class TransformerBlock(nn.Module):
43
+ def __init__(self, config):
44
+ super().__init__()
45
+ self.ln1 = nn.LayerNorm(config.d_model)
46
+ self.attn = CausalSelfAttention(config)
47
+ self.ln2 = nn.LayerNorm(config.d_model)
48
+ self.mlp = MLP(config)
49
+
50
+ def forward(self, x):
51
+ x = x + self.attn(self.ln1(x))
52
+ x = x + self.mlp(self.ln2(x))
53
+ return x
54
+
55
+ class CausalSelfAttention(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.n_heads = config.n_heads
59
+ self.head_dim = config.d_model // config.n_heads
60
+ self.qkv = nn.Linear(config.d_model, 3 * config.d_model)
61
+ self.proj = nn.Linear(config.d_model, config.d_model)
62
+
63
+ def forward(self, x):
64
+ B, T, C = x.shape
65
+ qkv = self.qkv(x).chunk(3, dim=-1)
66
+ q, k, v = [t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) for t in qkv]
67
+ att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
68
+ mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
69
+ att = att.masked_fill(mask, float('-inf'))
70
+ att = F.softmax(att, dim=-1)
71
+ y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
72
+ return self.proj(y)
73
+
74
+ class MLP(nn.Module):
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.fc1 = nn.Linear(config.d_model, config.d_ff)
78
+ self.fc2 = nn.Linear(config.d_ff, config.d_model)
79
+
80
+ def forward(self, x):
81
+ return self.fc2(F.gelu(self.fc1(x)))
82
+
83
+ model = None
84
+ enc = tiktoken.get_encoding("gpt2")
85
+ device = "cuda" if torch.cuda.is_available() else "cpu"
86
+
87
+ def load_model(ckpt_path):
88
+ global model
89
+ print(f"Loading model on {device}...")
90
+ model = AGILLM3(ModelConfig()).to(device)
91
+ ckpt = torch.load(ckpt_path, map_location=device)
92
+ state = ckpt.get('model_state_dict', ckpt)
93
+ model.load_state_dict(state, strict=False)
94
+ model.eval()
95
+ print("Model ready!")
96
+
97
+ @torch.no_grad()
98
+ def generate(prompt, max_tokens=100, temperature=0.8):
99
+ tokens = enc.encode(prompt)
100
+ tokens = torch.tensor([tokens], device=device)
101
+ for _ in range(max_tokens):
102
+ logits = model(tokens[:, -2048:])[:, -1, :]
103
+ probs = F.softmax(logits / temperature, dim=-1)
104
+ next_tok = torch.multinomial(probs, 1)
105
+ tokens = torch.cat([tokens, next_tok], dim=1)
106
+ if next_tok.item() == enc.eot_token:
107
+ break
108
+ return enc.decode(tokens[0].tolist())
109
+
110
+ @app.route('/api/chat', methods=['POST'])
111
+ def chat():
112
+ try:
113
+ data = request.json
114
+ message = data.get('message', '')
115
+ if not message:
116
+ return jsonify({'error': 'No message'}), 400
117
+ prompt = f"User: {message}\nAssistant:"
118
+ response = generate(prompt, max_tokens=150, temperature=0.7)
119
+ if "Assistant:" in response:
120
+ response = response.split("Assistant:")[-1].strip()
121
+ if "User:" in response:
122
+ response = response.split("User:")[0].strip()
123
+ return jsonify({'response': response})
124
+ except Exception as e:
125
+ return jsonify({'error': str(e)}), 500
126
+
127
+ @app.route('/api/health', methods=['GET'])
128
+ def health():
129
+ return jsonify({'status': 'ok', 'device': device, 'model_loaded': model is not None})
130
+
131
+ if __name__ == '__main__':
132
+ import glob
133
+ ckpts = sorted(glob.glob('/workspace/ckpts_expansion/*.pt'))
134
+ ckpt = ckpts[-1] if ckpts else '/workspace/checkpoint.pt'
135
+ print(f"Using checkpoint: {ckpt}")
136
+ load_model(ckpt)
137
+ app.run(host='0.0.0.0', port=5000, threaded=True)