File size: 4,851 Bytes
e483b5a
 
 
 
 
 
 
 
6c74002
e483b5a
 
 
 
 
 
 
 
 
 
 
 
6c74002
 
 
 
e483b5a
 
 
 
 
 
 
6c74002
e483b5a
 
 
 
 
 
 
 
 
 
 
 
6c74002
e483b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c74002
 
 
e483b5a
 
 
 
 
 
 
 
 
 
6c74002
e483b5a
 
 
 
 
 
 
 
 
 
 
6c74002
e483b5a
 
 
 
 
 
 
 
 
6c74002
e483b5a
 
 
 
 
6c74002
 
 
e483b5a
 
 
 
 
 
6c74002
e483b5a
 
 
 
6c74002
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
from llama_cpp import Llama
import os

# ── Model loading ──────────────────────────────────────────────────────────────
MODEL_REPO = "newtechdevng/i_am_a_lawyer"
MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf"
SYSTEM_PROMPT = (
    "You are Ambuj, an expert AI assistant specialised in Indian law. "
    "You provide accurate, well-structured legal information based on Indian statutes, "
    "case law, and legal procedures. Always clarify that your responses are for "
    "informational purposes only and not a substitute for professional legal advice."
)

print("Loading model …")
llm = Llama.from_pretrained(
    repo_id=MODEL_REPO,
    filename=MODEL_FILE,
    n_ctx=512,          # ← was 4096 (killed RAM); 512 is enough for legal Q&A
    n_threads=2,        # ← was os.cpu_count(); free tier has 2 vCPUs, use both safely
    n_batch=64,         # ← smaller prompt batch = less peak RAM
    n_gpu_layers=0,     # ← no GPU on free tier, keep at 0
    verbose=False,
)
print("Model ready βœ“")

# ── FastAPI app ────────────────────────────────────────────────────────────────
app = FastAPI(
    title="Indian Legal AI API",
    description="API for the Ambuj Indian Legal Llama model",
    version="1.0.0",
)


# ── Request / Response schemas ─────────────────────────────────────────────────
class Message(BaseModel):
    role: str        # "user" | "assistant" | "system"
    content: str


class ChatRequest(BaseModel):
    messages: list[Message]
    max_tokens: Optional[int] = 256          # ← was 512; lowered default
    temperature: Optional[float] = 0.7
    stream: Optional[bool] = False


class ChatResponse(BaseModel):
    role: str = "assistant"
    content: str


# ── Routes ─────────────────────────────────────────────────────────────────────
@app.get("/")
def root():
    return {
        "name": "Indian Legal AI API",
        "model": MODEL_FILE,
        "endpoints": {
            "POST /chat":   "Send messages, get a response",
            "POST /ask":    "Simple single-question shortcut",
            "GET  /health": "Health check",
            "GET  /docs":   "Swagger UI",
        }
    }


@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": llm is not None}


@app.post("/chat")
def chat(request: ChatRequest):
    # Hard cap max_tokens to prevent OOM on long generations
    safe_tokens = min(request.max_tokens or 256, 256)

    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for m in request.messages:
        if m.role not in ("user", "assistant", "system"):
            raise HTTPException(status_code=400, detail=f"Invalid role: {m.role}")
        messages.append({"role": m.role, "content": m.content})

    if request.stream:
        def generate():
            for chunk in llm.create_chat_completion(
                messages=messages,
                max_tokens=safe_tokens,
                temperature=request.temperature,
                stream=True,
            ):
                delta = chunk["choices"][0]["delta"].get("content", "")
                if delta:
                    yield delta

        return StreamingResponse(generate(), media_type="text/plain")

    response = llm.create_chat_completion(
        messages=messages,
        max_tokens=safe_tokens,
        temperature=request.temperature,
        stream=False,
    )
    content = response["choices"][0]["message"]["content"]
    return ChatResponse(content=content)


class AskRequest(BaseModel):
    question: str
    max_tokens: Optional[int] = 256          # ← was 512; lowered default
    temperature: Optional[float] = 0.7


@app.post("/ask")
def ask(request: AskRequest):
    # Hard cap max_tokens to prevent OOM on long generations
    safe_tokens = min(request.max_tokens or 256, 256)

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": request.question},
    ]
    response = llm.create_chat_completion(
        messages=messages,
        max_tokens=safe_tokens,
        temperature=request.temperature,
        stream=False,
    )
    content = response["choices"][0]["message"]["content"]
    return {"question": request.question, "answer": content}