Spaces:
Sleeping
Sleeping
File size: 4,998 Bytes
a6e0d48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import torch
import torch.nn as nn
from torch.nn import functional as F
import pickle
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
# --- Hyperparameters (Must match your model.py) ---
n_embd = 512
n_head = 8
n_layer = 12
dropout = 0.2
block_size = 512
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Model Architecture (Copied from your model.py) ---
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k, q, v = self.key(x), self.query(x), self.value(x)
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
return self.dropout(wei) @ v
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
return self.dropout(self.proj(out))
class FeedFoward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout))
def forward(self, x): return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa, self.ffwd = MultiHeadAttention(n_head, head_size), FeedFoward(n_embd)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
return x + self.ffwd(self.ln2(x))
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f, self.lm_head = nn.LayerNorm(n_embd), nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = self.ln_f(self.blocks(tok_emb + pos_emb))
logits = self.lm_head(x)
return logits, None
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)
probs = F.softmax(logits[:, -1, :], dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# --- Server Logic ---
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# Load Metadata and Model
with open('meta.pkl', 'rb') as f:
meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s if c in stoi] # Filter unknown chars
decode = lambda l: ''.join([itos[i] for i in l])
model = GPTLanguageModel(meta['vocab_size'])
# Load just the weights (state_dict) or the whole model
try:
checkpoint = torch.load("finetuned_model.pt", map_location=device)
if isinstance(checkpoint, dict):
model.load_state_dict(checkpoint)
else:
model = checkpoint
model.to(device)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print(f"Error: {e}")
class ChatRequest(BaseModel):
prompt: str
@app.post("/chat")
async def chat(request: ChatRequest):
# Wrap the prompt to force a dialogue structure
context = f"User: {request.prompt}\nChen Bot:"
input_ids = torch.tensor(encode(context), dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
output_ids = model.generate(input_ids, max_new_tokens=200)[0].tolist()
full_text = decode(output_ids)
# Extract response: strictly what comes after our manually injected context
reply_start_index = len(context)
raw_reply = full_text[reply_start_index:]
# Stop generation if the model tries to start a new "User:" turn
clean_reply = raw_reply.split("User:")[0].strip()
return {"reply": clean_reply}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |