newtechdevng's picture
Update app.py
6c74002 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
from llama_cpp import Llama
import os
# ── Model loading ──────────────────────────────────────────────────────────────
MODEL_REPO = "newtechdevng/i_am_a_lawyer"
MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf"
SYSTEM_PROMPT = (
"You are Ambuj, an expert AI assistant specialised in Indian law. "
"You provide accurate, well-structured legal information based on Indian statutes, "
"case law, and legal procedures. Always clarify that your responses are for "
"informational purposes only and not a substitute for professional legal advice."
)
print("Loading model …")
llm = Llama.from_pretrained(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
n_ctx=512, # ← was 4096 (killed RAM); 512 is enough for legal Q&A
n_threads=2, # ← was os.cpu_count(); free tier has 2 vCPUs, use both safely
n_batch=64, # ← smaller prompt batch = less peak RAM
n_gpu_layers=0, # ← no GPU on free tier, keep at 0
verbose=False,
)
print("Model ready βœ“")
# ── FastAPI app ────────────────────────────────────────────────────────────────
app = FastAPI(
title="Indian Legal AI API",
description="API for the Ambuj Indian Legal Llama model",
version="1.0.0",
)
# ── Request / Response schemas ─────────────────────────────────────────────────
class Message(BaseModel):
role: str # "user" | "assistant" | "system"
content: str
class ChatRequest(BaseModel):
messages: list[Message]
max_tokens: Optional[int] = 256 # ← was 512; lowered default
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatResponse(BaseModel):
role: str = "assistant"
content: str
# ── Routes ─────────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {
"name": "Indian Legal AI API",
"model": MODEL_FILE,
"endpoints": {
"POST /chat": "Send messages, get a response",
"POST /ask": "Simple single-question shortcut",
"GET /health": "Health check",
"GET /docs": "Swagger UI",
}
}
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": llm is not None}
@app.post("/chat")
def chat(request: ChatRequest):
# Hard cap max_tokens to prevent OOM on long generations
safe_tokens = min(request.max_tokens or 256, 256)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for m in request.messages:
if m.role not in ("user", "assistant", "system"):
raise HTTPException(status_code=400, detail=f"Invalid role: {m.role}")
messages.append({"role": m.role, "content": m.content})
if request.stream:
def generate():
for chunk in llm.create_chat_completion(
messages=messages,
max_tokens=safe_tokens,
temperature=request.temperature,
stream=True,
):
delta = chunk["choices"][0]["delta"].get("content", "")
if delta:
yield delta
return StreamingResponse(generate(), media_type="text/plain")
response = llm.create_chat_completion(
messages=messages,
max_tokens=safe_tokens,
temperature=request.temperature,
stream=False,
)
content = response["choices"][0]["message"]["content"]
return ChatResponse(content=content)
class AskRequest(BaseModel):
question: str
max_tokens: Optional[int] = 256 # ← was 512; lowered default
temperature: Optional[float] = 0.7
@app.post("/ask")
def ask(request: AskRequest):
# Hard cap max_tokens to prevent OOM on long generations
safe_tokens = min(request.max_tokens or 256, 256)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": request.question},
]
response = llm.create_chat_completion(
messages=messages,
max_tokens=safe_tokens,
temperature=request.temperature,
stream=False,
)
content = response["choices"][0]["message"]["content"]
return {"question": request.question, "answer": content}