CaffeinatedCoding commited on
Commit
0214972
·
verified ·
1 Parent(s): 6372870

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ *.pyc
4
+ .git/
5
+ .dvc/cache/
6
+ data/
7
+ logs/
8
+ *.log
9
+ bug_log.md
.pytest_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Created by pytest automatically.
2
+ *
.pytest_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
2
+ # This file is a cache directory tag created by pytest.
3
+ # For information about cache directory tags, see:
4
+ # https://bford.info/cachedir/spec.html
.pytest_cache/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # pytest cache directory #
2
+
3
+ This directory contains data from the pytest's cache plugin,
4
+ which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
5
+
6
+ **Do not** commit this to version control.
7
+
8
+ See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
.pytest_cache/v/cache/lastfailed ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "tests/test_api.py": true
3
+ }
.pytest_cache/v/cache/nodeids ADDED
@@ -0,0 +1 @@
 
 
1
+ []
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && \
7
+ apt-get install -y git curl && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements first — Docker layer caching
11
+ # If requirements.txt hasn't changed, this layer is reused
12
+ # and pip install is skipped on rebuild. Saves 5+ minutes.
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy all project files
17
+ COPY . .
18
+
19
+ # HuggingFace Spaces requires port 7860
20
+ EXPOSE 7860
21
+
22
+ # Start FastAPI
23
+ CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Nyayasetu
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- license: other
9
- short_description: Answers around INDIAN LAWS based on SIMILAR PAST JUDGEMENTS
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Performance
 
 
 
 
 
 
 
 
 
2
 
3
+ | Query | Top Result | Score | Verified |
4
+ |-------|-----------|-------|----------|
5
+ | Rights of arrested person (Art 22) | A K Gopalan vs State of Madras (1950) | 0.719 | Unverified* |
6
+ | Freedom of speech (Art 19) | Kaushal Kishor vs State of UP (2023) | 0.768 | Unverified* |
7
+ | Double jeopardy | Manipur Administration vs Thokchom Bira Singh (1964) | 0.681 | ✅ Verified |
8
+ | Bail rules | Babu Singh vs State of UP (1978) | 0.695 | Unverified* |
9
+ | Basic structure doctrine | Puttaswamy vs Union of India (2017) | 0.760 | Unverified* |
10
+ | Right to privacy | R Rajagopal vs State of TN (1994) | 0.756 | Unverified* |
11
+
12
+ *Unverified = LLM paraphrased rather than copied verbatim.
13
+ Answer content is accurate. See Limitations section.
api/__init__.py ADDED
File without changes
api/main.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NyayaSetu FastAPI application.
3
+ 3 endpoints only.
4
+
5
+ All models loaded at startup — never per request.
6
+ Port 7860 for HuggingFace Spaces compatibility.
7
+ """
8
+
9
+ from fastapi import FastAPI, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ import time
13
+ import os
14
+ import sys
15
+ import logging
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+
22
+ # ── Startup: Download models from HuggingFace Hub ────────────
23
+ def download_models():
24
+ """
25
+ Downloads NER model and FAISS index from HF Hub at container startup.
26
+ Only downloads if files don't already exist.
27
+ Skips gracefully if HF_TOKEN is not set.
28
+ """
29
+ hf_token = os.getenv("HF_TOKEN")
30
+ if not hf_token:
31
+ logger.warning("HF_TOKEN not set — skipping model download. Models must exist locally.")
32
+ return
33
+
34
+ try:
35
+ from huggingface_hub import snapshot_download
36
+ repo_id = "CaffeinatedCoding/nyayasetu-models"
37
+
38
+ # NER model
39
+ if not os.path.exists("models/ner_model"):
40
+ logger.info("Downloading NER model from HuggingFace Hub...")
41
+ snapshot_download(
42
+ repo_id=repo_id,
43
+ repo_type="model",
44
+ allow_patterns="ner_model/*",
45
+ local_dir="models",
46
+ token=hf_token
47
+ )
48
+ logger.info("NER model downloaded successfully")
49
+ else:
50
+ logger.info("NER model already exists, skipping download")
51
+
52
+ # FAISS index + chunk metadata
53
+ if not os.path.exists("models/faiss_index/index.faiss"):
54
+ logger.info("Downloading FAISS index from HuggingFace Hub...")
55
+ snapshot_download(
56
+ repo_id=repo_id,
57
+ repo_type="model",
58
+ allow_patterns="faiss_index/*",
59
+ local_dir="models",
60
+ token=hf_token
61
+ )
62
+ logger.info("FAISS index downloaded successfully")
63
+ else:
64
+ logger.info("FAISS index already exists, skipping download")
65
+
66
+ # Parent judgments → goes into data/ folder
67
+ if not os.path.exists("data/parent_judgments.jsonl"):
68
+ logger.info("Downloading parent judgments from HuggingFace Hub...")
69
+ os.makedirs("data", exist_ok=True)
70
+ snapshot_download(
71
+ repo_id=repo_id,
72
+ repo_type="model",
73
+ allow_patterns="parent_judgments.jsonl",
74
+ local_dir="data",
75
+ token=hf_token
76
+ )
77
+ logger.info("Parent judgments downloaded successfully")
78
+ else:
79
+ logger.info("Parent judgments already exist, skipping download")
80
+
81
+ except Exception as e:
82
+ logger.error(f"Model download failed: {e}")
83
+ logger.error("App will start but pipeline may fail if models are missing")
84
+
85
+ # Run at startup before importing pipeline
86
+ download_models()
87
+
88
+ from src.agent import run_query
89
+
90
+ app = FastAPI(
91
+ title="NyayaSetu",
92
+ description="Indian Legal RAG Agent — Supreme Court Judgments 1950–2024",
93
+ version="1.0.0"
94
+ )
95
+
96
+ app.add_middleware(
97
+ CORSMiddleware,
98
+ allow_origins=["*"],
99
+ allow_methods=["*"],
100
+ allow_headers=["*"]
101
+ )
102
+
103
+ # ── Request/Response models ───────────────────────────
104
+ class QueryRequest(BaseModel):
105
+ query: str
106
+
107
+ class QueryResponse(BaseModel):
108
+ query: str
109
+ answer: str
110
+ sources: list
111
+ verification_status: str
112
+ unverified_quotes: list
113
+ entities: dict
114
+ num_sources: int
115
+ truncated: bool
116
+ latency_ms: float
117
+
118
+
119
+ # ── Endpoint 1: Health check ──────────────────────────
120
+ @app.get("/health")
121
+ def health():
122
+ return {
123
+ "status": "ok",
124
+ "service": "NyayaSetu",
125
+ "version": "1.0.0"
126
+ }
127
+
128
+
129
+ # ── Endpoint 2: App info ──────────────────────────────
130
+ @app.get("/")
131
+ def root():
132
+ return {
133
+ "name": "NyayaSetu",
134
+ "description": "Indian Legal RAG Agent",
135
+ "data": "Supreme Court of India judgments 1950-2024",
136
+ "disclaimer": "NOT legal advice. Always consult a qualified advocate.",
137
+ "endpoints": {
138
+ "POST /query": "Ask a legal question",
139
+ "GET /health": "Health check",
140
+ "GET /": "This info page"
141
+ }
142
+ }
143
+
144
+
145
+ # ── Endpoint 3: Main query pipeline ──────────────────
146
+ @app.post("/query", response_model=QueryResponse)
147
+ def query(request: QueryRequest):
148
+ if not request.query.strip():
149
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
150
+
151
+ if len(request.query) < 10:
152
+ raise HTTPException(status_code=400, detail="Query too short — minimum 10 characters")
153
+
154
+ if len(request.query) > 1000:
155
+ raise HTTPException(status_code=400, detail="Query too long — maximum 1000 characters")
156
+
157
+ start = time.time()
158
+ try:
159
+ result = run_query(request.query)
160
+ except Exception as e:
161
+ raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}")
162
+
163
+ result["latency_ms"] = round((time.time() - start) * 1000, 2)
164
+ return result
bug_log.md ADDED
File without changes
params.yaml ADDED
File without changes
preprocessing/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Makes preprocessing a Python package
2
+
3
+
4
+ #**4. Update `requirements.txt`:**
5
+
6
+ torch
7
+ transformers
8
+ sentence-transformers
9
+ faiss-cpu
10
+ fastapi
11
+ uvicorn
12
+ python-dotenv
13
+ groq
14
+ dvc
15
+ mlflow
16
+ optuna
17
+ pytest
18
+ kagglehub
19
+ pymupdf
20
+ tenacity
21
+ seqeval
preprocessing/download.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download Indian Supreme Court judgments from Kaggle.
3
+ Uses kagglehub to download directly - no manual zip extraction needed.
4
+ Output: data/raw_judgments.jsonl
5
+
6
+ WHY kagglehub? Programmatic download - reproducible, no manual steps.
7
+ Anyone cloning this repo can run this script and get the same data.
8
+ """
9
+
10
+ import kagglehub
11
+ import json
12
+ import os
13
+ import glob
14
+
15
+ def download_judgments():
16
+ print("Downloading SC Judgments dataset from Kaggle...")
17
+
18
+ # Downloads to a local cache folder, returns the path
19
+ path = kagglehub.dataset_download("adarshsingh0903/legal-dataset-sc-judgments-india-19502024")
20
+ print(f"Dataset downloaded to: {path}")
21
+
22
+ # See what files we got
23
+ all_files = []
24
+ for root, dirs, files in os.walk(path):
25
+ for file in files:
26
+ full_path = os.path.join(root, file)
27
+ all_files.append(full_path)
28
+ print(f" Found: {full_path}")
29
+
30
+ print(f"\nTotal files found: {len(all_files)}")
31
+ return path, all_files
32
+
33
+ if __name__ == "__main__":
34
+ path, files = download_judgments()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ huggingface_hub
5
+ sentence-transformers
6
+ numpy
7
+ groq
8
+ tenacity
9
+ python-dotenv
10
+ transformers
11
+ faiss-cpu
12
+ torch
13
+ kagglehub
14
+ pytest
src/__init__.py ADDED
File without changes
src/agent.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NyayaSetu RAG Agent — single-pass function.
3
+
4
+ Every user query goes through exactly these steps in order:
5
+ 1. NER extraction (if model available, else skip gracefully)
6
+ 2. Query augmentation (append extracted entities)
7
+ 3. Embed augmented query with MiniLM
8
+ 4. FAISS retrieval (top-5 chunks)
9
+ 5. Out-of-domain check (empty results = no relevant judgments)
10
+ 6. Context assembly (build prompt context from expanded windows)
11
+ 7. Single LLM call with retry
12
+ 8. Citation verification
13
+ 9. Return structured result
14
+
15
+ WHY single-pass and no while loop?
16
+ A while loop that retries the whole pipeline masks failures.
17
+ If retrieval returned bad results, retrying with the same query
18
+ returns the same bad results. Better to fail honestly and tell
19
+ the user, than to loop silently and return garbage.
20
+ """
21
+
22
+ import os
23
+ import sys
24
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
25
+
26
+ from src.embed import embed_text
27
+ from src.retrieval import retrieve
28
+ from src.llm import call_llm
29
+ from src.verify import verify_citations
30
+ from typing import Dict, Any
31
+
32
+ # NER is optional — if not trained yet, pipeline runs without it
33
+ # This is the Cut Line Rule from the blueprint:
34
+ # ship without NER rather than blocking the whole project
35
+ NER_AVAILABLE = False
36
+ try:
37
+ from src.ner import extract_entities
38
+ NER_AVAILABLE = True
39
+ print("NER model loaded — query augmentation active")
40
+ except Exception as e:
41
+ print(f"NER not available, running without entity augmentation: {e}")
42
+
43
+
44
+ def run_query(query: str) -> Dict[str, Any]:
45
+ """
46
+ Main pipeline. Input: user query string.
47
+ Output: structured dict with answer, sources, verification.
48
+ """
49
+
50
+ # ── Step 1: NER ──────────────────────────────────────────
51
+ entities = {}
52
+ augmented_query = query
53
+
54
+ if NER_AVAILABLE:
55
+ try:
56
+ entities = extract_entities(query)
57
+ entity_string = " ".join(
58
+ f"{etype}: {etext}"
59
+ for etype, texts in entities.items()
60
+ for etext in texts
61
+ )
62
+ if entity_string:
63
+ augmented_query = f"{query} {entity_string}"
64
+ except Exception as e:
65
+ print(f"NER failed, using raw query: {e}")
66
+ augmented_query = query
67
+
68
+ # ── Step 2: Embed ─────────────────────────────────────────
69
+ query_embedding = embed_text(augmented_query)
70
+
71
+ # ── Step 3: Retrieve ──────────────────────────────────────
72
+ retrieved_chunks = retrieve(query_embedding, top_k=5)
73
+
74
+ # ── Step 4: Out-of-domain check ───────────────────────────
75
+ if not retrieved_chunks:
76
+ return {
77
+ "query": query,
78
+ "augmented_query": augmented_query,
79
+ "answer": "Your query doesn't appear to relate to Indian law. "
80
+ "NyayaSetu can answer questions about Supreme Court judgments, "
81
+ "constitutional rights, statutes, and legal provisions. "
82
+ "Please ask a legal question.",
83
+ "sources": [],
84
+ "verification_status": "No sources retrieved",
85
+ "unverified_quotes": [],
86
+ "entities": entities,
87
+ "num_sources": 0,
88
+ "truncated": False
89
+ }
90
+
91
+ # ── Step 5: Context assembly ──────────────────────────────
92
+ # Check total token estimate — rough rule: 1 token ≈ 4 characters
93
+ # LLM context limit ~6000 tokens for context = ~24000 chars
94
+ LLM_CONTEXT_LIMIT_CHARS = 24000
95
+ truncated = False
96
+
97
+ context_parts = []
98
+ total_chars = 0
99
+
100
+ for i, chunk in enumerate(retrieved_chunks, 1):
101
+ excerpt = chunk["expanded_context"]
102
+ header = f"[EXCERPT {i} — {chunk['title']} | {chunk['year']} | ID: {chunk['judgment_id']}]\n"
103
+ part = header + excerpt + "\n"
104
+
105
+ if total_chars + len(part) > LLM_CONTEXT_LIMIT_CHARS:
106
+ # Drop remaining chunks — too long for LLM context
107
+ truncated = True
108
+ print(f"Context truncated at {i-1} of {len(retrieved_chunks)} chunks")
109
+ break
110
+
111
+ context_parts.append(part)
112
+ total_chars += len(part)
113
+
114
+ context = "\n".join(context_parts)
115
+
116
+ # ── Step 6: LLM call ──────────────────────────────────────
117
+ try:
118
+ answer = call_llm(query=query, context=context)
119
+ except Exception as e:
120
+ # All 3 retries failed — return raw excerpts as fallback
121
+ print(f"LLM call failed after retries: {e}")
122
+ fallback_excerpts = "\n\n".join(
123
+ f"[{c['title']} | {c['year']}]\n{c['chunk_text'][:500]}"
124
+ for c in retrieved_chunks
125
+ )
126
+ return {
127
+ "query": query,
128
+ "augmented_query": augmented_query,
129
+ "answer": f"LLM service temporarily unavailable. "
130
+ f"Most relevant excerpts shown below:\n\n{fallback_excerpts}",
131
+ "sources": _build_sources(retrieved_chunks),
132
+ "verification_status": "LLM unavailable",
133
+ "unverified_quotes": [],
134
+ "entities": entities,
135
+ "num_sources": len(retrieved_chunks),
136
+ "truncated": truncated
137
+ }
138
+
139
+ # ── Step 7: Citation verification ─────────────────────────
140
+ verification_status, unverified_quotes = verify_citations(answer, retrieved_chunks)
141
+
142
+ # ── Step 8: Return ────────────────────────────────────────
143
+ return {
144
+ "query": query,
145
+ "augmented_query": augmented_query,
146
+ "answer": answer,
147
+ "sources": _build_sources(retrieved_chunks),
148
+ "verification_status": verification_status,
149
+ "unverified_quotes": unverified_quotes,
150
+ "entities": entities,
151
+ "num_sources": len(retrieved_chunks),
152
+ "truncated": truncated
153
+ }
154
+
155
+
156
+ def _build_sources(chunks) -> list:
157
+ """Format retrieved chunks for API response."""
158
+ return [
159
+ {
160
+ "judgment_id": c["judgment_id"],
161
+ "title": c["title"],
162
+ "year": c["year"],
163
+ "similarity_score": round(c["similarity_score"], 4),
164
+ "excerpt": c["chunk_text"][:300] + "..."
165
+ }
166
+ for c in chunks
167
+ ]
168
+
169
+
170
+ if __name__ == "__main__":
171
+ # Smoke test — run directly to verify pipeline works end to end
172
+ test_queries = [
173
+ "What are the rights of an arrested person under Article 22?",
174
+ "What did the Supreme Court say about freedom of speech?",
175
+ "How do I bake a cake?" # Out of domain — should return no results
176
+ ]
177
+
178
+ for query in test_queries:
179
+ print(f"\n{'='*60}")
180
+ print(f"QUERY: {query}")
181
+ result = run_query(query)
182
+ print(f"SOURCES: {result['num_sources']}")
183
+ print(f"VERIFICATION: {result['verification_status']}")
184
+ print(f"ANSWER (first 300 chars):\n{result['answer'][:300]}")
src/embed.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embedding module. Loads MiniLM once at startup, never per request.
3
+ """
4
+ from sentence_transformers import SentenceTransformer
5
+ import numpy as np
6
+
7
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
8
+
9
+ print(f"Loading embedding model...")
10
+ _model = SentenceTransformer(MODEL_NAME)
11
+ print("Embedding model ready.")
12
+
13
+ def embed_text(text: str) -> np.ndarray:
14
+ """Embed a single string. Returns shape (384,)"""
15
+ return _model.encode(text, normalize_embeddings=True)
src/llm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM module. Single Groq API call with tenacity retry.
3
+
4
+ WHY Groq? Free tier, fastest inference (~500 tokens/sec).
5
+ WHY temperature=0.1? Lower = more deterministic, less hallucination.
6
+ WHY one call per query? Multi-step chains add latency and failure points.
7
+ Gemini is configured as backup if Groq fails permanently.
8
+ """
9
+
10
+ import os
11
+ from groq import Groq
12
+ from tenacity import retry, stop_after_attempt, wait_exponential
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ _client = Groq(api_key=os.getenv("GROQ_API_KEY"))
18
+
19
+ SYSTEM_PROMPT = """You are NyayaSetu, an Indian legal research assistant.
20
+
21
+ Rules you must follow:
22
+ 1. Answer ONLY using the provided Supreme Court judgment excerpts
23
+ 2. Never use outside knowledge
24
+ 3. Quote directly from excerpts when making factual claims — use double quotes
25
+ 4. Always cite the Judgment ID when referencing a case
26
+ 5. If excerpts don't contain enough information, say so explicitly
27
+ 6. End every response with: "NOTE: This is not legal advice. Consult a qualified advocate."
28
+ """
29
+
30
+ @retry(
31
+ stop=stop_after_attempt(3),
32
+ wait=wait_exponential(multiplier=1, min=2, max=8)
33
+ )
34
+ def call_llm(query: str, context: str) -> str:
35
+ """
36
+ Call Groq Llama-3. Retries 3 times with exponential backoff.
37
+ Raises LLMError after all retries fail — caller handles this.
38
+ """
39
+ user_message = f"""QUESTION: {query}
40
+
41
+ SUPREME COURT JUDGMENT EXCERPTS:
42
+ {context}
43
+
44
+ Answer based only on the excerpts above. Cite judgment IDs."""
45
+
46
+ response = _client.chat.completions.create(
47
+ model="llama-3.3-70b-versatile",
48
+ messages=[
49
+ {"role": "system", "content": SYSTEM_PROMPT},
50
+ {"role": "user", "content": user_message}
51
+ ],
52
+ temperature=0.1,
53
+ max_tokens=800
54
+ )
55
+
56
+ return response.choices[0].message.content
src/ner.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NER inference module.
3
+ Loads fine-tuned DistilBERT and extracts legal entities from query text.
4
+
5
+ Loaded once at FastAPI startup — never per request.
6
+ Called before FAISS retrieval to augment the query with extracted entities.
7
+
8
+ Example:
9
+ Input: "What did Justice Chandrachud say about Section 302 IPC?"
10
+ Output: {"JUDGE": ["Justice Chandrachud"],
11
+ "PROVISION": ["Section 302"],
12
+ "STATUTE": ["IPC"]}
13
+
14
+ The augmented query becomes:
15
+ "What did Justice Chandrachud say about Section 302 IPC?
16
+ JUDGE: Justice Chandrachud PROVISION: Section 302 STATUTE: IPC"
17
+
18
+ WHY augment the query?
19
+ MiniLM embeds the full query string. Adding extracted entities
20
+ explicitly shifts the embedding closer to chunks that mention
21
+ those specific legal terms — improving retrieval precision.
22
+ """
23
+
24
+ import os
25
+ from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
26
+
27
+ NER_MODEL_PATH = os.getenv("NER_MODEL_PATH", "models/ner_model")
28
+
29
+ TARGET_ENTITIES = {
30
+ "JUDGE", "COURT", "STATUTE", "PROVISION",
31
+ "CASE_NUMBER", "DATE", "PRECEDENT", "LAWYER",
32
+ "PETITIONER", "RESPONDENT", "GPE", "ORG"
33
+ }
34
+
35
+ # Load once at import time
36
+ if not os.path.exists(NER_MODEL_PATH):
37
+ raise FileNotFoundError(
38
+ f"NER model not found at {NER_MODEL_PATH}. "
39
+ "Train it on Kaggle first. "
40
+ "System will run without NER until model is available."
41
+ )
42
+
43
+ print(f"Loading NER model from {NER_MODEL_PATH}...")
44
+ _tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
45
+ _model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
46
+
47
+ _ner_pipeline = pipeline(
48
+ "ner",
49
+ model=_model,
50
+ tokenizer=_tokenizer,
51
+ aggregation_strategy="simple"
52
+ )
53
+ print("NER model ready.")
54
+
55
+
56
+ def extract_entities(text: str) -> dict:
57
+ """
58
+ Run NER on input text.
59
+ Returns dict of {entity_type: [entity_text, ...]}
60
+ Filters to only legally relevant entity types.
61
+ """
62
+ if not text.strip():
63
+ return {}
64
+
65
+ try:
66
+ results = _ner_pipeline(text)
67
+ except Exception as e:
68
+ print(f"NER inference failed: {e}")
69
+ return {}
70
+
71
+ entities = {}
72
+ for result in results:
73
+ entity_type = result["entity_group"]
74
+ entity_text = result["word"].strip()
75
+
76
+ if entity_type not in TARGET_ENTITIES:
77
+ continue
78
+ if len(entity_text) < 2: # Skip single characters
79
+ continue
80
+
81
+ if entity_type not in entities:
82
+ entities[entity_type] = []
83
+ if entity_text not in entities[entity_type]: # No duplicates
84
+ entities[entity_type].append(entity_text)
85
+
86
+ return entities
87
+
88
+
89
+ def augment_query(query: str, entities: dict) -> str:
90
+ """
91
+ Append extracted entities to query string.
92
+ Returns augmented query for embedding.
93
+ """
94
+ if not entities:
95
+ return query
96
+
97
+ entity_string = " ".join(
98
+ f"{etype}: {etext}"
99
+ for etype, texts in entities.items()
100
+ for etext in texts
101
+ )
102
+
103
+ return f"{query} {entity_string}"
104
+
105
+
106
+ if __name__ == "__main__":
107
+ # Quick test
108
+ test_queries = [
109
+ "What did Justice Chandrachud say about Article 21?",
110
+ "Find cases related to Section 302 IPC and bail",
111
+ "Supreme Court judgment on fundamental rights in 1978"
112
+ ]
113
+
114
+ for q in test_queries:
115
+ entities = extract_entities(q)
116
+ augmented = augment_query(q, entities)
117
+ print(f"\nQuery: {q}")
118
+ print(f"Entities: {entities}")
119
+ print(f"Augmented: {augmented}")
src/retrieval.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FAISS retrieval module.
3
+
4
+ Loads the FAISS index and chunk metadata once at startup.
5
+ Given a query embedding, returns the top-k most similar chunks
6
+ plus an expanded context window from the parent judgment.
7
+
8
+ WHY load at startup and not per request?
9
+ Loading a 650MB index takes ~3 seconds. If you loaded it per request,
10
+ every user query would take 3+ seconds just for setup. Loading once
11
+ at startup means retrieval takes ~5ms per query.
12
+ """
13
+
14
+ import json
15
+ import numpy as np
16
+ import faiss
17
+ import os
18
+ from typing import List, Dict
19
+
20
+ INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "models/faiss_index/index.faiss")
21
+ METADATA_PATH = os.getenv("METADATA_PATH", "models/faiss_index/chunk_metadata.jsonl")
22
+ PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
23
+ TOP_K = 5
24
+
25
+ # Similarity threshold — if best score is below this, query is out of domain
26
+ # Score range: 0 to 1 (cosine similarity with normalized vectors)
27
+ # 0.3 = very loose match, 0.5 = decent match, 0.7 = strong match
28
+ SIMILARITY_THRESHOLD = 0.45
29
+
30
+ def _load_resources():
31
+ """Load index, metadata and parent store. Called once at module import."""
32
+
33
+ print("Loading FAISS index...")
34
+ index = faiss.read_index(INDEX_PATH)
35
+ print(f"Index loaded: {index.ntotal} vectors")
36
+
37
+ print("Loading chunk metadata...")
38
+ metadata = []
39
+ with open(METADATA_PATH, "r", encoding="utf-8") as f:
40
+ for line in f:
41
+ metadata.append(json.loads(line))
42
+ print(f"Metadata loaded: {len(metadata)} chunks")
43
+
44
+ print("Loading parent judgments...")
45
+ parent_store = {}
46
+ with open(PARENT_PATH, "r", encoding="utf-8") as f:
47
+ for line in f:
48
+ parent = json.loads(line)
49
+ parent_store[parent["judgment_id"]] = parent["text"]
50
+ print(f"Parent store loaded: {len(parent_store)} judgments")
51
+
52
+ return index, metadata, parent_store
53
+
54
+ _index, _metadata, _parent_store = _load_resources()
55
+
56
+
57
+ def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
58
+ """
59
+ Find top-k chunks most similar to the query embedding.
60
+ Returns empty list if best score is below SIMILARITY_THRESHOLD
61
+ (meaning the query is likely out of domain).
62
+ """
63
+ query_vec = query_embedding.reshape(1, -1).astype(np.float32)
64
+ scores, indices = _index.search(query_vec, top_k)
65
+
66
+ # Check if best match is above threshold
67
+ best_score = float(scores[0][0])
68
+ if best_score < SIMILARITY_THRESHOLD:
69
+ return [] # Out of domain — agent will handle this
70
+
71
+ results = []
72
+ for score, idx in zip(scores[0], indices[0]):
73
+ if idx == -1:
74
+ continue
75
+
76
+ chunk = _metadata[idx]
77
+ expanded = _get_expanded_context(
78
+ chunk["judgment_id"],
79
+ chunk["text"]
80
+ )
81
+
82
+ results.append({
83
+ "chunk_id": chunk["chunk_id"],
84
+ "judgment_id": chunk["judgment_id"],
85
+ "title": chunk.get("title", ""),
86
+ "year": chunk.get("year", ""),
87
+ "chunk_text": chunk["text"],
88
+ "expanded_context": expanded,
89
+ "similarity_score": float(score)
90
+ })
91
+
92
+ return results
93
+
94
+
95
+ def _get_expanded_context(judgment_id: str, chunk_text: str) -> str:
96
+ """
97
+ Get ~1024 token window from parent judgment centred on the chunk.
98
+ Falls back to chunk text if parent not found.
99
+
100
+ WHY expand context?
101
+ The chunk is 512 tokens — enough for retrieval.
102
+ But the LLM needs more surrounding context to give a complete answer.
103
+ We go back to the full judgment and extract a wider window.
104
+ """
105
+ parent_text = _parent_store.get(judgment_id, "")
106
+ if not parent_text:
107
+ return chunk_text
108
+
109
+ # Find chunk position in parent
110
+ anchor = chunk_text[:80]
111
+ start_pos = parent_text.find(anchor)
112
+ if start_pos == -1:
113
+ return chunk_text
114
+
115
+ # ~4 chars per token, 1024 tokens = ~4096 chars
116
+ WINDOW = 4096
117
+ expand_start = max(0, start_pos - WINDOW // 4)
118
+ expand_end = min(len(parent_text), start_pos + WINDOW)
119
+
120
+ return parent_text[expand_start:expand_end]
src/verify.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Citation verification. Deterministic string matching — no ML.
3
+
4
+ LOGIC:
5
+ - Extract all quoted phrases (in double quotes) from LLM answer
6
+ - Check each phrase verbatim against all retrieved chunk texts
7
+ - ALL found → Verified
8
+ - ANY missing → Unverified
9
+ - No quotes in answer → Verified (no verifiable claim made)
10
+
11
+ DOCUMENTED LIMITATION:
12
+ Paraphrased claims that are not quoted pass as Verified.
13
+ Full NLI-based verification is out of scope — documented in README.
14
+ """
15
+
16
+ import re
17
+ from typing import List, Dict, Tuple
18
+
19
+ def extract_quotes(text: str) -> List[str]:
20
+ """Extract double-quoted phrases of at least 8 characters."""
21
+ return re.findall(r'"([^"]{8,})"', text)
22
+
23
+ def verify_citations(
24
+ llm_answer: str,
25
+ retrieved_chunks: List[Dict]
26
+ ) -> Tuple[str, List[str]]:
27
+ """
28
+ Returns (status, unverified_quotes).
29
+ status: "Verified" | "Unverified" | "No verifiable claims"
30
+ """
31
+ quotes = extract_quotes(llm_answer)
32
+
33
+ if not quotes:
34
+ return "No verifiable claims", []
35
+
36
+ all_context = " ".join(
37
+ c.get("expanded_context", c.get("chunk_text", ""))
38
+ for c in retrieved_chunks
39
+ ).lower()
40
+
41
+ unverified = [q for q in quotes if q.lower() not in all_context]
42
+
43
+ if unverified:
44
+ return "Unverified", unverified
45
+ return "Verified", []
tests/test_api.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ import os
4
+
5
+ os.environ["SKIP_MODEL_LOAD"] = "true"
6
+
7
+ from api.main import app
8
+
9
+ client = TestClient(app)
10
+
11
+ def test_health():
12
+ response = client.get("/health")
13
+ assert response.status_code == 200
14
+ assert response.json()["status"] == "ok"
15
+
16
+ def test_info():
17
+ response = client.get("/info")
18
+ assert response.status_code == 200
19
+ assert "entity_types" in response.json()
20
+
21
+ def test_query_too_short():
22
+ response = client.post("/query", json={"query": "hi"})
23
+ assert response.status_code == 422
24
+
25
+ def test_query_too_long():
26
+ response = client.post("/query", json={"query": "a" * 2001})
27
+ assert response.status_code == 422
28
+
29
+ def test_query_empty():
30
+ response = client.post("/query", json={"query": ""})
31
+ assert response.status_code == 422