minigpt / main.py
debojitbasak's picture
Initial deployment
e900b67
# deploy/main.py
# Production server β€” serves API + static Next.js files
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
import torch
import torch.nn as nn
from torch.nn import functional as F
from huggingface_hub import hf_hub_download
import os
import math
import time
# ─────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────
HF_REPO_ID = "debojitbasak/minigpt-models"
CACHE_DIR = "./checkpoints"
os.makedirs(CACHE_DIR, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
# ─────────────────────────────────────────
# MODEL ARCHITECTURE
# ─────────────────────────────────────────
class Head(nn.Module):
def __init__(self, n_embd, head_size, block_size, dropout):
super().__init__()
self.query = nn.Linear(n_embd, head_size, bias=False)
self.key = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.query(x)
k = self.key(x)
wei = q @ k.transpose(-2, -1) * C**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
return wei @ self.value(x)
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, num_heads, head_size, block_size, dropout):
super().__init__()
self.heads = nn.ModuleList([Head(n_embd, head_size, block_size, dropout) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
return self.dropout(self.proj(out))
class FeedForward(nn.Module):
def __init__(self, n_embd, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head, block_size, dropout):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_embd, n_head, head_size, block_size, dropout)
self.ffwd = FeedForward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout):
super().__init__()
self.block_size = block_size
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[0][:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# ─────────────────────────────────────────
# MODEL LOADER
# ─────────────────────────────────────────
HF_FILENAMES = {
'english': 'english_maxed_model.pth',
'bangla': 'bangla_improved_model.pth',
}
_models = {}
def get_checkpoint_path(language: str) -> str:
filename = HF_FILENAMES[language]
local_path = os.path.join(CACHE_DIR, filename)
if os.path.exists(local_path):
print(f" Found locally: {local_path}")
return local_path
print(f" Downloading {filename} from Hugging Face...")
path = hf_hub_download(
repo_id = HF_REPO_ID,
filename = filename,
repo_type = "model",
local_dir = CACHE_DIR,
)
print(f" Downloaded: {path}")
return path
def load_model(language: str):
if language in _models:
return _models[language]
print(f"Loading {language} model...")
path = get_checkpoint_path(language)
checkpoint = torch.load(path, map_location=device, weights_only=False)
hp = checkpoint['hyperparameters']
model = GPTLanguageModel(
vocab_size = checkpoint['vocab_size'],
n_embd = hp['n_embd'],
n_head = hp['n_head'],
n_layer = hp['n_layer'],
block_size = hp['block_size'],
dropout = hp['dropout'],
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f" Loaded! {params:.2f}M params | val_loss={checkpoint['val_loss']:.4f}")
_models[language] = (model, checkpoint['stoi'], checkpoint['itos'])
return _models[language]
def generate_text(language, prompt="", max_tokens=200, temperature=1.0, top_k=40):
model, stoi, itos = load_model(language)
encode = lambda s: [stoi[c] for c in s if c in stoi]
decode = lambda l: ''.join([itos[i] for i in l])
if prompt and prompt.strip():
encoded = encode(prompt)
if not encoded:
encoded = [0]
context = torch.tensor(encoded, dtype=torch.long, device=device).unsqueeze(0)
else:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
output = model.generate(context, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k)
full_text = decode(output[0].tolist())
if prompt and prompt.strip():
return full_text[len(prompt):]
return full_text
# ─────────────────────────────────────────
# LIFESPAN
# ─────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Preloading models...")
for lang in ['english', 'bangla']:
try:
load_model(lang)
print(f" {lang} model ready βœ“")
except Exception as e:
print(f" {lang} model failed: {e}")
yield
# ─────────────────────────────────────────
# APP
# ─────────────────────────────────────────
app = FastAPI(
title = "MiniGPT API",
version = "1.0.0",
lifespan = lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_methods = ["*"],
allow_headers = ["*"],
)
# ─────────────────────────────────────────
# REQUEST / RESPONSE
# ─────────────────────────────────────────
class GenerateRequest(BaseModel):
language: str = Field(default="english")
prompt: str = Field(default="")
max_tokens: int = Field(default=200, ge=10, le=500)
temperature: float = Field(default=1.0, ge=0.1, le=2.0)
top_k: int = Field(default=40, ge=1, le=100)
class GenerateResponse(BaseModel):
generated_text: str
prompt: str
language: str
time_taken: float
tokens_generated: int
# ─────────────────────────────────────────
# API ROUTES
# ─────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/models")
def models():
return {
lang: {
"loaded": lang in _models,
"file": HF_FILENAMES.get(lang)
}
for lang in ['english', 'bangla']
}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
if req.language not in ["english", "bangla"]:
raise HTTPException(status_code=400, detail="language must be 'english' or 'bangla'")
try:
start = time.time()
text = generate_text(
language = req.language,
prompt = req.prompt,
max_tokens = req.max_tokens,
temperature = req.temperature,
top_k = req.top_k,
)
return GenerateResponse(
generated_text = text,
prompt = req.prompt,
language = req.language,
time_taken = round(time.time() - start, 2),
tokens_generated = len(text),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ─────────────────────────────────────────
# SERVE STATIC FRONTEND
# must be LAST β€” catches all remaining routes
# ─────────────────────────────────────────
app.mount("/", StaticFiles(directory="static", html=True), name="static")