newtechdevng commited on
Commit
6c74002
Β·
verified Β·
1 Parent(s): 1a9139c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -6,7 +6,7 @@ from llama_cpp import Llama
6
  import os
7
 
8
  # ── Model loading ──────────────────────────────────────────────────────────────
9
- MODEL_REPO = "newtechdevng/i_am_a_lawyer" # ← change to your repo
10
  MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf"
11
  SYSTEM_PROMPT = (
12
  "You are Ambuj, an expert AI assistant specialised in Indian law. "
@@ -19,8 +19,10 @@ print("Loading model …")
19
  llm = Llama.from_pretrained(
20
  repo_id=MODEL_REPO,
21
  filename=MODEL_FILE,
22
- n_ctx=4096,
23
- n_threads=os.cpu_count() or 4,
 
 
24
  verbose=False,
25
  )
26
  print("Model ready βœ“")
@@ -28,7 +30,7 @@ print("Model ready βœ“")
28
  # ── FastAPI app ────────────────────────────────────────────────────────────────
29
  app = FastAPI(
30
  title="Indian Legal AI API",
31
- description="API for the Ambuj-Tripathi Indian Legal Llama model",
32
  version="1.0.0",
33
  )
34
 
@@ -41,7 +43,7 @@ class Message(BaseModel):
41
 
42
  class ChatRequest(BaseModel):
43
  messages: list[Message]
44
- max_tokens: Optional[int] = 512
45
  temperature: Optional[float] = 0.7
46
  stream: Optional[bool] = False
47
 
@@ -73,10 +75,9 @@ def health():
73
 
74
  @app.post("/chat")
75
  def chat(request: ChatRequest):
76
- """
77
- Full chat endpoint β€” pass a list of messages with roles.
78
- Optionally stream the response.
79
- """
80
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
81
  for m in request.messages:
82
  if m.role not in ("user", "assistant", "system"):
@@ -87,7 +88,7 @@ def chat(request: ChatRequest):
87
  def generate():
88
  for chunk in llm.create_chat_completion(
89
  messages=messages,
90
- max_tokens=request.max_tokens,
91
  temperature=request.temperature,
92
  stream=True,
93
  ):
@@ -99,7 +100,7 @@ def chat(request: ChatRequest):
99
 
100
  response = llm.create_chat_completion(
101
  messages=messages,
102
- max_tokens=request.max_tokens,
103
  temperature=request.temperature,
104
  stream=False,
105
  )
@@ -109,24 +110,24 @@ def chat(request: ChatRequest):
109
 
110
  class AskRequest(BaseModel):
111
  question: str
112
- max_tokens: Optional[int] = 512
113
  temperature: Optional[float] = 0.7
114
 
115
 
116
  @app.post("/ask")
117
  def ask(request: AskRequest):
118
- """
119
- Simple single-question shortcut β€” no need to format messages manually.
120
- """
121
  messages = [
122
  {"role": "system", "content": SYSTEM_PROMPT},
123
  {"role": "user", "content": request.question},
124
  ]
125
  response = llm.create_chat_completion(
126
  messages=messages,
127
- max_tokens=request.max_tokens,
128
  temperature=request.temperature,
129
  stream=False,
130
  )
131
  content = response["choices"][0]["message"]["content"]
132
- return {"question": request.question, "answer": content}
 
6
  import os
7
 
8
  # ── Model loading ──────────────────────────────────────────────────────────────
9
+ MODEL_REPO = "newtechdevng/i_am_a_lawyer"
10
  MODEL_FILE = "llama-3.2-1b-instruct.Q4_K_M.gguf"
11
  SYSTEM_PROMPT = (
12
  "You are Ambuj, an expert AI assistant specialised in Indian law. "
 
19
  llm = Llama.from_pretrained(
20
  repo_id=MODEL_REPO,
21
  filename=MODEL_FILE,
22
+ n_ctx=512, # ← was 4096 (killed RAM); 512 is enough for legal Q&A
23
+ n_threads=2, # ← was os.cpu_count(); free tier has 2 vCPUs, use both safely
24
+ n_batch=64, # ← smaller prompt batch = less peak RAM
25
+ n_gpu_layers=0, # ← no GPU on free tier, keep at 0
26
  verbose=False,
27
  )
28
  print("Model ready βœ“")
 
30
  # ── FastAPI app ────────────────────────────────────────────────────────────────
31
  app = FastAPI(
32
  title="Indian Legal AI API",
33
+ description="API for the Ambuj Indian Legal Llama model",
34
  version="1.0.0",
35
  )
36
 
 
43
 
44
  class ChatRequest(BaseModel):
45
  messages: list[Message]
46
+ max_tokens: Optional[int] = 256 # ← was 512; lowered default
47
  temperature: Optional[float] = 0.7
48
  stream: Optional[bool] = False
49
 
 
75
 
76
  @app.post("/chat")
77
  def chat(request: ChatRequest):
78
+ # Hard cap max_tokens to prevent OOM on long generations
79
+ safe_tokens = min(request.max_tokens or 256, 256)
80
+
 
81
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
82
  for m in request.messages:
83
  if m.role not in ("user", "assistant", "system"):
 
88
  def generate():
89
  for chunk in llm.create_chat_completion(
90
  messages=messages,
91
+ max_tokens=safe_tokens,
92
  temperature=request.temperature,
93
  stream=True,
94
  ):
 
100
 
101
  response = llm.create_chat_completion(
102
  messages=messages,
103
+ max_tokens=safe_tokens,
104
  temperature=request.temperature,
105
  stream=False,
106
  )
 
110
 
111
  class AskRequest(BaseModel):
112
  question: str
113
+ max_tokens: Optional[int] = 256 # ← was 512; lowered default
114
  temperature: Optional[float] = 0.7
115
 
116
 
117
  @app.post("/ask")
118
  def ask(request: AskRequest):
119
+ # Hard cap max_tokens to prevent OOM on long generations
120
+ safe_tokens = min(request.max_tokens or 256, 256)
121
+
122
  messages = [
123
  {"role": "system", "content": SYSTEM_PROMPT},
124
  {"role": "user", "content": request.question},
125
  ]
126
  response = llm.create_chat_completion(
127
  messages=messages,
128
+ max_tokens=safe_tokens,
129
  temperature=request.temperature,
130
  stream=False,
131
  )
132
  content = response["choices"][0]["message"]["content"]
133
+ return {"question": request.question, "answer": content}