CaffeinatedCoding commited on
Commit
a64025f
·
verified ·
1 Parent(s): f756c47

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. api/main.py +30 -9
  2. src/agent_v2.py +78 -19
  3. src/llm.py +103 -54
api/main.py CHANGED
@@ -37,10 +37,20 @@ def download_models():
37
 
38
  if not os.path.exists("models/ner_model"):
39
  logger.info("Downloading NER model...")
40
- snapshot_download(
41
- repo_id=repo_id, repo_type="model",
42
- allow_patterns="ner_model/*", local_dir="models", token=hf_token
43
- )
 
 
 
 
 
 
 
 
 
 
44
  logger.info("NER model downloaded")
45
  else:
46
  logger.info("NER model already exists")
@@ -48,10 +58,14 @@ def download_models():
48
  if not os.path.exists("models/faiss_index/index.faiss"):
49
  logger.info("Downloading FAISS index...")
50
  os.makedirs("models/faiss_index", exist_ok=True)
51
- hf_hub_download(repo_id=repo_id, filename="faiss_index/index.faiss",
52
- repo_type="model", local_dir="models", token=hf_token)
53
- hf_hub_download(repo_id=repo_id, filename="faiss_index/chunk_metadata.jsonl",
54
- repo_type="model", local_dir="models", token=hf_token)
 
 
 
 
55
  logger.info("FAISS index downloaded")
56
  else:
57
  logger.info("FAISS index already exists")
@@ -138,7 +152,14 @@ def serve_frontend():
138
 
139
  @app.get("/health")
140
  def health():
141
- return {"status": "ok", "service": "NyayaSetu", "version": "2.0.0", "agent": AGENT_VERSION}
 
 
 
 
 
 
 
142
 
143
 
144
  @app.post("/query", response_model=QueryResponse)
 
37
 
38
  if not os.path.exists("models/ner_model"):
39
  logger.info("Downloading NER model...")
40
+ os.makedirs("models/ner_model", exist_ok=True)
41
+ # NER model files — explicit downloads to avoid snapshot_download pattern bugs
42
+ ner_files = [
43
+ "config.json", "model.safetensors", "tokenizer.json",
44
+ "tokenizer_config.json", "training_args.bin", "training_results.json"
45
+ ]
46
+ for fname in ner_files:
47
+ try:
48
+ hf_hub_download(
49
+ repo_id=repo_id, filename=f"ner_model/{fname}",
50
+ repo_type="model", local_dir="models", token=hf_token
51
+ )
52
+ except Exception as e:
53
+ logger.warning(f"Could not download ner_model/{fname}: {e}")
54
  logger.info("NER model downloaded")
55
  else:
56
  logger.info("NER model already exists")
 
58
  if not os.path.exists("models/faiss_index/index.faiss"):
59
  logger.info("Downloading FAISS index...")
60
  os.makedirs("models/faiss_index", exist_ok=True)
61
+ # Download FAISS files explicitly to avoid snapshot_download pattern issues
62
+ faiss_files = ["index.faiss", "chunk_metadata.jsonl"]
63
+ for fname in faiss_files:
64
+ try:
65
+ hf_hub_download(repo_id=repo_id, filename=f"faiss_index/{fname}",
66
+ repo_type="model", local_dir="models", token=hf_token)
67
+ except Exception as fe:
68
+ logger.warning(f"Could not download faiss_index/{fname}: {fe}")
69
  logger.info("FAISS index downloaded")
70
  else:
71
  logger.info("FAISS index already exists")
 
152
 
153
  @app.get("/health")
154
  def health():
155
+ from src.agent_v2 import _circuit_breaker
156
+ return {
157
+ "status": "ok",
158
+ "service": "NyayaSetu",
159
+ "version": "2.0.0",
160
+ "agent": AGENT_VERSION,
161
+ "groq_circuit_breaker": _circuit_breaker.get_status()
162
+ }
163
 
164
 
165
  @app.post("/query", response_model=QueryResponse)
src/agent_v2.py CHANGED
@@ -28,12 +28,57 @@ from src.ner import extract_entities, augment_query
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
- from groq import Groq
32
- from tenacity import retry, stop_after_attempt, wait_exponential
33
  from dotenv import load_dotenv
 
 
 
34
 
35
  load_dotenv()
36
- _client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # ── Session store ─────────────────────────────────────────
39
  sessions: Dict[str, Dict] = {}
@@ -116,8 +161,13 @@ def update_session(session_id: str, analysis: Dict, user_message: str, response:
116
 
117
 
118
  # ── Pass 1: Analyse ───────────────────────────────────────
119
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=4))
 
120
  def analyse(user_message: str, session: Dict) -> Dict:
 
 
 
 
121
  summary = session.get("summary", "")
122
  last_msgs = session.get("last_3_messages", [])
123
  cs = session["case_state"]
@@ -165,16 +215,14 @@ Rules:
165
  - Update hypothesis confidence based on new evidence
166
  - search_queries must be specific legal questions for vector search"""
167
 
168
- response = _client.chat.completions.create(
169
- model="llama-3.3-70b-versatile",
170
  messages=[
171
  {"role": "system", "content": ANALYSIS_PROMPT},
172
  {"role": "user", "content": user_content}
173
- ],
174
- temperature=0.1,
175
- max_tokens=900
176
  )
177
- raw = response.choices[0].message.content.strip()
 
178
  raw = raw.replace("```json", "").replace("```", "").strip()
179
 
180
  try:
@@ -229,8 +277,13 @@ def retrieve_parallel(search_queries: List[str], top_k: int = 5) -> List[Dict]:
229
 
230
 
231
  # ── Pass 3: Respond ───────────────────────────────────────
232
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=8))
 
233
  def respond(user_message: str, analysis: Dict, chunks: List[Dict], session: Dict) -> str:
 
 
 
 
234
  system_prompt = build_prompt(analysis)
235
  cs = session["case_state"]
236
  turn_count = cs.get("turn_count", 0)
@@ -325,16 +378,14 @@ Instructions:
325
  - Opposition war-gaming: if giving strategy, include what the other side will argue
326
  {radar_instruction}"""
327
 
328
- response = _client.chat.completions.create(
329
- model="llama-3.3-70b-versatile",
330
  messages=[
331
  {"role": "system", "content": system_prompt},
332
  {"role": "user", "content": user_content}
333
- ],
334
- temperature=0.3,
335
- max_tokens=1500
336
  )
337
- return response.choices[0].message.content
 
338
 
339
 
340
  # ── Main entry point ──────────────────────────────────────
@@ -346,7 +397,11 @@ def run_query_v2(user_message: str, session_id: str) -> Dict[str, Any]:
346
  try:
347
  analysis = analyse(user_message, session)
348
  except Exception as e:
349
- logger.error(f"Pass 1 failed: {e}")
 
 
 
 
350
  analysis = {
351
  "tone": "casual", "format_requested": "none",
352
  "subject": "legal query", "action_needed": "advice",
@@ -404,7 +459,11 @@ def run_query_v2(user_message: str, session_id: str) -> Dict[str, Any]:
404
  try:
405
  answer = respond(user_message, analysis, chunks, session)
406
  except Exception as e:
407
- logger.error(f"Pass 3 failed: {e}")
 
 
 
 
408
  if chunks:
409
  fallback = "\n\n".join(
410
  f"[{c.get('title', 'Source')}]\n{c.get('text', '')[:400]}"
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 
32
  from dotenv import load_dotenv
33
+ import threading
34
+ import time
35
+ from src.llm import call_llm_raw
36
 
37
  load_dotenv()
38
+
39
+ # ── Circuit Breaker for Groq API ──────────────────────────
40
+ class CircuitBreaker:
41
+ """Simple circuit breaker to detect when Groq API is down."""
42
+ def __init__(self, failure_threshold=5, recovery_timeout=60):
43
+ self.failure_count = 0
44
+ self.failure_threshold = failure_threshold
45
+ self.recovery_timeout = recovery_timeout
46
+ self.last_failure_time = None
47
+ self.is_open = False
48
+ self.lock = threading.Lock()
49
+
50
+ def record_success(self):
51
+ with self.lock:
52
+ self.failure_count = 0
53
+ self.is_open = False
54
+
55
+ def record_failure(self):
56
+ with self.lock:
57
+ self.failure_count += 1
58
+ self.last_failure_time = time.time()
59
+ if self.failure_count >= self.failure_threshold:
60
+ self.is_open = True
61
+ logger.warning(f"Circuit breaker OPEN: {self.failure_count} failures detected")
62
+
63
+ def can_attempt(self) -> bool:
64
+ with self.lock:
65
+ if not self.is_open:
66
+ return True
67
+ # Try to recover after timeout
68
+ if time.time() - self.last_failure_time > self.recovery_timeout:
69
+ logger.info("Circuit breaker attempting recovery...")
70
+ self.is_open = False
71
+ self.failure_count = 0
72
+ return True
73
+ return False
74
+
75
+ def get_status(self) -> str:
76
+ with self.lock:
77
+ if self.is_open:
78
+ return f"OPEN ({self.failure_count} failures)"
79
+ return f"CLOSED ({self.failure_count} failures)"
80
+
81
+ _circuit_breaker = CircuitBreaker()
82
 
83
  # ── Session store ─────────────────────────────────────────
84
  sessions: Dict[str, Dict] = {}
 
161
 
162
 
163
  # ── Pass 1: Analyse ───────────────────────────────────────
164
+ # Retry up to 5 times with exponential backoff (1s to 16s) to handle transient failures
165
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(min=1, max=16, multiplier=1.5))
166
  def analyse(user_message: str, session: Dict) -> Dict:
167
+ if not _circuit_breaker.can_attempt():
168
+ logger.error(f"Circuit breaker OPEN - skipping Pass 1. Status: {_circuit_breaker.get_status()}")
169
+ raise Exception("Groq API circuit breaker is open - service unavailable")
170
+
171
  summary = session.get("summary", "")
172
  last_msgs = session.get("last_3_messages", [])
173
  cs = session["case_state"]
 
215
  - Update hypothesis confidence based on new evidence
216
  - search_queries must be specific legal questions for vector search"""
217
 
218
+ response = call_llm_raw(
 
219
  messages=[
220
  {"role": "system", "content": ANALYSIS_PROMPT},
221
  {"role": "user", "content": user_content}
222
+ ]
 
 
223
  )
224
+ _circuit_breaker.record_success() # API call succeeded
225
+ raw = response.strip()
226
  raw = raw.replace("```json", "").replace("```", "").strip()
227
 
228
  try:
 
277
 
278
 
279
  # ── Pass 3: Respond ───────────────────────────────────────
280
+ # Retry up to 5 times with exponential backoff (2s to 32s) — more aggressive than Pass 1
281
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(min=2, max=32, multiplier=1.5))
282
  def respond(user_message: str, analysis: Dict, chunks: List[Dict], session: Dict) -> str:
283
+ if not _circuit_breaker.can_attempt():
284
+ logger.error(f"Circuit breaker OPEN - skipping Pass 3. Status: {_circuit_breaker.get_status()}")
285
+ raise Exception("Groq API circuit breaker is open - service unavailable")
286
+
287
  system_prompt = build_prompt(analysis)
288
  cs = session["case_state"]
289
  turn_count = cs.get("turn_count", 0)
 
378
  - Opposition war-gaming: if giving strategy, include what the other side will argue
379
  {radar_instruction}"""
380
 
381
+ response = call_llm_raw(
 
382
  messages=[
383
  {"role": "system", "content": system_prompt},
384
  {"role": "user", "content": user_content}
385
+ ]
 
 
386
  )
387
+ _circuit_breaker.record_success() # API call succeeded
388
+ return response
389
 
390
 
391
  # ── Main entry point ──────────────────────────────────────
 
397
  try:
398
  analysis = analyse(user_message, session)
399
  except Exception as e:
400
+ error_type = type(e).__name__
401
+ logger.error(f"Pass 1 failed after retries: {error_type}: {e}. Circuit breaker: {_circuit_breaker.get_status()}")
402
+ # Record API failure if it was a connection error
403
+ if "APIConnectionError" in error_type or "RateLimitError" in error_type:
404
+ _circuit_breaker.record_failure()
405
  analysis = {
406
  "tone": "casual", "format_requested": "none",
407
  "subject": "legal query", "action_needed": "advice",
 
459
  try:
460
  answer = respond(user_message, analysis, chunks, session)
461
  except Exception as e:
462
+ error_type = type(e).__name__
463
+ logger.error(f"Pass 3 failed after retries: {error_type}: {e}. Circuit breaker: {_circuit_breaker.get_status()}")
464
+ # Record API failure if it was a connection error
465
+ if "APIConnectionError" in error_type or "RateLimitError" in error_type:
466
+ _circuit_breaker.record_failure()
467
  if chunks:
468
  fallback = "\n\n".join(
469
  f"[{c.get('title', 'Source')}]\n{c.get('text', '')[:400]}"
src/llm.py CHANGED
@@ -1,77 +1,126 @@
1
  """
2
- LLM module. Single Groq API call with tenacity retry.
 
 
3
 
4
- WHY Groq? Free tier, fastest inference (~500 tokens/sec).
5
- WHY temperature=0.1? Lower = more deterministic, less hallucination.
6
- WHY one call per query? Multi-step chains add latency and failure points.
 
7
  """
8
 
9
  import os
10
  import logging
11
- from groq import Groq
12
- from tenacity import retry, stop_after_attempt, wait_exponential
13
  from dotenv import load_dotenv
 
14
 
15
  load_dotenv()
16
-
17
  logger = logging.getLogger(__name__)
18
 
19
- api_key = os.getenv("GROQ_API_KEY")
20
- logger.info(f"GROQ_API_KEY loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- _client = Groq(
23
- api_key=api_key
24
- )
25
- logger.info("Groq client initialized successfully")
 
 
 
 
 
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def call_llm_raw(messages: list) -> str:
29
  """
30
- Call Groq with pre-built messages list.
31
  Used by V2 agent for Pass 1 and Pass 3.
32
  """
33
- try:
34
- response = _client.chat.completions.create(
35
- model="llama-3.3-70b-versatile",
36
- messages=messages,
37
- temperature=0.3,
38
- max_tokens=1500
39
- )
40
- return response.choices[0].message.content
41
- except Exception as e:
42
- logger.error(f"Groq API error in call_llm_raw: {type(e).__name__}: {str(e)}", exc_info=True)
43
- raise
44
 
45
 
46
- @retry(
47
- stop=stop_after_attempt(3),
48
- wait=wait_exponential(multiplier=1, min=2, max=8)
49
- )
50
  def call_llm(query: str, context: str) -> str:
51
  """
52
- Call Groq Llama-3. Used by V1 agent.
53
- Retries 3 times with exponential backoff.
54
  """
55
- try:
56
- user_message = f"""QUESTION: {query}
57
-
58
- SUPREME COURT JUDGMENT EXCERPTS:
59
- {context}
60
-
61
- Answer based only on the excerpts above. Cite judgment IDs.
62
- Use proper markdown formatting."""
63
-
64
- response = _client.chat.completions.create(
65
- model="llama-3.3-70b-versatile",
66
- messages=[
67
- {"role": "system", "content": "You are NyayaSetu, an Indian legal research assistant. Answer only from provided excerpts. Cite judgment IDs. End with: NOTE: This is not legal advice."},
68
- {"role": "user", "content": user_message}
69
- ],
70
- temperature=0.1,
71
- max_tokens=1500
72
- )
73
-
74
- return response.choices[0].message.content
75
- except Exception as e:
76
- logger.error(f"Groq API error in call_llm: {type(e).__name__}: {str(e)}", exc_info=True)
77
- raise
 
1
  """
2
+ LLM module. HuggingFace Inference API as primary.
3
+ Works natively from HF Spaces — same infrastructure.
4
+ Groq as local dev fallback.
5
 
6
+ WHY HF Inference API?
7
+ HF Spaces can always reach HuggingFace's own APIs.
8
+ No network routing issues. Uses existing HF_TOKEN.
9
+ Same Llama 3.3 70B model as Groq.
10
  """
11
 
12
  import os
13
  import logging
 
 
14
  from dotenv import load_dotenv
15
+ from tenacity import retry, stop_after_attempt, wait_exponential
16
 
17
  load_dotenv()
 
18
  logger = logging.getLogger(__name__)
19
 
20
+ # ── HuggingFace Inference API ─────────────────────────────
21
+ _hf_client = None
22
+
23
+ def _init_hf():
24
+ global _hf_client
25
+ token = os.getenv("HF_TOKEN")
26
+ if not token:
27
+ logger.warning("HF_TOKEN not set — HF Inference API disabled")
28
+ return False
29
+ try:
30
+ from huggingface_hub import InferenceClient
31
+ _hf_client = InferenceClient(
32
+ model="meta-llama/Llama-3.3-70B-Instruct",
33
+ token=token
34
+ )
35
+ logger.info("HF Inference API ready (Llama-3.3-70B)")
36
+ return True
37
+ except Exception as e:
38
+ logger.error(f"HF Inference API init failed: {e}")
39
+ return False
40
+
41
+ # ── Groq fallback (works locally, may be blocked on HF Spaces) ──
42
+ _groq_client = None
43
+
44
+ def _init_groq():
45
+ global _groq_client
46
+ api_key = os.getenv("GROQ_API_KEY")
47
+ if not api_key:
48
+ return False
49
+ try:
50
+ from groq import Groq
51
+ _groq_client = Groq(api_key=api_key)
52
+ logger.info("Groq ready as fallback")
53
+ return True
54
+ except Exception as e:
55
+ logger.error(f"Groq init failed: {e}")
56
+ return False
57
+
58
+ _hf_ready = _init_hf()
59
+ _groq_ready = _init_groq()
60
+
61
 
62
+ def _call_hf(messages: list) -> str:
63
+ """Call HuggingFace Inference API."""
64
+ # Convert to HF format
65
+ response = _hf_client.chat_completion(
66
+ messages=messages,
67
+ max_tokens=1500,
68
+ temperature=0.3,
69
+ )
70
+ return response.choices[0].message.content
71
 
72
 
73
+ def _call_groq(messages: list) -> str:
74
+ """Call Groq as fallback."""
75
+ response = _groq_client.chat.completions.create(
76
+ model="llama-3.3-70b-versatile",
77
+ messages=messages,
78
+ temperature=0.3,
79
+ max_tokens=1500
80
+ )
81
+ return response.choices[0].message.content
82
+
83
+
84
+ def _call_with_fallback(messages: list) -> str:
85
+ """Try HF first, fall back to Groq."""
86
+ if _hf_ready and _hf_client:
87
+ try:
88
+ return _call_hf(messages)
89
+ except Exception as e:
90
+ logger.warning(f"HF Inference failed: {e}, trying Groq")
91
+
92
+ if _groq_ready and _groq_client:
93
+ try:
94
+ return _call_groq(messages)
95
+ except Exception as e:
96
+ logger.error(f"Groq also failed: {e}")
97
+
98
+ raise Exception("All LLM providers failed")
99
+
100
+
101
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=8))
102
  def call_llm_raw(messages: list) -> str:
103
  """
104
+ Call LLM with pre-built messages list.
105
  Used by V2 agent for Pass 1 and Pass 3.
106
  """
107
+ return _call_with_fallback(messages)
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=8))
 
 
 
111
  def call_llm(query: str, context: str) -> str:
112
  """
113
+ Call LLM with query and context.
114
+ Used by V1 agent.
115
  """
116
+ messages = [
117
+ {
118
+ "role": "system",
119
+ "content": "You are NyayaSetu, an Indian legal research assistant. Answer only from provided excerpts. Cite judgment IDs. End with: NOTE: This is not legal advice."
120
+ },
121
+ {
122
+ "role": "user",
123
+ "content": f"QUESTION: {query}\n\nSOURCES:\n{context}\n\nAnswer based on sources. Cite judgment IDs."
124
+ }
125
+ ]
126
+ return _call_with_fallback(messages)