from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import Optional from llama_cpp import Llama import os # ── Model loading ────────────────────────────────────────────────────────────── MODEL_REPO = "newtechdevng/i_am_a_lawyer" MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf" SYSTEM_PROMPT = ( "You are Ambuj, an expert AI assistant specialised in Indian law. " "You provide accurate, well-structured legal information based on Indian statutes, " "case law, and legal procedures. Always clarify that your responses are for " "informational purposes only and not a substitute for professional legal advice." ) print("Loading model …") llm = Llama.from_pretrained( repo_id=MODEL_REPO, filename=MODEL_FILE, n_ctx=512, # ← was 4096 (killed RAM); 512 is enough for legal Q&A n_threads=2, # ← was os.cpu_count(); free tier has 2 vCPUs, use both safely n_batch=64, # ← smaller prompt batch = less peak RAM n_gpu_layers=0, # ← no GPU on free tier, keep at 0 verbose=False, ) print("Model ready ✓") # ── FastAPI app ──────────────────────────────────────────────────────────────── app = FastAPI( title="Indian Legal AI API", description="API for the Ambuj Indian Legal Llama model", version="1.0.0", ) # ── Request / Response schemas ───────────────────────────────────────────────── class Message(BaseModel): role: str # "user" | "assistant" | "system" content: str class ChatRequest(BaseModel): messages: list[Message] max_tokens: Optional[int] = 256 # ← was 512; lowered default temperature: Optional[float] = 0.7 stream: Optional[bool] = False class ChatResponse(BaseModel): role: str = "assistant" content: str # ── Routes ───────────────────────────────────────────────────────────────────── @app.get("/") def root(): return { "name": "Indian Legal AI API", "model": MODEL_FILE, "endpoints": { "POST /chat": "Send messages, get a response", "POST /ask": "Simple single-question shortcut", "GET /health": "Health check", "GET /docs": "Swagger UI", } } @app.get("/health") def health(): return {"status": "ok", "model_loaded": llm is not None} @app.post("/chat") def chat(request: ChatRequest): # Hard cap max_tokens to prevent OOM on long generations safe_tokens = min(request.max_tokens or 256, 256) messages = [{"role": "system", "content": SYSTEM_PROMPT}] for m in request.messages: if m.role not in ("user", "assistant", "system"): raise HTTPException(status_code=400, detail=f"Invalid role: {m.role}") messages.append({"role": m.role, "content": m.content}) if request.stream: def generate(): for chunk in llm.create_chat_completion( messages=messages, max_tokens=safe_tokens, temperature=request.temperature, stream=True, ): delta = chunk["choices"][0]["delta"].get("content", "") if delta: yield delta return StreamingResponse(generate(), media_type="text/plain") response = llm.create_chat_completion( messages=messages, max_tokens=safe_tokens, temperature=request.temperature, stream=False, ) content = response["choices"][0]["message"]["content"] return ChatResponse(content=content) class AskRequest(BaseModel): question: str max_tokens: Optional[int] = 256 # ← was 512; lowered default temperature: Optional[float] = 0.7 @app.post("/ask") def ask(request: AskRequest): # Hard cap max_tokens to prevent OOM on long generations safe_tokens = min(request.max_tokens or 256, 256) messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": request.question}, ] response = llm.create_chat_completion( messages=messages, max_tokens=safe_tokens, temperature=request.temperature, stream=False, ) content = response["choices"][0]["message"]["content"] return {"question": request.question, "answer": content}