Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ from llama_cpp import Llama
|
|
| 6 |
import os
|
| 7 |
|
| 8 |
# ββ Model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
-
MODEL_REPO = "newtechdevng/i_am_a_lawyer"
|
| 10 |
MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf"
|
| 11 |
SYSTEM_PROMPT = (
|
| 12 |
"You are Ambuj, an expert AI assistant specialised in Indian law. "
|
|
@@ -19,8 +19,10 @@ print("Loading model β¦")
|
|
| 19 |
llm = Llama.from_pretrained(
|
| 20 |
repo_id=MODEL_REPO,
|
| 21 |
filename=MODEL_FILE,
|
| 22 |
-
n_ctx=
|
| 23 |
-
n_threads=os.cpu_count()
|
|
|
|
|
|
|
| 24 |
verbose=False,
|
| 25 |
)
|
| 26 |
print("Model ready β")
|
|
@@ -28,7 +30,7 @@ print("Model ready β")
|
|
| 28 |
# ββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
app = FastAPI(
|
| 30 |
title="Indian Legal AI API",
|
| 31 |
-
description="API for the Ambuj
|
| 32 |
version="1.0.0",
|
| 33 |
)
|
| 34 |
|
|
@@ -41,7 +43,7 @@ class Message(BaseModel):
|
|
| 41 |
|
| 42 |
class ChatRequest(BaseModel):
|
| 43 |
messages: list[Message]
|
| 44 |
-
max_tokens: Optional[int] = 512
|
| 45 |
temperature: Optional[float] = 0.7
|
| 46 |
stream: Optional[bool] = False
|
| 47 |
|
|
@@ -73,10 +75,9 @@ def health():
|
|
| 73 |
|
| 74 |
@app.post("/chat")
|
| 75 |
def chat(request: ChatRequest):
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
"""
|
| 80 |
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 81 |
for m in request.messages:
|
| 82 |
if m.role not in ("user", "assistant", "system"):
|
|
@@ -87,7 +88,7 @@ def chat(request: ChatRequest):
|
|
| 87 |
def generate():
|
| 88 |
for chunk in llm.create_chat_completion(
|
| 89 |
messages=messages,
|
| 90 |
-
max_tokens=
|
| 91 |
temperature=request.temperature,
|
| 92 |
stream=True,
|
| 93 |
):
|
|
@@ -99,7 +100,7 @@ def chat(request: ChatRequest):
|
|
| 99 |
|
| 100 |
response = llm.create_chat_completion(
|
| 101 |
messages=messages,
|
| 102 |
-
max_tokens=
|
| 103 |
temperature=request.temperature,
|
| 104 |
stream=False,
|
| 105 |
)
|
|
@@ -109,24 +110,24 @@ def chat(request: ChatRequest):
|
|
| 109 |
|
| 110 |
class AskRequest(BaseModel):
|
| 111 |
question: str
|
| 112 |
-
max_tokens: Optional[int] = 512
|
| 113 |
temperature: Optional[float] = 0.7
|
| 114 |
|
| 115 |
|
| 116 |
@app.post("/ask")
|
| 117 |
def ask(request: AskRequest):
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
messages = [
|
| 122 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 123 |
{"role": "user", "content": request.question},
|
| 124 |
]
|
| 125 |
response = llm.create_chat_completion(
|
| 126 |
messages=messages,
|
| 127 |
-
max_tokens=
|
| 128 |
temperature=request.temperature,
|
| 129 |
stream=False,
|
| 130 |
)
|
| 131 |
content = response["choices"][0]["message"]["content"]
|
| 132 |
-
return {"question": request.question, "answer": content}
|
|
|
|
| 6 |
import os
|
| 7 |
|
| 8 |
# ββ Model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
MODEL_REPO = "newtechdevng/i_am_a_lawyer"
|
| 10 |
MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf"
|
| 11 |
SYSTEM_PROMPT = (
|
| 12 |
"You are Ambuj, an expert AI assistant specialised in Indian law. "
|
|
|
|
| 19 |
llm = Llama.from_pretrained(
|
| 20 |
repo_id=MODEL_REPO,
|
| 21 |
filename=MODEL_FILE,
|
| 22 |
+
n_ctx=512, # β was 4096 (killed RAM); 512 is enough for legal Q&A
|
| 23 |
+
n_threads=2, # β was os.cpu_count(); free tier has 2 vCPUs, use both safely
|
| 24 |
+
n_batch=64, # β smaller prompt batch = less peak RAM
|
| 25 |
+
n_gpu_layers=0, # β no GPU on free tier, keep at 0
|
| 26 |
verbose=False,
|
| 27 |
)
|
| 28 |
print("Model ready β")
|
|
|
|
| 30 |
# ββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
app = FastAPI(
|
| 32 |
title="Indian Legal AI API",
|
| 33 |
+
description="API for the Ambuj Indian Legal Llama model",
|
| 34 |
version="1.0.0",
|
| 35 |
)
|
| 36 |
|
|
|
|
| 43 |
|
| 44 |
class ChatRequest(BaseModel):
|
| 45 |
messages: list[Message]
|
| 46 |
+
max_tokens: Optional[int] = 256 # β was 512; lowered default
|
| 47 |
temperature: Optional[float] = 0.7
|
| 48 |
stream: Optional[bool] = False
|
| 49 |
|
|
|
|
| 75 |
|
| 76 |
@app.post("/chat")
|
| 77 |
def chat(request: ChatRequest):
|
| 78 |
+
# Hard cap max_tokens to prevent OOM on long generations
|
| 79 |
+
safe_tokens = min(request.max_tokens or 256, 256)
|
| 80 |
+
|
|
|
|
| 81 |
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 82 |
for m in request.messages:
|
| 83 |
if m.role not in ("user", "assistant", "system"):
|
|
|
|
| 88 |
def generate():
|
| 89 |
for chunk in llm.create_chat_completion(
|
| 90 |
messages=messages,
|
| 91 |
+
max_tokens=safe_tokens,
|
| 92 |
temperature=request.temperature,
|
| 93 |
stream=True,
|
| 94 |
):
|
|
|
|
| 100 |
|
| 101 |
response = llm.create_chat_completion(
|
| 102 |
messages=messages,
|
| 103 |
+
max_tokens=safe_tokens,
|
| 104 |
temperature=request.temperature,
|
| 105 |
stream=False,
|
| 106 |
)
|
|
|
|
| 110 |
|
| 111 |
class AskRequest(BaseModel):
|
| 112 |
question: str
|
| 113 |
+
max_tokens: Optional[int] = 256 # β was 512; lowered default
|
| 114 |
temperature: Optional[float] = 0.7
|
| 115 |
|
| 116 |
|
| 117 |
@app.post("/ask")
|
| 118 |
def ask(request: AskRequest):
|
| 119 |
+
# Hard cap max_tokens to prevent OOM on long generations
|
| 120 |
+
safe_tokens = min(request.max_tokens or 256, 256)
|
| 121 |
+
|
| 122 |
messages = [
|
| 123 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 124 |
{"role": "user", "content": request.question},
|
| 125 |
]
|
| 126 |
response = llm.create_chat_completion(
|
| 127 |
messages=messages,
|
| 128 |
+
max_tokens=safe_tokens,
|
| 129 |
temperature=request.temperature,
|
| 130 |
stream=False,
|
| 131 |
)
|
| 132 |
content = response["choices"][0]["message"]["content"]
|
| 133 |
+
return {"question": request.question, "answer": content}
|