# deploy/main.py # Production server — serves API + static Next.js files from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel, Field from contextlib import asynccontextmanager import torch import torch.nn as nn from torch.nn import functional as F from huggingface_hub import hf_hub_download import os import math import time # ───────────────────────────────────────── # CONFIG # ───────────────────────────────────────── HF_REPO_ID = "debojitbasak/minigpt-models" CACHE_DIR = "./checkpoints" os.makedirs(CACHE_DIR, exist_ok=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Device: {device}") # ───────────────────────────────────────── # MODEL ARCHITECTURE # ───────────────────────────────────────── class Head(nn.Module): def __init__(self, n_embd, head_size, block_size, dropout): super().__init__() self.query = nn.Linear(n_embd, head_size, bias=False) self.key = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) self.dropout = nn.Dropout(dropout) def forward(self, x): B, T, C = x.shape q = self.query(x) k = self.key(x) wei = q @ k.transpose(-2, -1) * C**-0.5 wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) wei = F.softmax(wei, dim=-1) wei = self.dropout(wei) return wei @ self.value(x) class MultiHeadAttention(nn.Module): def __init__(self, n_embd, num_heads, head_size, block_size, dropout): super().__init__() self.heads = nn.ModuleList([Head(n_embd, head_size, block_size, dropout) for _ in range(num_heads)]) self.proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) return self.dropout(self.proj(out)) class FeedForward(nn.Module): def __init__(self, n_embd, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Block(nn.Module): def __init__(self, n_embd, n_head, block_size, dropout): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_embd, n_head, head_size, block_size, dropout) self.ffwd = FeedForward(n_embd, dropout) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) def forward(self, x): x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class GPTLanguageModel(nn.Module): def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout): super().__init__() self.block_size = block_size self.token_embedding_table = nn.Embedding(vocab_size, n_embd) self.position_embedding_table = nn.Embedding(block_size, n_embd) self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size) def forward(self, idx, targets=None): B, T = idx.shape tok_emb = self.token_embedding_table(idx) pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) x = tok_emb + pos_emb x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: B, T, C = logits.shape loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[0][:, [-1]]] = float('-inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx # ───────────────────────────────────────── # MODEL LOADER # ───────────────────────────────────────── HF_FILENAMES = { 'english': 'english_maxed_model.pth', 'bangla': 'bangla_improved_model.pth', } _models = {} def get_checkpoint_path(language: str) -> str: filename = HF_FILENAMES[language] local_path = os.path.join(CACHE_DIR, filename) if os.path.exists(local_path): print(f" Found locally: {local_path}") return local_path print(f" Downloading {filename} from Hugging Face...") path = hf_hub_download( repo_id = HF_REPO_ID, filename = filename, repo_type = "model", local_dir = CACHE_DIR, ) print(f" Downloaded: {path}") return path def load_model(language: str): if language in _models: return _models[language] print(f"Loading {language} model...") path = get_checkpoint_path(language) checkpoint = torch.load(path, map_location=device, weights_only=False) hp = checkpoint['hyperparameters'] model = GPTLanguageModel( vocab_size = checkpoint['vocab_size'], n_embd = hp['n_embd'], n_head = hp['n_head'], n_layer = hp['n_layer'], block_size = hp['block_size'], dropout = hp['dropout'], ).to(device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() params = sum(p.numel() for p in model.parameters()) / 1e6 print(f" Loaded! {params:.2f}M params | val_loss={checkpoint['val_loss']:.4f}") _models[language] = (model, checkpoint['stoi'], checkpoint['itos']) return _models[language] def generate_text(language, prompt="", max_tokens=200, temperature=1.0, top_k=40): model, stoi, itos = load_model(language) encode = lambda s: [stoi[c] for c in s if c in stoi] decode = lambda l: ''.join([itos[i] for i in l]) if prompt and prompt.strip(): encoded = encode(prompt) if not encoded: encoded = [0] context = torch.tensor(encoded, dtype=torch.long, device=device).unsqueeze(0) else: context = torch.zeros((1, 1), dtype=torch.long, device=device) output = model.generate(context, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) full_text = decode(output[0].tolist()) if prompt and prompt.strip(): return full_text[len(prompt):] return full_text # ───────────────────────────────────────── # LIFESPAN # ───────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): print("Preloading models...") for lang in ['english', 'bangla']: try: load_model(lang) print(f" {lang} model ready ✓") except Exception as e: print(f" {lang} model failed: {e}") yield # ───────────────────────────────────────── # APP # ───────────────────────────────────────── app = FastAPI( title = "MiniGPT API", version = "1.0.0", lifespan = lifespan, ) app.add_middleware( CORSMiddleware, allow_origins = ["*"], allow_methods = ["*"], allow_headers = ["*"], ) # ───────────────────────────────────────── # REQUEST / RESPONSE # ───────────────────────────────────────── class GenerateRequest(BaseModel): language: str = Field(default="english") prompt: str = Field(default="") max_tokens: int = Field(default=200, ge=10, le=500) temperature: float = Field(default=1.0, ge=0.1, le=2.0) top_k: int = Field(default=40, ge=1, le=100) class GenerateResponse(BaseModel): generated_text: str prompt: str language: str time_taken: float tokens_generated: int # ───────────────────────────────────────── # API ROUTES # ───────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok"} @app.get("/models") def models(): return { lang: { "loaded": lang in _models, "file": HF_FILENAMES.get(lang) } for lang in ['english', 'bangla'] } @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): if req.language not in ["english", "bangla"]: raise HTTPException(status_code=400, detail="language must be 'english' or 'bangla'") try: start = time.time() text = generate_text( language = req.language, prompt = req.prompt, max_tokens = req.max_tokens, temperature = req.temperature, top_k = req.top_k, ) return GenerateResponse( generated_text = text, prompt = req.prompt, language = req.language, time_taken = round(time.time() - start, 2), tokens_generated = len(text), ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ───────────────────────────────────────── # SERVE STATIC FRONTEND # must be LAST — catches all remaining routes # ───────────────────────────────────────── app.mount("/", StaticFiles(directory="static", html=True), name="static")