chen-bot-api-v1 / main.py
chenhsu92's picture
Deploying backend with Dockerfile and Model
a6e0d48
import torch
import torch.nn as nn
from torch.nn import functional as F
import pickle
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
# --- Hyperparameters (Must match your model.py) ---
n_embd = 512
n_head = 8
n_layer = 12
dropout = 0.2
block_size = 512
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Model Architecture (Copied from your model.py) ---
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k, q, v = self.key(x), self.query(x), self.value(x)
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
return self.dropout(wei) @ v
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
return self.dropout(self.proj(out))
class FeedFoward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout))
def forward(self, x): return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa, self.ffwd = MultiHeadAttention(n_head, head_size), FeedFoward(n_embd)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
return x + self.ffwd(self.ln2(x))
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f, self.lm_head = nn.LayerNorm(n_embd), nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = self.ln_f(self.blocks(tok_emb + pos_emb))
logits = self.lm_head(x)
return logits, None
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)
probs = F.softmax(logits[:, -1, :], dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# --- Server Logic ---
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# Load Metadata and Model
with open('meta.pkl', 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s if c in stoi] # Filter unknown chars
decode = lambda l: ''.join([itos[i] for i in l])
model = GPTLanguageModel(meta['vocab_size'])
# Load just the weights (state_dict) or the whole model
try:
checkpoint = torch.load("finetuned_model.pt", map_location=device)
if isinstance(checkpoint, dict):
model.load_state_dict(checkpoint)
else:
model = checkpoint
model.to(device)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print(f"Error: {e}")
class ChatRequest(BaseModel):
prompt: str
@app.post("/chat")
async def chat(request: ChatRequest):
# Wrap the prompt to force a dialogue structure
context = f"User: {request.prompt}\nChen Bot:"
input_ids = torch.tensor(encode(context), dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
output_ids = model.generate(input_ids, max_new_tokens=200)[0].tolist()
full_text = decode(output_ids)
# Extract response: strictly what comes after our manually injected context
reply_start_index = len(context)
raw_reply = full_text[reply_start_index:]
# Stop generation if the model tries to start a new "User:" turn
clean_reply = raw_reply.split("User:")[0].strip()
return {"reply": clean_reply}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)