| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| from typing import Optional |
| from llama_cpp import Llama |
| import os |
|
|
| |
| MODEL_REPO = "newtechdevng/i_am_a_lawyer" |
| MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf" |
| SYSTEM_PROMPT = ( |
| "You are Ambuj, an expert AI assistant specialised in Indian law. " |
| "You provide accurate, well-structured legal information based on Indian statutes, " |
| "case law, and legal procedures. Always clarify that your responses are for " |
| "informational purposes only and not a substitute for professional legal advice." |
| ) |
|
|
| print("Loading model β¦") |
| llm = Llama.from_pretrained( |
| repo_id=MODEL_REPO, |
| filename=MODEL_FILE, |
| n_ctx=512, |
| n_threads=2, |
| n_batch=64, |
| n_gpu_layers=0, |
| verbose=False, |
| ) |
| print("Model ready β") |
|
|
| |
| app = FastAPI( |
| title="Indian Legal AI API", |
| description="API for the Ambuj Indian Legal Llama model", |
| version="1.0.0", |
| ) |
|
|
|
|
| |
| class Message(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| class ChatRequest(BaseModel): |
| messages: list[Message] |
| max_tokens: Optional[int] = 256 |
| temperature: Optional[float] = 0.7 |
| stream: Optional[bool] = False |
|
|
|
|
| class ChatResponse(BaseModel): |
| role: str = "assistant" |
| content: str |
|
|
|
|
| |
| @app.get("/") |
| def root(): |
| return { |
| "name": "Indian Legal AI API", |
| "model": MODEL_FILE, |
| "endpoints": { |
| "POST /chat": "Send messages, get a response", |
| "POST /ask": "Simple single-question shortcut", |
| "GET /health": "Health check", |
| "GET /docs": "Swagger UI", |
| } |
| } |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "model_loaded": llm is not None} |
|
|
|
|
| @app.post("/chat") |
| def chat(request: ChatRequest): |
| |
| safe_tokens = min(request.max_tokens or 256, 256) |
|
|
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
| for m in request.messages: |
| if m.role not in ("user", "assistant", "system"): |
| raise HTTPException(status_code=400, detail=f"Invalid role: {m.role}") |
| messages.append({"role": m.role, "content": m.content}) |
|
|
| if request.stream: |
| def generate(): |
| for chunk in llm.create_chat_completion( |
| messages=messages, |
| max_tokens=safe_tokens, |
| temperature=request.temperature, |
| stream=True, |
| ): |
| delta = chunk["choices"][0]["delta"].get("content", "") |
| if delta: |
| yield delta |
|
|
| return StreamingResponse(generate(), media_type="text/plain") |
|
|
| response = llm.create_chat_completion( |
| messages=messages, |
| max_tokens=safe_tokens, |
| temperature=request.temperature, |
| stream=False, |
| ) |
| content = response["choices"][0]["message"]["content"] |
| return ChatResponse(content=content) |
|
|
|
|
| class AskRequest(BaseModel): |
| question: str |
| max_tokens: Optional[int] = 256 |
| temperature: Optional[float] = 0.7 |
|
|
|
|
| @app.post("/ask") |
| def ask(request: AskRequest): |
| |
| safe_tokens = min(request.max_tokens or 256, 256) |
|
|
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": request.question}, |
| ] |
| response = llm.create_chat_completion( |
| messages=messages, |
| max_tokens=safe_tokens, |
| temperature=request.temperature, |
| stream=False, |
| ) |
| content = response["choices"][0]["message"]["content"] |
| return {"question": request.question, "answer": content} |