Sai Pranav Reddy commited on
Commit
968e24d
·
0 Parent(s):

Clean lightweight deployment

Browse files
Files changed (44) hide show
  1. .gitattributes +36 -0
  2. Dockerfile +38 -0
  3. README.md +11 -0
  4. outputs/summarization/final/model.safetensors +3 -0
  5. outputs/summarization/final/training_args.bin +3 -0
  6. requirements.txt +27 -0
  7. src/api/__pycache__/main.cpython-310.pyc +0 -0
  8. src/api/main.py +223 -0
  9. src/evaluation/__pycache__/evaluator.cpython-310.pyc +0 -0
  10. src/extraction/__pycache__/batch_processor.cpython-310.pyc +0 -0
  11. src/extraction/__pycache__/pdf_extractor.cpython-310.pyc +0 -0
  12. src/extraction/batch_processor.py +459 -0
  13. src/extraction/pdf_extractor.py +522 -0
  14. src/indexing/build_faiss_index.py +85 -0
  15. src/indexing/create_embeddings.py +232 -0
  16. src/indexing/create_sqlite_index.py +196 -0
  17. src/indexing/paragraph_indexer.py +167 -0
  18. src/pipeline.py +84 -0
  19. src/qa/__pycache__/dataset.cpython-310.pyc +0 -0
  20. src/qa/__pycache__/model.cpython-310.pyc +0 -0
  21. src/qa/dataset.py +69 -0
  22. src/qa/inference.py +206 -0
  23. src/qa/model.py +8 -0
  24. src/qa/monitor_training.py +21 -0
  25. src/qa/train.py +42 -0
  26. src/rag/__pycache__/query_engine.cpython-310.pyc +0 -0
  27. src/rag/query_engine.py +326 -0
  28. src/rag/test_retriever.py +75 -0
  29. src/segmentation/__pycache__/judgement_segmenter.cpython-310.pyc +0 -0
  30. src/segmentation/annotate_paragraphs.py +146 -0
  31. src/segmentation/check.py +41 -0
  32. src/segmentation/judgement_segmenter.py +195 -0
  33. src/summarization/__pycache__/composer.cpython-310.pyc +0 -0
  34. src/summarization/__pycache__/inference.cpython-310.pyc +0 -0
  35. src/summarization/__pycache__/model.cpython-310.pyc +0 -0
  36. src/summarization/__pycache__/ranker.cpython-310.pyc +0 -0
  37. src/summarization/__pycache__/utils.cpython-310.pyc +0 -0
  38. src/summarization/composer.py +17 -0
  39. src/summarization/dataset.py +74 -0
  40. src/summarization/inference.py +129 -0
  41. src/summarization/model.py +25 -0
  42. src/summarization/ranker.py +40 -0
  43. src/summarization/train.py +41 -0
  44. src/summarization/utils.py +5 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.db filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory to /app
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ git \
11
+ git-lfs \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy the requirements file into the container
15
+ COPY requirements.txt .
16
+
17
+ # Install Python dependencies
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Pre-download HF Models to cache them in the Docker image (prevents downloading on every boot)
21
+ RUN python -c "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM; AutoTokenizer.from_pretrained('nsi319/legal-pegasus'); AutoModelForSeq2SeqLM.from_pretrained('nsi319/legal-pegasus');"
22
+ RUN python -c "from sentence_transformers import SentenceTransformer, CrossEncoder; SentenceTransformer('BAAI/bge-base-en-v1.5'); CrossEncoder('BAAI/bge-reranker-base');"
23
+
24
+ # Copy the rest of the application code (including .git if not ignored)
25
+ COPY . .
26
+
27
+
28
+ # Download heavy databases from Dataset into their correct folders
29
+ RUN huggingface-cli download SaiPranav09/NyayLens-Data paragraphs.db --repo-type dataset --local-dir data/processed/indexed/
30
+ RUN huggingface-cli download SaiPranav09/NyayLens-Data faiss_index.bin --repo-type dataset --local-dir data/processed/faiss/
31
+ RUN huggingface-cli download SaiPranav09/NyayLens-Data model.safetensors --repo-type dataset --local-dir outputs/summarization/final/
32
+
33
+
34
+ # Expose port 7860 (Hugging Face standard)
35
+ EXPOSE 7860
36
+
37
+ # Command to run the application using Uvicorn
38
+ CMD ["uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NyayLens API
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ short_description: NyayLens-Legal AI Assistant
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
outputs/summarization/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f3beb2c13e1b5773787c8bb39605511469c4f77d96d2f38cca086c14d352f37
3
+ size 437956140
outputs/summarization/final/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbda47393ea48d83251e18b1d9e470b83cb2fa30ca734d2095ff10fc6b4fd2c6
3
+ size 5368
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NyayLens — Python Backend Dependencies
2
+ # Install with: pip install -r requirements.txt
3
+
4
+ # API
5
+ fastapi==0.115.12
6
+ uvicorn==0.34.0
7
+ python-multipart==0.0.20
8
+ pydantic==2.12.5
9
+
10
+ # RAG & Retrieval
11
+ faiss-cpu==1.9.0
12
+ sentence-transformers==5.2.0
13
+ groq==1.0.0
14
+
15
+ # Summarization
16
+ transformers==4.57.3
17
+ torch==2.6.0
18
+ sentencepiece==0.2.1
19
+ huggingface-hub==0.36.0
20
+
21
+ # PDF Extraction
22
+ pdfplumber==0.11.8
23
+
24
+ # Utilities
25
+ python-dotenv==1.0.1
26
+ safetensors==0.4.5
27
+ ragas==0.2.14
src/api/__pycache__/main.cpython-310.pyc ADDED
Binary file (6.17 kB). View file
 
src/api/main.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/api/main.py
2
+ import sys
3
+ import os
4
+ import io
5
+ import time
6
+ import uuid
7
+ import atexit
8
+ import shutil
9
+ import asyncio
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Request
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import JSONResponse
16
+ from pydantic import BaseModel, field_validator
17
+
18
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
19
+
20
+ from src.rag.query_engine import QueryEngine
21
+ from src.summarization.inference import summarize
22
+
23
+ # ── Constants ──────────────────────────────────────────────────────────────
24
+ MAX_UPLOAD_MB = 10
25
+ MAX_UPLOAD_BYTES = MAX_UPLOAD_MB * 1024 * 1024
26
+ UPLOAD_DIR = Path("data/uploads")
27
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
28
+ SUMMARIZE_TIMEOUT_S = 180 # 3 min max for summarization on CPU
29
+
30
+ # ── App ────────────────────────────────────────────────────────────────────
31
+ app = FastAPI(
32
+ title="NyayLens API",
33
+ description="Production API for Legal Chat, Document QA, and Summarization",
34
+ version="1.0.0",
35
+ )
36
+
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=[
40
+ "https://nyay-lens.vercel.app", # Production Vercel URL
41
+ "http://localhost:5173", # Local Vite dev server
42
+ "http://127.0.0.1:5173"
43
+ ],
44
+ allow_credentials=True,
45
+ allow_methods=["*"],
46
+ allow_headers=["*"],
47
+ )
48
+
49
+ # ── Startup / Shutdown ─────────────────────────────────────────────────────
50
+ async def cleanup_loop():
51
+ """Background task to remove leftover files older than 2 hours."""
52
+ while True:
53
+ now = time.time()
54
+ for f in UPLOAD_DIR.glob("*"):
55
+ if f.is_file() and (now - f.stat().st_mtime) > 7200:
56
+ try:
57
+ f.unlink()
58
+ except Exception as e:
59
+ print(f"Cleanup error: {e}")
60
+ await asyncio.sleep(3600) # Check every hour
61
+
62
+ @app.on_event("startup")
63
+ async def startup():
64
+ global query_engine
65
+ print("Initializing NyayLens Backend...")
66
+ query_engine = QueryEngine()
67
+
68
+ # Start the infinite cleanup loop
69
+ asyncio.create_task(cleanup_loop())
70
+ print("✓ Backend ready. Background cleanup active.")
71
+
72
+ @app.on_event("shutdown")
73
+ def shutdown():
74
+ """Clean up all uploaded files on server shutdown."""
75
+ if UPLOAD_DIR.exists():
76
+ shutil.rmtree(UPLOAD_DIR)
77
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
78
+ print("✓ Uploads directory cleaned on shutdown.")
79
+
80
+ # ── Schema ─────────────────────────────────────────────────────────────────
81
+ class UnifiedRequest(BaseModel):
82
+ message: str
83
+ filepath: Optional[str] = None
84
+ top_k: int = 5
85
+ chat_history: Optional[list] = []
86
+
87
+ @field_validator("message")
88
+ @classmethod
89
+ def message_not_empty(cls, v):
90
+ if not v or not v.strip():
91
+ raise ValueError("Message cannot be empty")
92
+ if len(v) > 4000:
93
+ raise ValueError("Message too long (max 4000 characters)")
94
+ return v.strip()
95
+
96
+ # ── Health ─────────────────────────────────────────────────────────────────
97
+ @app.get("/")
98
+ @app.get("/api/health")
99
+ def health():
100
+ return {
101
+ "status": "online",
102
+ "service": "NyayLens API",
103
+ "version": "1.0.0",
104
+ "models": ["Legal-BERT", "Legal-PEGASUS", "Llama-3.1-8B (Groq)"],
105
+ "index": "FAISS 298K vectors",
106
+ }
107
+
108
+ # ── Upload ─────────────────────────────────────────────────────────────────
109
+ @app.post("/api/upload")
110
+ async def upload_document(file: UploadFile = File(...)):
111
+ """
112
+ Accepts .pdf and .txt files up to 10 MB.
113
+ PDFs are extracted to plain text via pdfplumber.
114
+ Returns a server filepath for subsequent /api/chat calls.
115
+ """
116
+ import pdfplumber
117
+
118
+ # 1. Validate extension
119
+ filename = file.filename or "upload"
120
+ ext = Path(filename).suffix.lower()
121
+ if ext not in {".pdf", ".txt"}:
122
+ raise HTTPException(status_code=400, detail="Only .pdf and .txt files are supported.")
123
+
124
+ # 2. Read with size guard
125
+ raw_bytes = await file.read()
126
+ if len(raw_bytes) > MAX_UPLOAD_BYTES:
127
+ raise HTTPException(
128
+ status_code=413,
129
+ detail=f"File too large. Maximum allowed size is {MAX_UPLOAD_MB} MB."
130
+ )
131
+ if len(raw_bytes) == 0:
132
+ raise HTTPException(status_code=400, detail="Uploaded file is empty.")
133
+
134
+ # 3. Unique name to avoid collisions
135
+ uid = uuid.uuid4().hex[:8]
136
+ safe_name = f"{uid}_{Path(filename).stem}"
137
+
138
+ # 4. Extract / save
139
+ if ext == ".pdf":
140
+ text_parts = []
141
+ try:
142
+ with pdfplumber.open(io.BytesIO(raw_bytes)) as pdf:
143
+ for page in pdf.pages:
144
+ t = page.extract_text()
145
+ if t:
146
+ text_parts.append(t.strip())
147
+ except Exception as e:
148
+ raise HTTPException(status_code=400, detail=f"PDF extraction failed: {e}")
149
+
150
+ if not text_parts:
151
+ raise HTTPException(
152
+ status_code=422,
153
+ detail="PDF contains no readable text. It may be a scanned image — please use a searchable PDF."
154
+ )
155
+
156
+ out_path = UPLOAD_DIR / f"{safe_name}.txt"
157
+ out_path.write_text("\n\n".join(text_parts), encoding="utf-8")
158
+ return {"filepath": str(out_path), "filename": filename, "pages": len(text_parts), "size_kb": round(len(raw_bytes)/1024, 1)}
159
+
160
+ else:
161
+ out_path = UPLOAD_DIR / f"{safe_name}.txt"
162
+ out_path.write_bytes(raw_bytes)
163
+ return {"filepath": str(out_path), "filename": filename, "size_kb": round(len(raw_bytes)/1024, 1)}
164
+
165
+
166
+ # ── Chat ───────────────────────────────────────────────────────────────────
167
+ @app.post("/api/chat")
168
+ def chat(request: UnifiedRequest):
169
+ """
170
+ Unified intent-aware chat endpoint.
171
+ Routes to: Summarization | Document QA | Global RAG
172
+ """
173
+ message_lower = request.message.lower()
174
+
175
+ print(f"\n[BACKEND] '{request.message[:80]}' | file={os.path.basename(request.filepath) if request.filepath else 'None'}")
176
+
177
+ # Validate filepath if provided
178
+ if request.filepath:
179
+ if not os.path.exists(request.filepath):
180
+ return JSONResponse(
181
+ status_code=404,
182
+ content={"answer": "The uploaded document could not be found on the server. Please re-upload the file.", "sources": []}
183
+ )
184
+
185
+ try:
186
+ # ── Route 1: Summarization (with timeout) ──────────────────────────
187
+ if "summarize" in message_lower or "summary" in message_lower:
188
+ if not request.filepath:
189
+ return {
190
+ "answer": "Please **upload a PDF or text file** first using the 📎 button, then ask me to summarize it.",
191
+ "sources": []
192
+ }
193
+ print("[BACKEND] → Summarization pipeline")
194
+ summary_dict = summarize(request.filepath)
195
+ return {
196
+ "answer": "__STRUCTURED_SUMMARY__",
197
+ "summary": summary_dict,
198
+ "sources": [{"judgment_id": os.path.basename(request.filepath), "score": 1.0}]
199
+ }
200
+
201
+ # ── Route 2: Document QA ────────────────────────────────────────────
202
+ if request.filepath:
203
+ print("[BACKEND] → Document QA")
204
+ return query_engine.query_with_document(request.message, request.filepath, chat_history=request.chat_history)
205
+
206
+ # ── Route 3: Global RAG ─────────────────────────────────────────────
207
+ print("[BACKEND] → Global RAG")
208
+ return query_engine.query(request.message, top_k=request.top_k, chat_history=request.chat_history)
209
+
210
+ except Exception as e:
211
+ print(f"[BACKEND ERROR] {e}")
212
+ raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
213
+
214
+
215
+ # ── Cleanup old uploads (files older than 2 hours) ─────────────────────────
216
+ @app.delete("/api/upload/{filename}")
217
+ def delete_upload(filename: str):
218
+ """Explicit delete for a specific upload."""
219
+ target = UPLOAD_DIR / filename
220
+ if target.exists() and target.is_file():
221
+ target.unlink()
222
+ return {"status": "deleted"}
223
+ raise HTTPException(status_code=404, detail="File not found.")
src/evaluation/__pycache__/evaluator.cpython-310.pyc ADDED
Binary file (7.86 kB). View file
 
src/extraction/__pycache__/batch_processor.cpython-310.pyc ADDED
Binary file (6.6 kB). View file
 
src/extraction/__pycache__/pdf_extractor.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
src/extraction/batch_processor.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # Batch processor for large-scale PDF extraction
3
+ # Includes progress tracking, error handling, and resumability
4
+ # """
5
+
6
+ # from pathlib import Path
7
+ # from typing import List, Dict
8
+ # import logging
9
+ # from tqdm import tqdm
10
+ # import json
11
+ # from datetime import datetime
12
+ # from pdf_extractor import LegalJudgmentExtractor
13
+
14
+ # logging.basicConfig(
15
+ # level=logging.INFO,
16
+ # format='%(asctime)s - %(levelname)s - %(message)s'
17
+ # )
18
+ # logger = logging.getLogger(__name__)
19
+
20
+
21
+ # class BatchProcessor:
22
+ # """
23
+ # Batch processor for extracting text from thousands of legal judgments
24
+ # Features: Progress tracking, resumability, comprehensive reporting
25
+ # """
26
+
27
+ # def __init__(self,
28
+ # input_dir: Path,
29
+ # output_dir: Path,
30
+ # num_workers: int = 1,
31
+ # enable_ocr: bool = True):
32
+
33
+ # self.input_dir = Path(input_dir)
34
+ # self.output_dir = Path(output_dir)
35
+ # self.num_workers = num_workers
36
+
37
+ # # Create extractor
38
+ # self.extractor = LegalJudgmentExtractor(output_dir, enable_ocr=enable_ocr)
39
+
40
+ # # Progress tracking
41
+ # self.progress_file = output_dir / "processing_progress.json"
42
+ # self.stats_file = output_dir / "processing_stats.json"
43
+
44
+ # def get_all_pdfs(self) -> List[Path]:
45
+ # """Get all PDF files from input directory"""
46
+ # return sorted(self.input_dir.rglob("*.pdf"))
47
+
48
+ # def load_progress(self) -> set:
49
+ # """Load already processed files"""
50
+ # if self.progress_file.exists():
51
+ # with open(self.progress_file, 'r') as f:
52
+ # data = json.load(f)
53
+ # return set(data.get('processed_files', []))
54
+ # return set()
55
+
56
+ # def save_progress(self, processed_files: set):
57
+ # """Save processing progress"""
58
+ # with open(self.progress_file, 'w') as f:
59
+ # json.dump({
60
+ # 'processed_files': list(processed_files),
61
+ # 'last_updated': datetime.now().isoformat(),
62
+ # 'total_processed': len(processed_files)
63
+ # }, f, indent=2)
64
+
65
+ # def process_single_pdf(self, pdf_path: Path) -> Dict:
66
+ # """Process a single PDF"""
67
+ # try:
68
+ # success = self.extractor.process_pdf(pdf_path)
69
+ # return {
70
+ # 'filename': pdf_path.name,
71
+ # 'year': pdf_path.parent.name,
72
+ # 'success': success,
73
+ # 'error': None
74
+ # }
75
+ # except Exception as e:
76
+ # return {
77
+ # 'filename': pdf_path.name,
78
+ # 'year': pdf_path.parent.name,
79
+ # 'success': False,
80
+ # 'error': str(e)
81
+ # }
82
+
83
+ # def process_batch(self,
84
+ # start_year: int = None,
85
+ # end_year: int = None,
86
+ # limit: int = None,
87
+ # resume: bool = True):
88
+ # """
89
+ # Process PDFs in batch with progress tracking
90
+
91
+ # Args:
92
+ # start_year: Start from this year (inclusive)
93
+ # end_year: Process until this year (inclusive)
94
+ # limit: Maximum number of PDFs to process
95
+ # resume: Continue from last checkpoint
96
+ # """
97
+
98
+ # logger.info("Starting batch processing...")
99
+ # logger.info(f"Workers: {self.num_workers}")
100
+
101
+ # # Get all PDFs
102
+ # all_pdfs = self.get_all_pdfs()
103
+ # logger.info(f"Found {len(all_pdfs):,} PDFs")
104
+
105
+ # # Filter by year if specified
106
+ # if start_year or end_year:
107
+ # all_pdfs = [
108
+ # p for p in all_pdfs
109
+ # if (not start_year or int(p.parent.name) >= start_year) and
110
+ # (not end_year or int(p.parent.name) <= end_year)
111
+ # ]
112
+ # logger.info(f"Filtered to {len(all_pdfs):,} PDFs (years {start_year}-{end_year})")
113
+
114
+ # # Load progress and filter already processed
115
+ # if resume:
116
+ # processed = self.load_progress()
117
+ # all_pdfs = [p for p in all_pdfs if str(p) not in processed]
118
+ # logger.info(f"Resuming: {len(all_pdfs):,} PDFs remaining")
119
+ # else:
120
+ # processed = set()
121
+
122
+ # # Apply limit
123
+ # if limit:
124
+ # all_pdfs = all_pdfs[:limit]
125
+ # logger.info(f"Limited to {len(all_pdfs):,} PDFs")
126
+
127
+ # if not all_pdfs:
128
+ # logger.info("No PDFs to process!")
129
+ # return
130
+
131
+ # # Initialize stats
132
+ # stats = {
133
+ # 'total': len(all_pdfs),
134
+ # 'successful': 0,
135
+ # 'failed': 0,
136
+ # 'start_time': datetime.now().isoformat(),
137
+ # 'failed_files': []
138
+ # }
139
+
140
+ # # Process with progress bar
141
+ # with tqdm(total=len(all_pdfs), desc="Processing PDFs") as pbar:
142
+ # for pdf_path in all_pdfs:
143
+ # result = self.process_single_pdf(pdf_path)
144
+
145
+ # if result['success']:
146
+ # stats['successful'] += 1
147
+ # else:
148
+ # stats['failed'] += 1
149
+ # stats['failed_files'].append({
150
+ # 'file': result['filename'],
151
+ # 'error': result['error']
152
+ # })
153
+ # logger.warning(f"Failed: {result['filename']} - {result['error']}")
154
+
155
+ # # Update progress
156
+ # processed.add(str(pdf_path))
157
+
158
+ # # Save progress every 50 files
159
+ # if len(processed) % 50 == 0:
160
+ # self.save_progress(processed)
161
+
162
+ # pbar.update(1)
163
+ # pbar.set_postfix({
164
+ # 'Success': stats['successful'],
165
+ # 'Failed': stats['failed']
166
+ # })
167
+
168
+ # # Final save
169
+ # self.save_progress(processed)
170
+
171
+ # # Save statistics
172
+ # stats['end_time'] = datetime.now().isoformat()
173
+ # stats['success_rate'] = (stats['successful'] / stats['total'] * 100) if stats['total'] > 0 else 0
174
+
175
+ # with open(self.stats_file, 'w') as f:
176
+ # json.dump(stats, f, indent=2)
177
+
178
+ # # Summary
179
+ # logger.info("\n" + "="*60)
180
+ # logger.info("PROCESSING COMPLETE")
181
+ # logger.info("="*60)
182
+ # logger.info(f"Total processed: {stats['total']:,}")
183
+ # logger.info(f"Successful: {stats['successful']:,}")
184
+ # logger.info(f"Failed: {stats['failed']:,}")
185
+ # logger.info(f"Success rate: {stats['success_rate']:.2f}%")
186
+ # logger.info("="*60)
187
+
188
+
189
+ # def main():
190
+ # """Main execution with CLI arguments"""
191
+ # import argparse
192
+
193
+ # parser = argparse.ArgumentParser(description='Batch process legal judgment PDFs')
194
+ # parser.add_argument('--start-year', type=int, help='Start year (inclusive)')
195
+ # parser.add_argument('--end-year', type=int, help='End year (inclusive)')
196
+ # parser.add_argument('--limit', type=int, help='Maximum PDFs to process')
197
+ # parser.add_argument('--no-resume', action='store_true', help='Start fresh')
198
+ # parser.add_argument('--no-ocr', action='store_true', help='Disable OCR fallback')
199
+
200
+ # args = parser.parse_args()
201
+
202
+ # # Configuration
203
+ # INPUT_DIR = Path("data/raw")
204
+ # OUTPUT_DIR = Path("data/processed/extracted")
205
+
206
+ # processor = BatchProcessor(
207
+ # input_dir=INPUT_DIR,
208
+ # output_dir=OUTPUT_DIR,
209
+ # enable_ocr=not args.no_ocr
210
+ # )
211
+
212
+ # processor.process_batch(
213
+ # start_year=args.start_year,
214
+ # end_year=args.end_year,
215
+ # limit=args.limit,
216
+ # resume=not args.no_resume
217
+ # )
218
+
219
+
220
+ # if __name__ == "__main__":
221
+ # main()
222
+ """
223
+ Batch processor for large-scale PDF extraction
224
+ Includes progress tracking, error handling, and resumability
225
+ """
226
+
227
+ from pathlib import Path
228
+ from typing import List, Dict
229
+ import logging
230
+ from tqdm import tqdm
231
+ import json
232
+ from datetime import datetime
233
+ from pdf_extractor import LegalJudgmentExtractor
234
+
235
+ logging.basicConfig(
236
+ level=logging.INFO,
237
+ format='%(asctime)s - %(levelname)s - %(message)s'
238
+ )
239
+ logger = logging.getLogger(__name__)
240
+
241
+
242
+ class BatchProcessor:
243
+ """
244
+ Batch processor for extracting text from thousands of legal judgments
245
+ Features: Progress tracking, resumability, comprehensive reporting
246
+ """
247
+
248
+ def __init__(self,
249
+ input_dir: Path,
250
+ output_dir: Path,
251
+ num_workers: int = 1,
252
+ enable_ocr: bool = True):
253
+
254
+ self.input_dir = Path(input_dir)
255
+ self.output_dir = Path(output_dir)
256
+ self.num_workers = num_workers
257
+
258
+ # Create extractor
259
+ self.extractor = LegalJudgmentExtractor(output_dir, enable_ocr=enable_ocr)
260
+
261
+ # Progress tracking
262
+ self.progress_file = output_dir / "processing_progress.json"
263
+ self.stats_file = output_dir / "processing_stats.json"
264
+
265
+ def get_all_pdfs(self) -> List[Path]:
266
+ """Get all PDF files (case-insensitive)"""
267
+ # FIX: Search for both .pdf and .PDF
268
+ pdfs_lower = list(self.input_dir.rglob("*.pdf"))
269
+ pdfs_upper = list(self.input_dir.rglob("*.PDF"))
270
+ all_pdfs = pdfs_lower + pdfs_upper
271
+ return sorted(set(all_pdfs)) # Remove duplicates and sort
272
+
273
+ def load_progress(self) -> set:
274
+ """Load already processed files"""
275
+ if self.progress_file.exists():
276
+ with open(self.progress_file, 'r') as f:
277
+ data = json.load(f)
278
+ return set(data.get('processed_files', []))
279
+ return set()
280
+
281
+ def save_progress(self, processed_files: set):
282
+ """Save processing progress"""
283
+ with open(self.progress_file, 'w') as f:
284
+ json.dump({
285
+ 'processed_files': list(processed_files),
286
+ 'last_updated': datetime.now().isoformat(),
287
+ 'total_processed': len(processed_files)
288
+ }, f, indent=2)
289
+
290
+ def process_single_pdf(self, pdf_path: Path) -> Dict:
291
+ """Process a single PDF"""
292
+ try:
293
+ success = self.extractor.process_pdf(pdf_path)
294
+ return {
295
+ 'filename': pdf_path.name,
296
+ 'year': pdf_path.parent.name,
297
+ 'success': success,
298
+ 'error': None
299
+ }
300
+ except Exception as e:
301
+ return {
302
+ 'filename': pdf_path.name,
303
+ 'year': pdf_path.parent.name,
304
+ 'success': False,
305
+ 'error': str(e)
306
+ }
307
+
308
+ def process_batch(self,
309
+ start_year: int = None,
310
+ end_year: int = None,
311
+ limit: int = None,
312
+ resume: bool = True):
313
+ """
314
+ Process PDFs in batch with progress tracking
315
+
316
+ Args:
317
+ start_year: Start from this year (inclusive)
318
+ end_year: Process until this year (inclusive)
319
+ limit: Maximum number of PDFs to process
320
+ resume: Continue from last checkpoint
321
+ """
322
+
323
+ logger.info("Starting batch processing...")
324
+ logger.info(f"Input directory: {self.input_dir}")
325
+ logger.info(f"Output directory: {self.output_dir}")
326
+
327
+ # Get all PDFs
328
+ all_pdfs = self.get_all_pdfs()
329
+ logger.info(f"Found {len(all_pdfs):,} PDFs")
330
+
331
+ if len(all_pdfs) == 0:
332
+ logger.error("❌ No PDFs found! Check your data/raw directory.")
333
+ logger.error(f"Looking in: {self.input_dir}")
334
+ logger.error("Make sure PDFs are in year folders like: data/raw/1950/*.PDF")
335
+ return
336
+
337
+ # Filter by year if specified
338
+ if start_year or end_year:
339
+ filtered_pdfs = []
340
+ for p in all_pdfs:
341
+ try:
342
+ year = int(p.parent.name)
343
+ if (not start_year or year >= start_year) and (not end_year or year <= end_year):
344
+ filtered_pdfs.append(p)
345
+ except ValueError:
346
+ logger.warning(f"Skipping non-year folder: {p.parent.name}")
347
+
348
+ all_pdfs = filtered_pdfs
349
+ logger.info(f"Filtered to {len(all_pdfs):,} PDFs (years {start_year}-{end_year})")
350
+
351
+ # Load progress and filter already processed
352
+ if resume:
353
+ processed = self.load_progress()
354
+ all_pdfs = [p for p in all_pdfs if str(p) not in processed]
355
+ logger.info(f"Resuming: {len(all_pdfs):,} PDFs remaining")
356
+ else:
357
+ processed = set()
358
+
359
+ # Apply limit
360
+ if limit:
361
+ all_pdfs = all_pdfs[:limit]
362
+ logger.info(f"Limited to {len(all_pdfs):,} PDFs")
363
+
364
+ if not all_pdfs:
365
+ logger.info("No PDFs to process!")
366
+ return
367
+
368
+ # Initialize stats
369
+ stats = {
370
+ 'total': len(all_pdfs),
371
+ 'successful': 0,
372
+ 'failed': 0,
373
+ 'start_time': datetime.now().isoformat(),
374
+ 'failed_files': []
375
+ }
376
+
377
+ # Process with progress bar
378
+ with tqdm(total=len(all_pdfs), desc="Processing PDFs") as pbar:
379
+ for pdf_path in all_pdfs:
380
+ result = self.process_single_pdf(pdf_path)
381
+
382
+ if result['success']:
383
+ stats['successful'] += 1
384
+ else:
385
+ stats['failed'] += 1
386
+ stats['failed_files'].append({
387
+ 'file': result['filename'],
388
+ 'year': result['year'],
389
+ 'error': result['error']
390
+ })
391
+ logger.warning(f"Failed: {result['filename']} - {result['error']}")
392
+
393
+ # Update progress
394
+ processed.add(str(pdf_path))
395
+
396
+ # Save progress every 50 files
397
+ if len(processed) % 50 == 0:
398
+ self.save_progress(processed)
399
+
400
+ pbar.update(1)
401
+ pbar.set_postfix({
402
+ 'Success': stats['successful'],
403
+ 'Failed': stats['failed']
404
+ })
405
+
406
+ # Final save
407
+ self.save_progress(processed)
408
+
409
+ # Save statistics
410
+ stats['end_time'] = datetime.now().isoformat()
411
+ stats['success_rate'] = (stats['successful'] / stats['total'] * 100) if stats['total'] > 0 else 0
412
+
413
+ with open(self.stats_file, 'w') as f:
414
+ json.dump(stats, f, indent=2)
415
+
416
+ # Summary
417
+ logger.info("\n" + "="*60)
418
+ logger.info("PROCESSING COMPLETE")
419
+ logger.info("="*60)
420
+ logger.info(f"Total processed: {stats['total']:,}")
421
+ logger.info(f"Successful: {stats['successful']:,}")
422
+ logger.info(f"Failed: {stats['failed']:,}")
423
+ logger.info(f"Success rate: {stats['success_rate']:.2f}%")
424
+ logger.info("="*60)
425
+
426
+
427
+ def main():
428
+ """Main execution with CLI arguments"""
429
+ import argparse
430
+
431
+ parser = argparse.ArgumentParser(description='Batch process legal judgment PDFs')
432
+ parser.add_argument('--start-year', type=int, help='Start year (inclusive)')
433
+ parser.add_argument('--end-year', type=int, help='End year (inclusive)')
434
+ parser.add_argument('--limit', type=int, help='Maximum PDFs to process')
435
+ parser.add_argument('--no-resume', action='store_true', help='Start fresh')
436
+ parser.add_argument('--no-ocr', action='store_true', help='Disable OCR fallback')
437
+
438
+ args = parser.parse_args()
439
+
440
+ # Configuration
441
+ INPUT_DIR = Path("data/raw")
442
+ OUTPUT_DIR = Path("data/processed/extracted")
443
+
444
+ processor = BatchProcessor(
445
+ input_dir=INPUT_DIR,
446
+ output_dir=OUTPUT_DIR,
447
+ enable_ocr=not args.no_ocr
448
+ )
449
+
450
+ processor.process_batch(
451
+ start_year=args.start_year,
452
+ end_year=args.end_year,
453
+ limit=args.limit,
454
+ resume=not args.no_resume
455
+ )
456
+
457
+
458
+ if __name__ == "__main__":
459
+ main()
src/extraction/pdf_extractor.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Production PDF Extractor for Legal Judgments
3
+ Enhanced with robust error handling, quality checks, and paragraph preservation
4
+ """
5
+
6
+ import PyPDF2
7
+ import pdfplumber
8
+ from pathlib import Path
9
+ from typing import Dict, Optional, List, Tuple
10
+ import logging
11
+ from dataclasses import dataclass, asdict
12
+ import json
13
+ from datetime import datetime
14
+ import re
15
+
16
+ # OCR imports
17
+ try:
18
+ import pytesseract
19
+ from pdf2image import convert_from_path
20
+ OCR_AVAILABLE = True
21
+ except ImportError:
22
+ OCR_AVAILABLE = False
23
+ logging.warning("OCR libraries not installed. OCR fallback disabled.")
24
+
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class ExtractionMetadata:
31
+ """Metadata for extracted judgment"""
32
+ filename: str
33
+ year: str
34
+ num_pages: int
35
+ text_length: int
36
+ extraction_method: str
37
+ has_text: bool
38
+ extraction_timestamp: str
39
+ file_size_bytes: int
40
+ ocr_used: bool
41
+ quality_score: float
42
+ paragraph_count: int
43
+ errors: List[str]
44
+ warnings: List[str]
45
+
46
+
47
+ class TextQualityChecker:
48
+ """Utility class for assessing extracted text quality"""
49
+
50
+ # Legal keywords to preserve even in short lines
51
+ LEGAL_KEYWORDS = {
52
+ 'held', 'order', 'appeal', 'writ', 'judgment', 'decree',
53
+ 'petition', 'application', 'allowed', 'dismissed', 'granted',
54
+ 'rejected', 'reserved', 'disposed', 'quashed', 'set aside',
55
+ 'affirmed', 'reversed', 'remanded', 'suo moto', 'ex parte',
56
+ 'interim', 'stay', 'injunction', 'bail', 'custody', 'liberty',
57
+ 'notice', 'respondent', 'petitioner', 'appellant', 'accused'
58
+ }
59
+
60
+ @staticmethod
61
+ def calculate_quality_score(text: str) -> Tuple[float, List[str]]:
62
+ """
63
+ Calculate quality score (0-1) for extracted text
64
+
65
+ Returns:
66
+ (score, issues_found)
67
+ """
68
+ if not text or len(text.strip()) < 100:
69
+ return 0.0, ["Text too short"]
70
+
71
+ issues = []
72
+ score = 1.0
73
+
74
+ # Check 1: Alphabetic character ratio
75
+ alpha_chars = sum(c.isalpha() for c in text)
76
+ total_chars = len(text.replace('\n', '').replace(' ', ''))
77
+
78
+ if total_chars > 0:
79
+ alpha_ratio = alpha_chars / total_chars
80
+ if alpha_ratio < 0.5:
81
+ score -= 0.3
82
+ issues.append(f"Low alphabetic ratio: {alpha_ratio:.2f}")
83
+
84
+ # Check 2: Average word length (gibberish detection)
85
+ words = text.split()
86
+ if words:
87
+ avg_word_len = sum(len(w) for w in words) / len(words)
88
+ if avg_word_len < 2 or avg_word_len > 15:
89
+ score -= 0.2
90
+ issues.append(f"Unusual avg word length: {avg_word_len:.1f}")
91
+
92
+ # Check 3: Check for repeated patterns (OCR errors)
93
+ lines = text.split('\n')
94
+ if len(lines) > 10:
95
+ unique_lines = len(set(line.strip() for line in lines if line.strip()))
96
+ repetition_ratio = unique_lines / len(lines)
97
+ if repetition_ratio < 0.3:
98
+ score -= 0.2
99
+ issues.append(f"High repetition: {repetition_ratio:.2f}")
100
+
101
+ # Check 4: Minimum sentence structure
102
+ sentence_markers = text.count('.') + text.count('?') + text.count('!')
103
+ if len(words) > 100 and sentence_markers < len(words) / 50:
104
+ score -= 0.1
105
+ issues.append("Lacks sentence structure")
106
+
107
+ return max(0.0, min(1.0, score)), issues
108
+
109
+ @staticmethod
110
+ def clean_ocr_text(text: str) -> str:
111
+ """
112
+ Normalize OCR-extracted text with legal-aware filtering
113
+ - Remove excessive whitespace
114
+ - Collapse multiple newlines
115
+ - Remove repeated headers
116
+ - Preserve important legal terms
117
+ """
118
+ # Collapse multiple spaces
119
+ text = re.sub(r' +', ' ', text)
120
+
121
+ # Collapse multiple newlines (keep max 2 for paragraph breaks)
122
+ text = re.sub(r'\n{3,}', '\n\n', text)
123
+
124
+ # Remove common OCR artifacts
125
+ text = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f]', '', text)
126
+
127
+ # Legal-aware line filtering
128
+ lines = text.split('\n')
129
+ cleaned_lines = []
130
+
131
+ for line in lines:
132
+ stripped = line.strip()
133
+
134
+ # Skip empty lines
135
+ if not stripped:
136
+ continue
137
+
138
+ # Skip pure numbers (page numbers)
139
+ if stripped.isdigit():
140
+ continue
141
+
142
+ # PRESERVE if:
143
+ # 1. Line is substantial (>10 chars)
144
+ # 2. Contains legal keyword (even if short like "Held.")
145
+ # 3. Is alphabetic and reasonable length (>3 chars)
146
+ if (len(stripped) > 10 or
147
+ any(keyword in stripped.lower() for keyword in TextQualityChecker.LEGAL_KEYWORDS) or
148
+ (stripped.replace('.', '').replace(',', '').isalpha() and len(stripped) > 3)):
149
+ cleaned_lines.append(line)
150
+
151
+ text = '\n'.join(cleaned_lines)
152
+
153
+ # Remove repeated header patterns
154
+ lines = text.split('\n')
155
+ result = []
156
+ prev_line = None
157
+ repeat_count = 0
158
+
159
+ for line in lines:
160
+ if line.strip() == prev_line and prev_line:
161
+ repeat_count += 1
162
+ if repeat_count < 2: # Allow max 2 repetitions
163
+ result.append(line)
164
+ else:
165
+ repeat_count = 0
166
+ result.append(line)
167
+ prev_line = line.strip()
168
+
169
+ return '\n'.join(result).strip()
170
+
171
+
172
+ class LegalJudgmentExtractor:
173
+ """
174
+ Production-grade extractor with robust error handling and quality assurance
175
+ """
176
+
177
+ def __init__(self, output_dir: Path, enable_ocr: bool = True, ocr_max_pages: int = 50):
178
+ self.output_dir = Path(output_dir)
179
+ self.output_dir.mkdir(parents=True, exist_ok=True)
180
+ self.enable_ocr = enable_ocr and OCR_AVAILABLE
181
+ self.ocr_max_pages = ocr_max_pages
182
+
183
+ if enable_ocr and not OCR_AVAILABLE:
184
+ logger.warning("OCR requested but libraries not installed.")
185
+
186
+ # Create subdirectories
187
+ self.text_dir = self.output_dir / "texts"
188
+ self.metadata_dir = self.output_dir / "metadata"
189
+ self.failed_dir = self.output_dir / "failed"
190
+ self.ocr_log_file = self.output_dir / "ocr_cases.jsonl"
191
+
192
+ for dir_path in [self.text_dir, self.metadata_dir, self.failed_dir]:
193
+ dir_path.mkdir(parents=True, exist_ok=True)
194
+
195
+ def extract_year_from_path(self, pdf_path: Path) -> Tuple[str, List[str]]:
196
+ """
197
+ Safely extract year from path with validation
198
+
199
+ Returns:
200
+ (year, warnings)
201
+ """
202
+ warnings = []
203
+ year = pdf_path.parent.name
204
+
205
+ # Validate year
206
+ if not year.isdigit():
207
+ warnings.append(f"Invalid year from directory: {year}")
208
+
209
+ # Try to extract from filename
210
+ filename = pdf_path.stem
211
+ year_match = re.search(r'(19|20)\d{2}', filename)
212
+ if year_match:
213
+ year = year_match.group(0)
214
+ warnings.append(f"Year extracted from filename: {year}")
215
+ else:
216
+ year = "unknown"
217
+ warnings.append("Could not determine year")
218
+ else:
219
+ # Validate year range
220
+ year_int = int(year)
221
+ if year_int < 1950 or year_int > 2025:
222
+ warnings.append(f"Year {year} outside expected range (1950-2025)")
223
+
224
+ return year, warnings
225
+
226
+ def count_paragraphs(self, text: str) -> int:
227
+ """Count paragraph-like structures in text"""
228
+ # Split by double newlines
229
+ paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
230
+ # Filter out very short "paragraphs" (likely headers)
231
+ substantial_paragraphs = [p for p in paragraphs if len(p) > 50]
232
+ return len(substantial_paragraphs)
233
+
234
+ def extract_with_pypdf2(self, pdf_path: Path) -> Optional[str]:
235
+ """Primary extraction - preserves paragraph structure"""
236
+ try:
237
+ with open(pdf_path, 'rb') as file:
238
+ reader = PyPDF2.PdfReader(file)
239
+ text_parts = []
240
+
241
+ for page in reader.pages:
242
+ text = page.extract_text()
243
+ if text:
244
+ text_parts.append(text.strip())
245
+
246
+ # Join with double newline to preserve page breaks
247
+ full_text = "\n\n".join(text_parts)
248
+
249
+ # Quality check
250
+ score, _ = TextQualityChecker.calculate_quality_score(full_text)
251
+ return full_text if score > 0.3 else None
252
+
253
+ except Exception as e:
254
+ logger.debug(f"PyPDF2 failed for {pdf_path.name}: {e}")
255
+ return None
256
+
257
+ def extract_with_pdfplumber(self, pdf_path: Path) -> Optional[str]:
258
+ """Fallback extraction - better for complex layouts"""
259
+ try:
260
+ with pdfplumber.open(pdf_path) as pdf:
261
+ text_parts = []
262
+
263
+ for page in pdf.pages:
264
+ text = page.extract_text()
265
+ if text:
266
+ text_parts.append(text.strip())
267
+
268
+ full_text = "\n\n".join(text_parts)
269
+
270
+ score, _ = TextQualityChecker.calculate_quality_score(full_text)
271
+ return full_text if score > 0.3 else None
272
+
273
+ except Exception as e:
274
+ logger.debug(f"pdfplumber failed for {pdf_path.name}: {e}")
275
+ return None
276
+
277
+ def extract_with_ocr(self, pdf_path: Path, num_pages: int) -> Optional[str]:
278
+ """
279
+ OCR extraction with proper page limiting and text normalization
280
+
281
+ Args:
282
+ pdf_path: Path to PDF
283
+ num_pages: Total pages in PDF (for proper limiting)
284
+ """
285
+ if not self.enable_ocr:
286
+ return None
287
+
288
+ try:
289
+ logger.info(f"OCR extraction: {pdf_path.name}")
290
+
291
+ # Proper page limiting
292
+ last_page = min(self.ocr_max_pages, num_pages)
293
+
294
+ if num_pages > self.ocr_max_pages:
295
+ logger.warning(f"PDF has {num_pages} pages, OCR limited to first {self.ocr_max_pages}")
296
+
297
+ # Convert to images
298
+ images = convert_from_path(
299
+ pdf_path,
300
+ dpi=300,
301
+ first_page=1,
302
+ last_page=last_page
303
+ )
304
+
305
+ text_parts = []
306
+ for i, image in enumerate(images, 1):
307
+ logger.debug(f"OCR page {i}/{len(images)}")
308
+
309
+ text = pytesseract.image_to_string(image, lang='eng')
310
+ if text.strip():
311
+ text_parts.append(text)
312
+
313
+ full_text = "\n\n".join(text_parts)
314
+
315
+ # Normalize OCR text
316
+ full_text = TextQualityChecker.clean_ocr_text(full_text)
317
+
318
+ # Check quality
319
+ score, issues = TextQualityChecker.calculate_quality_score(full_text)
320
+
321
+ if score > 0.3:
322
+ # Log successful OCR to JSONL
323
+ self._log_ocr_case(pdf_path, num_pages, last_page, score)
324
+ logger.info(f"✓ OCR successful (quality: {score:.2f})")
325
+ return full_text
326
+ else:
327
+ logger.warning(f"OCR quality too low ({score:.2f}): {issues}")
328
+ return None
329
+
330
+ except Exception as e:
331
+ logger.warning(f"OCR failed for {pdf_path.name}: {e}")
332
+ return None
333
+
334
+ def _log_ocr_case(self, pdf_path: Path, total_pages: int, pages_processed: int, quality: float):
335
+ """Log OCR usage to JSONL file"""
336
+ log_entry = {
337
+ 'timestamp': datetime.now().isoformat(),
338
+ 'filename': pdf_path.name,
339
+ 'year': pdf_path.parent.name,
340
+ 'total_pages': total_pages,
341
+ 'pages_processed': pages_processed,
342
+ 'quality_score': quality
343
+ }
344
+
345
+ with open(self.ocr_log_file, 'a', encoding='utf-8') as f:
346
+ f.write(json.dumps(log_entry) + '\n')
347
+
348
+ def extract_pdf(self, pdf_path: Path) -> Dict:
349
+ """
350
+ Main extraction with fallback chain and quality assurance
351
+ """
352
+ errors = []
353
+ warnings = []
354
+ text = None
355
+ method = None
356
+ ocr_used = False
357
+ quality_score = 0.0
358
+
359
+ # Get metadata
360
+ file_size = pdf_path.stat().st_size
361
+
362
+ # Robust year extraction
363
+ year, year_warnings = self.extract_year_from_path(pdf_path)
364
+ warnings.extend(year_warnings)
365
+
366
+ # Count pages first (needed for OCR)
367
+ try:
368
+ with open(pdf_path, 'rb') as f:
369
+ reader = PyPDF2.PdfReader(f)
370
+ num_pages = len(reader.pages)
371
+ except Exception as e:
372
+ num_pages = 0
373
+ errors.append(f"Could not count pages: {e}")
374
+
375
+ # Extraction chain: PyPDF2 → pdfplumber → OCR
376
+ text = self.extract_with_pypdf2(pdf_path)
377
+ if text:
378
+ method = "pypdf2"
379
+ else:
380
+ errors.append("PyPDF2 insufficient")
381
+
382
+ text = self.extract_with_pdfplumber(pdf_path)
383
+ if text:
384
+ method = "pdfplumber"
385
+ else:
386
+ errors.append("pdfplumber failed")
387
+
388
+ if self.enable_ocr and num_pages > 0:
389
+ text = self.extract_with_ocr(pdf_path, num_pages)
390
+ if text:
391
+ method = "ocr"
392
+ ocr_used = True
393
+ warnings.append("OCR used - verify quality")
394
+ else:
395
+ errors.append("OCR failed")
396
+
397
+ # Calculate quality
398
+ paragraph_count = 0
399
+ if text:
400
+ quality_score, quality_issues = TextQualityChecker.calculate_quality_score(text)
401
+ paragraph_count = self.count_paragraphs(text)
402
+
403
+ if quality_score < 0.7:
404
+ warnings.extend(quality_issues)
405
+
406
+ # Create metadata
407
+ metadata = ExtractionMetadata(
408
+ filename=pdf_path.name,
409
+ year=year,
410
+ num_pages=num_pages,
411
+ text_length=len(text) if text else 0,
412
+ extraction_method=method if method else "failed",
413
+ has_text=text is not None,
414
+ extraction_timestamp=datetime.now().isoformat(),
415
+ file_size_bytes=file_size,
416
+ ocr_used=ocr_used,
417
+ quality_score=quality_score,
418
+ paragraph_count=paragraph_count,
419
+ errors=errors,
420
+ warnings=warnings
421
+ )
422
+
423
+ return {
424
+ 'text': text,
425
+ 'metadata': metadata
426
+ }
427
+
428
+ def save_extraction(self, pdf_path: Path, extraction_result: Dict) -> bool:
429
+ """Save with quality indicators"""
430
+
431
+ metadata = extraction_result['metadata']
432
+ text = extraction_result['text']
433
+
434
+ base_name = pdf_path.stem
435
+ year = metadata.year
436
+
437
+ # Save text
438
+ if text:
439
+ text_file = self.text_dir / f"{year}_{base_name}.txt"
440
+ try:
441
+ with open(text_file, 'w', encoding='utf-8') as f:
442
+ # Add quality header
443
+ f.write(f"{'='*70}\n")
444
+ f.write(f"File: {metadata.filename}\n")
445
+ f.write(f"Extraction: {metadata.extraction_method}\n")
446
+ f.write(f"Quality: {metadata.quality_score:.2f}\n")
447
+ f.write(f"Paragraphs: {metadata.paragraph_count}\n")
448
+
449
+ if metadata.ocr_used:
450
+ f.write("⚠️ OCR USED - Verify important details\n")
451
+
452
+ if metadata.warnings:
453
+ f.write(f"Warnings: {', '.join(metadata.warnings[:3])}\n")
454
+
455
+ f.write(f"{'='*70}\n\n")
456
+ f.write(text)
457
+
458
+ except Exception as e:
459
+ logger.error(f"Failed to save text: {e}")
460
+ return False
461
+
462
+ # Save metadata
463
+ metadata_file = self.metadata_dir / f"{year}_{base_name}.json"
464
+ try:
465
+ with open(metadata_file, 'w', encoding='utf-8') as f:
466
+ json.dump(asdict(metadata), f, indent=2)
467
+ except Exception as e:
468
+ logger.error(f"Failed to save metadata: {e}")
469
+ return False
470
+
471
+ # Log failures
472
+ if not text:
473
+ failed_log = self.failed_dir / "failed_extractions.jsonl"
474
+ with open(failed_log, 'a', encoding='utf-8') as f:
475
+ log_entry = {
476
+ 'timestamp': datetime.now().isoformat(),
477
+ 'file': str(pdf_path),
478
+ 'errors': metadata.errors
479
+ }
480
+ f.write(json.dumps(log_entry) + '\n')
481
+
482
+ return True
483
+
484
+ def process_pdf(self, pdf_path: Path) -> bool:
485
+ """Process single PDF"""
486
+ try:
487
+ result = self.extract_pdf(pdf_path)
488
+ return self.save_extraction(pdf_path, result)
489
+ except Exception as e:
490
+ logger.error(f"Unexpected error: {pdf_path.name}: {e}")
491
+ return False
492
+
493
+
494
+ if __name__ == "__main__":
495
+ # Test
496
+ print("="*70)
497
+ print("Testing Enhanced PDF Extractor")
498
+ print("="*70)
499
+
500
+ extractor = LegalJudgmentExtractor(
501
+ output_dir=Path("data/processed/extracted"),
502
+ enable_ocr=False
503
+ )
504
+
505
+ test_pdf = Path("data/raw/2025/A_John_Kennedy_vs_The_State_Of_Tamil_Nadu_on_24_March_2025_1.PDF")
506
+
507
+ if test_pdf.exists():
508
+ print(f"\nTesting: {test_pdf.name}")
509
+ success = extractor.process_pdf(test_pdf)
510
+ print(f"\n{'✓' if success else '✗'} Extraction {'successful' if success else 'failed'}")
511
+
512
+ # Show metadata
513
+ metadata_file = Path("data/processed/extracted/metadata") / f"2025_{test_pdf.stem}.json"
514
+ if metadata_file.exists():
515
+ with open(metadata_file, 'r') as f:
516
+ metadata = json.load(f)
517
+ print(f"\nMethod: {metadata['extraction_method']}")
518
+ print(f"Quality: {metadata['quality_score']:.2f}")
519
+ print(f"Paragraphs: {metadata['paragraph_count']}")
520
+ print(f"Text length: {metadata['text_length']:,} chars")
521
+ else:
522
+ print("Test PDF not found")
src/indexing/build_faiss_index.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NyayLens – FAISS Index Builder (Merged Embeddings)
3
+ """
4
+
5
+ import faiss
6
+ import numpy as np
7
+ import json
8
+ from pathlib import Path
9
+
10
+
11
+ def build_faiss_index():
12
+ print("=" * 70)
13
+ print("NyayLens – Building FAISS Index")
14
+ print("=" * 70)
15
+
16
+ embeddings_dir = Path("data/processed/embeddings")
17
+ output_dir = Path("data/processed/faiss")
18
+ output_dir.mkdir(parents=True, exist_ok=True)
19
+
20
+ embeddings_file = embeddings_dir / "paragraph_embeddings.npy"
21
+ ids_file = embeddings_dir / "paragraph_ids.json"
22
+ meta_file = embeddings_dir / "embedding_metadata.json"
23
+
24
+ # --- Load data ---
25
+ print("\nLoading embeddings...")
26
+ embeddings = np.load(embeddings_file)
27
+
28
+ print("Loading paragraph IDs...")
29
+ with open(ids_file, "r") as f:
30
+ paragraph_ids = json.load(f)
31
+
32
+ # --- Safety checks ---
33
+ assert embeddings.shape[0] == len(paragraph_ids), (
34
+ f"Mismatch: {embeddings.shape[0]} embeddings vs "
35
+ f"{len(paragraph_ids)} paragraph IDs"
36
+ )
37
+
38
+ print(f"✓ Loaded {embeddings.shape[0]:,} vectors")
39
+ print(f"✓ Embedding dimension: {embeddings.shape[1]}")
40
+
41
+ # --- Normalize (cosine similarity) ---
42
+ print("\nNormalizing embeddings...")
43
+ faiss.normalize_L2(embeddings)
44
+
45
+ # --- Build FAISS index ---
46
+ dim = embeddings.shape[1]
47
+
48
+ # Switch to HNSW for rapid Approximate Nearest Neighbor search
49
+ # M=32 is number of connections per layer, typical for dense models
50
+ index = faiss.IndexHNSWFlat(dim, 32, faiss.METRIC_INNER_PRODUCT)
51
+ index.hnsw.efConstruction = 40 # Depth of search during index build
52
+ index.hnsw.efSearch = 64 # Depth of search during inference
53
+
54
+ print(f"Adding {embeddings.shape[0]:,} vectors to HNSW index...")
55
+ index.add(embeddings)
56
+
57
+ print(f"✓ FAISS index contains {index.ntotal:,} vectors")
58
+
59
+ # --- Save index ---
60
+ index_file = output_dir / "faiss_index.bin"
61
+ faiss.write_index(index, str(index_file))
62
+
63
+ # --- Save FAISS metadata ---
64
+ with open(meta_file, "r") as f:
65
+ emb_meta = json.load(f)
66
+
67
+ faiss_meta = {
68
+ "index_type": "IndexHNSWFlat",
69
+ "metric": "cosine_similarity",
70
+ "dimension": dim,
71
+ "total_vectors": int(index.ntotal),
72
+ "embedding_model": emb_meta.get("model_name", "unknown"),
73
+ }
74
+
75
+ with open(output_dir / "faiss_metadata.json", "w") as f:
76
+ json.dump(faiss_meta, f, indent=2)
77
+
78
+ print(f"\n✓ FAISS index saved to: {index_file}")
79
+ print("=" * 70)
80
+ print("✓ FAISS Index Build Complete")
81
+ print("=" * 70)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ build_faiss_index()
src/indexing/create_embeddings.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate embeddings in CHUNKS - GPU-safe with resume capability
3
+ Processes 20K paragraphs at a time with cooling breaks
4
+ """
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ import json
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ import torch
12
+ import time
13
+
14
+ class ChunkedEmbeddingGenerator:
15
+ """Generate embeddings in safe chunks with resume capability"""
16
+
17
+ def __init__(self, model_name: str = "BAAI/bge-base-en-v1.5", chunk_size: int = 20000):
18
+ """
19
+ Initialize with sentence transformer
20
+ chunk_size: Process this many paragraphs at a time
21
+ """
22
+ print(f"Loading model: {model_name}")
23
+
24
+ # Check CUDA availability
25
+ if torch.cuda.is_available():
26
+ print(f"✓ CUDA available! GPU: {torch.cuda.get_device_name(0)}")
27
+ print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
28
+ self.device = 'cuda'
29
+ else:
30
+ print("⚠️ CUDA not available, using CPU")
31
+ self.device = 'cpu'
32
+
33
+ self.model = SentenceTransformer(model_name)
34
+ self.model.to(self.device)
35
+ self.chunk_size = chunk_size
36
+ print(f"✓ Model loaded on: {self.device}")
37
+ print(f"✓ Chunk size: {chunk_size:,} paragraphs")
38
+
39
+ def load_paragraphs(self, index_file: Path):
40
+ """Load paragraphs from JSONL"""
41
+ print("\nLoading paragraphs from index...")
42
+ paragraphs = []
43
+
44
+ with open(index_file, 'r', encoding='utf-8') as f:
45
+ total_lines = sum(1 for _ in f)
46
+
47
+ with open(index_file, 'r', encoding='utf-8') as f:
48
+ for line in tqdm(f, total=total_lines, desc="Loading index"):
49
+ para = json.loads(line)
50
+ paragraphs.append(para)
51
+
52
+ print(f"✓ Loaded {len(paragraphs):,} paragraphs")
53
+ return paragraphs
54
+
55
+ def process_chunk(self, texts, batch_size=64):
56
+ """Process one chunk of texts"""
57
+ embeddings_list = []
58
+
59
+ with tqdm(total=len(texts), desc="Processing chunk", unit="para") as pbar:
60
+ for i in range(0, len(texts), batch_size):
61
+ batch = texts[i:i+batch_size]
62
+
63
+ # Generate embeddings
64
+ batch_embeddings = self.model.encode(
65
+ batch,
66
+ convert_to_numpy=True,
67
+ normalize_embeddings=True,
68
+ show_progress_bar=False,
69
+ device=self.device
70
+ )
71
+
72
+ embeddings_list.append(batch_embeddings)
73
+ pbar.update(len(batch))
74
+
75
+ # Clear GPU cache every 10 batches
76
+ if self.device == 'cuda' and i % (batch_size * 10) == 0:
77
+ torch.cuda.empty_cache()
78
+
79
+ return np.vstack(embeddings_list)
80
+
81
+ def generate_embeddings_chunked(self, index_file: Path, output_dir: Path, batch_size: int = 64):
82
+ """Generate embeddings in chunks with resume capability"""
83
+
84
+ output_dir = Path(output_dir)
85
+ output_dir.mkdir(parents=True, exist_ok=True)
86
+
87
+ chunks_dir = output_dir / "chunks"
88
+ chunks_dir.mkdir(exist_ok=True)
89
+
90
+ # Load paragraphs
91
+ paragraphs = self.load_paragraphs(index_file)
92
+ total_paras = len(paragraphs)
93
+
94
+ # Calculate chunks
95
+ num_chunks = (total_paras + self.chunk_size - 1) // self.chunk_size
96
+ print(f"\n✓ Will process in {num_chunks} chunks of {self.chunk_size:,} paragraphs each")
97
+
98
+ # Check for existing chunks
99
+ existing_chunks = list(chunks_dir.glob("chunk_*.npy"))
100
+ completed_chunks = len(existing_chunks)
101
+
102
+ if completed_chunks > 0:
103
+ print(f"✓ Found {completed_chunks} existing chunks")
104
+ response = input(f"Resume from chunk {completed_chunks + 1}? (yes/no): ")
105
+ if response.lower() != 'yes':
106
+ print("Starting fresh...")
107
+ for f in existing_chunks:
108
+ f.unlink()
109
+ completed_chunks = 0
110
+
111
+ # Process chunks
112
+ start_time = time.time()
113
+
114
+ for chunk_idx in range(completed_chunks, num_chunks):
115
+ print(f"\n{'='*70}")
116
+ print(f"CHUNK {chunk_idx + 1}/{num_chunks}")
117
+ print(f"{'='*70}")
118
+
119
+ # Get chunk data
120
+ start_idx = chunk_idx * self.chunk_size
121
+ end_idx = min(start_idx + self.chunk_size, total_paras)
122
+ chunk_paras = paragraphs[start_idx:end_idx]
123
+
124
+ print(f"Processing paragraphs {start_idx:,} to {end_idx:,}")
125
+
126
+ # Extract texts
127
+ texts = [p['text'] for p in chunk_paras]
128
+ ids = [p['id'] for p in chunk_paras]
129
+
130
+ # Process chunk
131
+ chunk_start = time.time()
132
+ embeddings = self.process_chunk(texts, batch_size)
133
+ chunk_time = time.time() - chunk_start
134
+
135
+ print(f"✓ Chunk completed in {chunk_time/60:.1f} minutes")
136
+ print(f" Speed: {len(texts)/chunk_time:.1f} para/s")
137
+
138
+ # Save chunk
139
+ chunk_file = chunks_dir / f"chunk_{chunk_idx:03d}.npy"
140
+ ids_file = chunks_dir / f"chunk_{chunk_idx:03d}_ids.json"
141
+
142
+ np.save(chunk_file, embeddings)
143
+ with open(ids_file, 'w') as f:
144
+ json.dump(ids, f)
145
+
146
+ print(f"✓ Saved to: {chunk_file.name}")
147
+
148
+ # Clear GPU and add cooling break
149
+ if self.device == 'cuda':
150
+ torch.cuda.empty_cache()
151
+ if chunk_idx < num_chunks - 1: # Not last chunk
152
+ print("\n⏸️ Cooling break: 10 seconds...")
153
+ time.sleep(10)
154
+
155
+ # Combine all chunks
156
+ print(f"\n{'='*70}")
157
+ print("COMBINING CHUNKS...")
158
+ print(f"{'='*70}")
159
+
160
+ all_embeddings = []
161
+ all_ids = []
162
+
163
+ for chunk_idx in tqdm(range(num_chunks), desc="Loading chunks"):
164
+ chunk_file = chunks_dir / f"chunk_{chunk_idx:03d}.npy"
165
+ ids_file = chunks_dir / f"chunk_{chunk_idx:03d}_ids.json"
166
+
167
+ embeddings = np.load(chunk_file)
168
+ with open(ids_file, 'r') as f:
169
+ ids = json.load(f)
170
+
171
+ all_embeddings.append(embeddings)
172
+ all_ids.extend(ids)
173
+
174
+ final_embeddings = np.vstack(all_embeddings)
175
+
176
+ print(f"✓ Combined shape: {final_embeddings.shape}")
177
+
178
+ # Save final files
179
+ print("\nSaving final files...")
180
+
181
+ embeddings_file = output_dir / "paragraph_embeddings.npy"
182
+ np.save(embeddings_file, final_embeddings)
183
+ print(f"✓ Saved: {embeddings_file}")
184
+
185
+ ids_file = output_dir / "paragraph_ids.json"
186
+ with open(ids_file, 'w') as f:
187
+ json.dump(all_ids, f)
188
+ print(f"✓ Saved: {ids_file}")
189
+
190
+ # Save metadata
191
+ total_time = time.time() - start_time
192
+ metadata = {
193
+ 'model_name': "BAAI/bge-base-en-v1.5",
194
+ 'embedding_dim': int(final_embeddings.shape[1]),
195
+ 'total_paragraphs': len(all_ids),
196
+ 'device_used': self.device,
197
+ 'batch_size': batch_size,
198
+ 'chunk_size': self.chunk_size,
199
+ 'num_chunks': num_chunks,
200
+ 'total_time_minutes': total_time / 60
201
+ }
202
+
203
+ with open(output_dir / "embedding_metadata.json", 'w') as f:
204
+ json.dump(metadata, f, indent=2)
205
+
206
+ print(f"\n✓ Total time: {total_time/60:.1f} minutes")
207
+ print(f" Average speed: {len(all_ids)/total_time:.1f} para/s")
208
+
209
+ return final_embeddings, all_ids
210
+
211
+
212
+ if __name__ == "__main__":
213
+ print("="*70)
214
+ print("NyayLens - Chunked Embedding Generation (GPU-Safe)")
215
+ print("="*70)
216
+ print()
217
+
218
+ generator = ChunkedEmbeddingGenerator(
219
+ model_name="BAAI/bge-base-en-v1.5",
220
+ chunk_size=20000 # Process 20K at a time
221
+ )
222
+
223
+ embeddings, ids = generator.generate_embeddings_chunked(
224
+ index_file=Path("data/processed/indexed/paragraph_index.jsonl"),
225
+ output_dir=Path("data/processed/embeddings"),
226
+ batch_size=64 # Reduced for safety
227
+ )
228
+
229
+ print()
230
+ print("="*70)
231
+ print(f"✓ COMPLETE! Generated {len(embeddings):,} embeddings")
232
+ print("="*70)
src/indexing/create_sqlite_index.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """Create SQLite index for fast paragraph lookup"""
2
+
3
+ # import sqlite3
4
+ # import json
5
+ # from pathlib import Path
6
+ # from tqdm import tqdm
7
+
8
+ # def create_sqlite_index():
9
+ # print("Creating SQLite index...")
10
+
11
+ # db_path = Path("data/processed/indexed/paragraphs.db")
12
+ # db_path.parent.mkdir(parents=True, exist_ok=True)
13
+
14
+ # # Create database
15
+ # conn = sqlite3.connect(db_path)
16
+ # cursor = conn.cursor()
17
+
18
+ # # Create table
19
+ # cursor.execute("""
20
+ # CREATE TABLE IF NOT EXISTS paragraphs (
21
+ # id TEXT PRIMARY KEY,
22
+ # judgment_id TEXT,
23
+ # page_no INTEGER,
24
+ # text TEXT,
25
+ # char_count INTEGER,
26
+ # word_count INTEGER
27
+ # )
28
+ # """)
29
+
30
+ # cursor.execute("CREATE INDEX IF NOT EXISTS idx_judgment ON paragraphs(judgment_id)")
31
+
32
+ # # Load data
33
+ # index_file = Path("data/processed/indexed/paragraph_index.jsonl")
34
+
35
+ # with open(index_file, 'r', encoding='utf-8') as f:
36
+ # total = sum(1 for _ in f)
37
+
38
+ # with open(index_file, 'r', encoding='utf-8') as f:
39
+ # batch = []
40
+ # for line in tqdm(f, total=total, desc="Inserting"):
41
+ # p = json.loads(line)
42
+ # batch.append((
43
+ # p['id'], p['judgment_id'], p['page_no'],
44
+ # p['text'], p['char_count'], p['word_count']
45
+ # ))
46
+
47
+ # if len(batch) >= 1000:
48
+ # cursor.executemany(
49
+ # "INSERT OR REPLACE INTO paragraphs VALUES (?,?,?,?,?,?)",
50
+ # batch
51
+ # )
52
+ # batch = []
53
+
54
+ # if batch:
55
+ # cursor.executemany(
56
+ # "INSERT OR REPLACE INTO paragraphs VALUES (?,?,?,?,?,?)",
57
+ # batch
58
+ # )
59
+
60
+ # conn.commit()
61
+ # conn.close()
62
+
63
+ # print(f"✓ SQLite index created: {db_path}")
64
+
65
+ # if __name__ == "__main__":
66
+ # create_sqlite_index()
67
+ """
68
+ Create SQLite index with section annotations
69
+ Source: paragraph_index_with_sections.jsonl
70
+ """
71
+
72
+ import sqlite3
73
+ import json
74
+ from pathlib import Path
75
+ from tqdm import tqdm
76
+
77
+
78
+ INPUT_INDEX = Path("data/processed/indexed/paragraph_index_with_sections.jsonl")
79
+ DB_PATH = Path("data/processed/indexed/paragraphs.db")
80
+
81
+
82
+ def create_sqlite_index():
83
+ print("=" * 70)
84
+ print("NyayLens – Creating SQLite Index (with Sections)")
85
+ print("=" * 70)
86
+
87
+ DB_PATH.parent.mkdir(parents=True, exist_ok=True)
88
+
89
+ # Connect to SQLite
90
+ conn = sqlite3.connect(DB_PATH)
91
+ cursor = conn.cursor()
92
+
93
+ # Drop existing table (derived data → safe to rebuild)
94
+ cursor.execute("DROP TABLE IF EXISTS paragraphs")
95
+
96
+ # Create table
97
+ cursor.execute("""
98
+ CREATE TABLE paragraphs (
99
+ id TEXT PRIMARY KEY,
100
+ judgment_id TEXT,
101
+ page_no INTEGER,
102
+ text TEXT,
103
+ char_count INTEGER,
104
+ word_count INTEGER,
105
+ section TEXT,
106
+ section_conf REAL
107
+ )
108
+ """)
109
+
110
+ # Create FTS5 virtual table for fast full-text search (BM25)
111
+ cursor.execute("DROP TABLE IF EXISTS paragraphs_fts")
112
+ cursor.execute("""
113
+ CREATE VIRTUAL TABLE paragraphs_fts USING fts5(
114
+ id UNINDEXED,
115
+ text,
116
+ tokenize='porter unicode61'
117
+ )
118
+ """)
119
+
120
+ # Indexes for fast lookup
121
+ cursor.execute("CREATE INDEX idx_judgment_id ON paragraphs(judgment_id)")
122
+ cursor.execute("CREATE INDEX idx_section ON paragraphs(section)")
123
+ cursor.execute("CREATE INDEX idx_judgment_section ON paragraphs(judgment_id, section)")
124
+
125
+ conn.commit()
126
+
127
+ # Count total records
128
+ with open(INPUT_INDEX, "r", encoding="utf-8") as f:
129
+ total = sum(1 for _ in f)
130
+
131
+ print(f"✓ Inserting {total:,} paragraphs")
132
+
133
+ # Insert data in batches
134
+ batch = []
135
+ BATCH_SIZE = 1000
136
+
137
+ with open(INPUT_INDEX, "r", encoding="utf-8") as f:
138
+ for line in tqdm(f, total=total, desc="Inserting"):
139
+ p = json.loads(line)
140
+
141
+ batch.append((
142
+ p["id"],
143
+ p["judgment_id"],
144
+ p.get("page_no", -1),
145
+ p["text"],
146
+ p.get("char_count", len(p["text"])),
147
+ p.get("word_count", len(p["text"].split())),
148
+ p.get("section", "unknown"),
149
+ p.get("section_conf", 0.0),
150
+ ))
151
+
152
+ if len(batch) >= BATCH_SIZE:
153
+ cursor.executemany(
154
+ """
155
+ INSERT INTO paragraphs
156
+ (id, judgment_id, page_no, text, char_count, word_count, section, section_conf)
157
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
158
+ """,
159
+ batch
160
+ )
161
+
162
+ # Insert into FTS5 table
163
+ fts_batch = [(b[0], b[3]) for b in batch]
164
+ cursor.executemany(
165
+ "INSERT INTO paragraphs_fts (id, text) VALUES (?, ?)",
166
+ fts_batch
167
+ )
168
+
169
+ batch.clear()
170
+
171
+ if batch:
172
+ cursor.executemany(
173
+ """
174
+ INSERT INTO paragraphs
175
+ (id, judgment_id, page_no, text, char_count, word_count, section, section_conf)
176
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
177
+ """,
178
+ batch
179
+ )
180
+
181
+ fts_batch = [(b[0], b[3]) for b in batch]
182
+ cursor.executemany(
183
+ "INSERT INTO paragraphs_fts (id, text) VALUES (?, ?)",
184
+ fts_batch
185
+ )
186
+
187
+ conn.commit()
188
+ conn.close()
189
+
190
+ print("\n✓ SQLite index created successfully")
191
+ print(f"✓ Database path: {DB_PATH}")
192
+ print("=" * 70)
193
+
194
+
195
+ if __name__ == "__main__":
196
+ create_sqlite_index()
src/indexing/paragraph_indexer.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paragraph-level indexer for NyayLens RAG
3
+ - Page-aware
4
+ - Content-stable paragraph IDs
5
+ - Legal-aware filtering
6
+ - Streaming JSONL output (memory safe)
7
+ """
8
+
9
+ import json
10
+ import hashlib
11
+ import re
12
+ from pathlib import Path
13
+ from typing import Dict, Iterable
14
+ from tqdm import tqdm
15
+
16
+
17
+ class ParagraphIndexer:
18
+ """Index legal judgments at paragraph level with stable IDs and metadata"""
19
+
20
+ # Legal keywords worth preserving even in short paragraphs
21
+ LEGAL_KEYWORDS = {
22
+ 'held', 'order', 'appeal', 'writ', 'judgment', 'decree',
23
+ 'petition', 'application', 'allowed', 'dismissed', 'granted',
24
+ 'rejected', 'disposed', 'quashed', 'set aside',
25
+ 'affirmed', 'reversed', 'remanded', 'bail', 'custody',
26
+ 'interim', 'stay', 'injunction', 'no costs'
27
+ }
28
+
29
+ PAGE_MARKER_PATTERN = re.compile(r'<<<PAGE:(\d+)>>>')
30
+
31
+ def __init__(self, texts_dir: Path, output_dir: Path):
32
+ self.texts_dir = Path(texts_dir)
33
+ self.output_dir = Path(output_dir)
34
+ self.output_dir.mkdir(parents=True, exist_ok=True)
35
+
36
+ self.index_file = self.output_dir / "paragraph_index.jsonl"
37
+ self.stats_file = self.output_dir / "index_stats.json"
38
+
39
+ @staticmethod
40
+ def _contains_legal_keyword(text: str) -> bool:
41
+ text_l = text.lower()
42
+ return any(
43
+ re.search(rf"\b{re.escape(kw)}\b", text_l)
44
+ for kw in ParagraphIndexer.LEGAL_KEYWORDS
45
+ )
46
+
47
+ @staticmethod
48
+ def _stable_paragraph_id(judgment_id: str, page_no: int, text: str) -> str:
49
+ """
50
+ Content-stable ID:
51
+ - same paragraph text => same ID
52
+ - survives re-indexing
53
+ """
54
+ h = hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
55
+ page_str = page_no if page_no is not None else "unk"
56
+ return f"{judgment_id}_p{page_str}_{h}"
57
+
58
+ def _strip_header(self, content: str) -> str:
59
+ """
60
+ Remove extractor quality header safely
61
+ """
62
+ sep = "=" * 70
63
+ if sep in content:
64
+ parts = content.split(sep, 2)
65
+ if len(parts) == 3:
66
+ return parts[2].strip()
67
+ return content.strip()
68
+
69
+ def _iter_paragraphs(self, content: str) -> Iterable[tuple]:
70
+ """
71
+ Yield paragraph records with page numbers
72
+ """
73
+ current_page = None
74
+ buffer = []
75
+
76
+ for line in content.splitlines():
77
+ page_match = self.PAGE_MARKER_PATTERN.match(line.strip())
78
+ if page_match:
79
+ # Flush buffer before page change
80
+ if buffer:
81
+ yield current_page, "\n".join(buffer).strip()
82
+ buffer = []
83
+ current_page = int(page_match.group(1))
84
+ continue
85
+
86
+ if not line.strip():
87
+ if buffer:
88
+ yield current_page, "\n".join(buffer).strip()
89
+ buffer = []
90
+ continue
91
+
92
+ buffer.append(line)
93
+
94
+ if buffer:
95
+ yield current_page, "\n".join(buffer).strip()
96
+
97
+ def index_judgment(self, text_file: Path, writer) -> int:
98
+ """
99
+ Index a single judgment file.
100
+ Returns number of paragraphs indexed.
101
+ """
102
+ with open(text_file, "r", encoding="utf-8") as f:
103
+ content = self._strip_header(f.read())
104
+
105
+ judgment_id = text_file.stem
106
+ para_count = 0
107
+
108
+ for page_no, para in self._iter_paragraphs(content):
109
+ if not para:
110
+ continue
111
+
112
+ # Keep substantial OR legally important short paragraphs
113
+ if len(para) < 50 and not self._contains_legal_keyword(para):
114
+ continue
115
+
116
+ record = {
117
+ "id": self._stable_paragraph_id(judgment_id, page_no if page_no is not None else -1, para),
118
+ "judgment_id": judgment_id,
119
+ "page_no": page_no if page_no is not None else -1,
120
+ "text": para,
121
+ "char_count": len(para),
122
+ "word_count": len(para.split())
123
+ }
124
+
125
+ writer.write(json.dumps(record, ensure_ascii=False) + "\n")
126
+ para_count += 1
127
+
128
+ return para_count
129
+
130
+ def build_full_index(self):
131
+ text_files = sorted(self.texts_dir.glob("*.txt"))
132
+ print(f"Indexing {len(text_files):,} judgments...")
133
+
134
+ total_paragraphs = 0
135
+
136
+ with open(self.index_file, "w", encoding="utf-8") as writer:
137
+ for text_file in tqdm(text_files, desc="Indexing"):
138
+ try:
139
+ total_paragraphs += self.index_judgment(text_file, writer)
140
+ except Exception as e:
141
+ print(f"❌ Failed indexing {text_file.name}: {e}")
142
+
143
+ stats = {
144
+ "total_judgments": len(text_files),
145
+ "total_paragraphs": total_paragraphs,
146
+ "avg_paragraphs_per_judgment":
147
+ total_paragraphs / len(text_files) if text_files else 0
148
+ }
149
+
150
+ with open(self.stats_file, "w", encoding="utf-8") as f:
151
+ json.dump(stats, f, indent=2)
152
+
153
+ print("\n✓ Paragraph indexing complete")
154
+ print(f" Total paragraphs: {total_paragraphs:,}")
155
+ print(f" Output: {self.index_file}")
156
+
157
+ return stats
158
+
159
+
160
+ if __name__ == "__main__":
161
+ indexer = ParagraphIndexer(
162
+ texts_dir=Path("data/processed/extracted/texts"),
163
+ output_dir=Path("data/processed/indexed")
164
+ )
165
+
166
+ stats = indexer.build_full_index()
167
+ print(f"\nAverage paragraphs per judgment: {stats['avg_paragraphs_per_judgment']:.1f}")
src/pipeline.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import logging
4
+ import time
5
+
6
+ from extraction.pdf_extractor import LegalJudgmentExtractor
7
+ from segmentation.judgement_segmenter import JudgmentSegmenter
8
+
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class NyayLensPipeline:
13
+ """Unified ingestion pipeline for NyayLens"""
14
+
15
+ def __init__(self, raw_dir="data/raw", processed_dir="data/processed"):
16
+ self.raw_dir = Path(raw_dir)
17
+ self.processed_dir = Path(processed_dir)
18
+
19
+ logger.info("Initializing NyayLens Pipeline components...")
20
+ self.extractor = LegalJudgmentExtractor(output_dir=self.processed_dir / "extracted")
21
+ self.segmenter = JudgmentSegmenter()
22
+
23
+ def ingest_pdf(self, pdf_path: Path):
24
+ """Process a single PDF end-to-end"""
25
+ logger.info(f"--- Starting Ingestion: {pdf_path.name} ---")
26
+ start_time = time.time()
27
+
28
+ # 1. Extraction
29
+ logger.info("Step 1: Extracting text...")
30
+ extraction_result = self.extractor.extract_pdf(pdf_path)
31
+ if not extraction_result['text']:
32
+ logger.error(f"Extraction failed for {pdf_path.name}")
33
+ return False
34
+
35
+ # Optional: Save extraction
36
+ self.extractor.save_extraction(pdf_path, extraction_result)
37
+
38
+ # 2. Segmentation
39
+ logger.info("Step 2: Segmenting judgment...")
40
+ # Simple paragraph split for segmentation
41
+ paragraphs = extraction_result['text'].split('\n\n')
42
+ sections = self.segmenter.segment(paragraphs)
43
+
44
+ logger.info(f"Found {len(sections)} distinct sections.")
45
+
46
+ # 3. Next steps would hook into `create_embeddings.py` and `create_sqlite_index.py`
47
+ logger.info("Step 3: Ready for Embeddings & Indexing (Batch process recommended)")
48
+
49
+ elapsed = time.time() - start_time
50
+ logger.info(f"--- Successfully processed {pdf_path.name} in {elapsed:.2f}s ---")
51
+ return True
52
+
53
+ def process_directory(self, limit: int = None):
54
+ """Process all PDFs in the raw directory"""
55
+ pdfs = list(self.raw_dir.glob("**/*.pdf")) + list(self.raw_dir.glob("**/*.PDF"))
56
+
57
+ if limit:
58
+ pdfs = pdfs[:limit]
59
+
60
+ logger.info(f"Found {len(pdfs)} PDFs to process.")
61
+
62
+ success = 0
63
+ for pdf in pdfs:
64
+ if self.ingest_pdf(pdf):
65
+ success += 1
66
+
67
+ logger.info(f"Batch completed. {success}/{len(pdfs)} successful.")
68
+
69
+ if __name__ == "__main__":
70
+ parser = argparse.ArgumentParser(description="NyayLens Unified Ingestion Pipeline")
71
+ parser.add_argument("--pdf", type=str, help="Path to a single PDF to ingest")
72
+ parser.add_argument("--batch", action="store_true", help="Process all PDFs in data/raw")
73
+ parser.add_argument("--limit", type=int, default=None, help="Limit number of PDFs in batch")
74
+
75
+ args = parser.parse_args()
76
+
77
+ pipeline = NyayLensPipeline()
78
+
79
+ if args.pdf:
80
+ pipeline.ingest_pdf(Path(args.pdf))
81
+ elif args.batch:
82
+ pipeline.process_directory(limit=args.limit)
83
+ else:
84
+ logger.warning("Please specify --pdf <path> or --batch. Use --help for options.")
src/qa/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
src/qa/__pycache__/model.cpython-310.pyc ADDED
Binary file (462 Bytes). View file
 
src/qa/dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ MAX_LENGTH = 384
4
+ DOC_STRIDE = 128
5
+
6
+ def load_and_prepare_dataset(tokenizer):
7
+ dataset = load_dataset("squad") # auto-download
8
+
9
+ def preprocess(examples):
10
+ questions = [q.strip() for q in examples["question"]]
11
+ contexts = examples["context"]
12
+
13
+ tokenized = tokenizer(
14
+ questions,
15
+ contexts,
16
+ truncation="only_second",
17
+ max_length=MAX_LENGTH,
18
+ stride=DOC_STRIDE,
19
+ return_overflowing_tokens=True,
20
+ return_offsets_mapping=True,
21
+ padding="max_length",
22
+ )
23
+
24
+ sample_mapping = tokenized.pop("overflow_to_sample_mapping")
25
+ offset_mapping = tokenized.pop("offset_mapping")
26
+
27
+ start_positions = []
28
+ end_positions = []
29
+
30
+ for i, offsets in enumerate(offset_mapping):
31
+ input_ids = tokenized["input_ids"][i]
32
+ cls_index = input_ids.index(tokenizer.cls_token_id)
33
+
34
+ sample_idx = sample_mapping[i]
35
+ answer = examples["answers"][sample_idx]
36
+
37
+ if len(answer["answer_start"]) == 0:
38
+ start_positions.append(cls_index)
39
+ end_positions.append(cls_index)
40
+ else:
41
+ start_char = answer["answer_start"][0]
42
+ end_char = start_char + len(answer["text"][0])
43
+
44
+ token_start = token_end = None
45
+ for idx, (start, end) in enumerate(offsets):
46
+ if start <= start_char < end:
47
+ token_start = idx
48
+ if start < end_char <= end:
49
+ token_end = idx
50
+ break
51
+
52
+ if token_start is None or token_end is None:
53
+ start_positions.append(cls_index)
54
+ end_positions.append(cls_index)
55
+ else:
56
+ start_positions.append(token_start)
57
+ end_positions.append(token_end)
58
+
59
+ tokenized["start_positions"] = start_positions
60
+ tokenized["end_positions"] = end_positions
61
+ return tokenized
62
+
63
+ tokenized_dataset = dataset.map(
64
+ preprocess,
65
+ batched=True,
66
+ remove_columns=dataset["train"].column_names,
67
+ )
68
+
69
+ return tokenized_dataset
src/qa/inference.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import faiss
3
+ import json
4
+ import sqlite3
5
+ import re
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
8
+
9
+
10
+ class LegalQAEngine:
11
+ def __init__(self):
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {self.device}")
14
+
15
+ # ---- Load QA model ----
16
+ self.tokenizer = AutoTokenizer.from_pretrained("outputs/qa_model/final")
17
+ self.qa_model = AutoModelForQuestionAnswering.from_pretrained(
18
+ "outputs/qa_model/final"
19
+ ).to(self.device)
20
+ self.qa_model.eval()
21
+
22
+ # ---- Load retriever ----
23
+ self.embedder = SentenceTransformer("BAAI/bge-base-en-v1.5", device=self.device)
24
+ self.index = faiss.read_index("data/processed/faiss/faiss_index.bin")
25
+
26
+ with open("data/processed/embeddings/paragraph_ids.json", encoding="utf-8") as f:
27
+ self.para_ids = json.load(f)
28
+
29
+ self.db = sqlite3.connect("data/processed/indexed/paragraphs.db")
30
+ self.cursor = self.db.cursor()
31
+
32
+ print("✓ Enhanced QA inference system ready")
33
+
34
+ # ------------------------------------------------------------------
35
+ # TEXT NORMALIZATION (critical for PDF artifacts)
36
+ # ------------------------------------------------------------------
37
+ def _normalize(self, text: str) -> str:
38
+ text = text.lower()
39
+ text = re.sub(r"\s+", " ", text)
40
+ return text.strip()
41
+
42
+ # ------------------------------------------------------------------
43
+ # REFUTED CLAUSE DETECTION (Article 21 FIX)
44
+ # ------------------------------------------------------------------
45
+ def _is_refuted_clause(self, answer_text, paragraph_text):
46
+ para = self._normalize(paragraph_text)
47
+ ans = self._normalize(answer_text)
48
+
49
+ # Patterns like:
50
+ # "it is not correct to say, ..., that X"
51
+ # "it cannot be said, ..., that X"
52
+ refutation_regexes = [
53
+ r"not correct to say.*?that\s+(.*?)(?:\.|,)",
54
+ r"cannot be said.*?that\s+(.*?)(?:\.|,)",
55
+ ]
56
+
57
+ for pattern in refutation_regexes:
58
+ matches = re.findall(pattern, para)
59
+ for refuted_prop in matches:
60
+ # If answer is part of the refuted proposition → block
61
+ if ans in refuted_prop:
62
+ return True
63
+
64
+ return False
65
+
66
+
67
+ # ------------------------------------------------------------------
68
+ # RETRIEVAL
69
+ # ------------------------------------------------------------------
70
+ def retrieve_paragraphs(self, question, top_k=8):
71
+ q_emb = self.embedder.encode(
72
+ [question], normalize_embeddings=True, convert_to_numpy=True
73
+ )
74
+ scores, indices = self.index.search(q_emb, top_k)
75
+
76
+ results = []
77
+ for score, idx in zip(scores[0], indices[0]):
78
+ para_id = self.para_ids[idx]
79
+ self.cursor.execute(
80
+ "SELECT judgment_id, page_no, text FROM paragraphs WHERE id = ?",
81
+ (para_id,),
82
+ )
83
+ row = self.cursor.fetchone()
84
+ if row:
85
+ judgment_id, page_no, text = row
86
+ results.append(
87
+ {
88
+ "judgment_id": judgment_id,
89
+ "page_no": page_no,
90
+ "text": text,
91
+ "retrieval_score": float(score),
92
+ }
93
+ )
94
+ return results
95
+
96
+ # ------------------------------------------------------------------
97
+ # ANSWERING
98
+ # ------------------------------------------------------------------
99
+ def answer_question(self, question, top_k=8, max_answers=2):
100
+ paragraphs = self.retrieve_paragraphs(question, top_k)
101
+ candidates = []
102
+
103
+ for para in paragraphs:
104
+ inputs = self.tokenizer(
105
+ question,
106
+ para["text"],
107
+ return_tensors="pt",
108
+ truncation=True,
109
+ max_length=512,
110
+ ).to(self.device)
111
+
112
+ with torch.no_grad():
113
+ outputs = self.qa_model(**inputs)
114
+
115
+ start_logits = outputs.start_logits[0]
116
+ end_logits = outputs.end_logits[0]
117
+
118
+ token_type_ids = inputs["token_type_ids"][0].tolist()
119
+ question_end = token_type_ids.index(1)
120
+
121
+ top_starts = torch.topk(start_logits, k=5).indices
122
+ top_ends = torch.topk(end_logits, k=5).indices
123
+
124
+ for s in top_starts:
125
+ for e in top_ends:
126
+ if e < s or (e - s) > 80:
127
+ continue
128
+
129
+ # ❌ Block question echo
130
+ if s < question_end:
131
+ continue
132
+
133
+ answer_tokens = inputs["input_ids"][0][s : e + 1]
134
+ answer_text = self.tokenizer.decode(
135
+ answer_tokens, skip_special_tokens=True
136
+ ).strip()
137
+
138
+ words = answer_text.split()
139
+ if len(words) < 8:
140
+ continue
141
+
142
+ # ❌ Block refuted propositions
143
+ if self._is_refuted_clause(answer_text, para["text"]):
144
+ continue
145
+
146
+ score = start_logits[s].item() + end_logits[e].item()
147
+
148
+ # Doctrinal boost
149
+ if any(
150
+ k in answer_text.lower()
151
+ for k in ["the court", "held that", "it is clear that", "the law"]
152
+ ):
153
+ score += 1.5
154
+
155
+ candidates.append(
156
+ {
157
+ "answer": answer_text,
158
+ "confidence": score,
159
+ "judgment_id": para["judgment_id"],
160
+ "page_no": para["page_no"],
161
+ "paragraph": para["text"],
162
+ "retrieval_score": para["retrieval_score"],
163
+ }
164
+ )
165
+
166
+ # ---- Deduplicate answers ----
167
+ seen = set()
168
+ final = []
169
+ for c in sorted(candidates, key=lambda x: x["confidence"], reverse=True):
170
+ key = self._normalize(c["answer"])
171
+ if key not in seen:
172
+ seen.add(key)
173
+ final.append(c)
174
+
175
+ return final[:max_answers]
176
+
177
+
178
+ # ----------------------------------------------------------------------
179
+ # DEMO
180
+ # ----------------------------------------------------------------------
181
+ if __name__ == "__main__":
182
+ qa = LegalQAEngine()
183
+
184
+ questions = [
185
+ "What is the scope of Article 21?",
186
+ "What are the conditions for granting anticipatory bail?",
187
+ "What is the burden of proof in criminal law?",
188
+ ]
189
+
190
+ for q in questions:
191
+ print("\n" + "=" * 90)
192
+ print(f"QUESTION: {q}")
193
+ print("=" * 90)
194
+
195
+ answers = qa.answer_question(q)
196
+
197
+ for i, ans in enumerate(answers, 1):
198
+ print(f"\nANSWER {i}:")
199
+ print(ans["answer"])
200
+ print(
201
+ f"\nSOURCE: {ans['judgment_id']} | Page: {ans['page_no']}"
202
+ )
203
+ print(f"Retrieval score: {ans['retrieval_score']:.3f}")
204
+ print(f"Confidence score: {ans['confidence']:.2f}")
205
+ print("\nPARAGRAPH:")
206
+ print(ans["paragraph"][:700] + "...")
src/qa/model.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
2
+
3
+ MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
4
+
5
+ def load_model():
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
7
+ model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
8
+ return tokenizer, model
src/qa/monitor_training.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Monitor training progress"""
2
+ import json
3
+ from pathlib import Path
4
+ import time
5
+
6
+ log_file = Path("outputs/qa_model/trainer_state.json")
7
+
8
+ print("Monitoring training progress...\n")
9
+
10
+ while True:
11
+ if log_file.exists():
12
+ with open(log_file) as f:
13
+ state = json.load(f)
14
+
15
+ if 'log_history' in state:
16
+ latest = state['log_history'][-1]
17
+ print(f"Step {latest.get('step', 0):5d} | "
18
+ f"Loss: {latest.get('loss', 0):.4f} | "
19
+ f"Eval Loss: {latest.get('eval_loss', 'N/A')}")
20
+
21
+ time.sleep(10)
src/qa/train.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TrainingArguments, Trainer
2
+ from model import load_model
3
+ from dataset import load_and_prepare_dataset
4
+
5
+ def main():
6
+ tokenizer, model = load_model()
7
+ dataset = load_and_prepare_dataset(tokenizer)
8
+
9
+ training_args = TrainingArguments(
10
+ output_dir="outputs/qa_model",
11
+ eval_strategy="steps", # ✅ correct API
12
+ eval_steps=1000,
13
+ learning_rate=3e-5,
14
+ per_device_train_batch_size=8,
15
+ per_device_eval_batch_size=8,
16
+ num_train_epochs=2,
17
+ weight_decay=0.01,
18
+ fp16=True,
19
+ logging_steps=500,
20
+ save_steps=2000,
21
+ save_total_limit=2,
22
+ load_best_model_at_end=True,
23
+ metric_for_best_model="eval_loss",
24
+ report_to="none",
25
+ )
26
+
27
+ trainer = Trainer(
28
+ model=model,
29
+ args=training_args,
30
+ train_dataset=dataset["train"],
31
+ eval_dataset=dataset["validation"],
32
+ tokenizer=tokenizer,
33
+ )
34
+
35
+ trainer.train()
36
+
37
+ # Save final model + tokenizer
38
+ trainer.save_model("outputs/qa_model/final")
39
+ tokenizer.save_pretrained("outputs/qa_model/final")
40
+
41
+ if __name__ == "__main__":
42
+ main()
src/rag/__pycache__/query_engine.cpython-310.pyc ADDED
Binary file (8.49 kB). View file
 
src/rag/query_engine.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAG Query Engine with LLM"""
2
+
3
+ # pyrefly: ignore [missing-import]
4
+ import faiss
5
+ import json
6
+ import sqlite3
7
+ import re
8
+ # pyrefly: ignore [missing-import]
9
+ from sentence_transformers import SentenceTransformer, CrossEncoder
10
+ import os
11
+ from groq import Groq
12
+ from dotenv import load_dotenv
13
+
14
+ from src.summarization.ranker import ImportanceRanker
15
+ from src.summarization.utils import split_sentences
16
+
17
+ load_dotenv()
18
+
19
+ class QueryEngine:
20
+
21
+ def __init__(self):
22
+ print("Loading RAG components...")
23
+
24
+ # FAISS + SQLite + Embeddings
25
+ self.index = faiss.read_index("data/processed/faiss/faiss_index.bin")
26
+
27
+ with open("data/processed/embeddings/paragraph_ids.json") as f:
28
+ self.para_ids = json.load(f)
29
+
30
+ self.model = SentenceTransformer("BAAI/bge-base-en-v1.5")
31
+ self.reranker = CrossEncoder("BAAI/bge-reranker-base")
32
+ self.importance_ranker = ImportanceRanker("outputs/summarization/final")
33
+
34
+ def _get_db(self):
35
+ return sqlite3.connect("data/processed/indexed/paragraphs.db")
36
+
37
+ # LLM Setup
38
+ api_key = os.getenv('GROQ_API_KEY')
39
+ if not api_key:
40
+ raise ValueError("GROQ_API_KEY not set")
41
+
42
+ self.llm = Groq(api_key=api_key)
43
+ self.llm_model = 'llama-3.1-8b-instant'
44
+
45
+ print(f"✓ Ready with {self.index.ntotal:,} vectors")
46
+ print(f"✓ LLM: Groq (Llama 3.1 8B)")
47
+
48
+ def search(self, query: str, top_k: int = 5):
49
+ """Hybrid Search: FAISS (Dense) + SQLite FTS5 (BM25) with RRF"""
50
+
51
+ # --- 1. Dense Search (FAISS) ---
52
+ query_vec = self.model.encode([query], normalize_embeddings=True)
53
+ dense_scores, dense_indices = self.index.search(query_vec, k=top_k * 2) # Fetch extra for fusion
54
+
55
+ dense_results = []
56
+ for rank, (score, idx) in enumerate(zip(dense_scores[0], dense_indices[0])):
57
+ para_id = self.para_ids[idx]
58
+ dense_results.append({'id': para_id, 'score': float(score), 'rank': rank + 1})
59
+
60
+ # --- 2. Keyword Search (SQLite FTS5 BM25) ---
61
+ db = self._get_db()
62
+ cursor = db.cursor()
63
+
64
+ # FTS5 requires a specific syntax. A raw string means AND.
65
+ # We want BM25 behavior (OR), so we clean punctuation and join words with OR.
66
+ import re
67
+ clean_query = re.sub(r'[^\w\s]', '', query)
68
+ fts_query = " OR ".join(clean_query.split())
69
+
70
+ try:
71
+ cursor.execute(f"""
72
+ SELECT id, bm25(paragraphs_fts) as bm25_score
73
+ FROM paragraphs_fts
74
+ WHERE paragraphs_fts MATCH ?
75
+ ORDER BY bm25_score LIMIT ?
76
+ """, (fts_query, top_k * 2))
77
+ fts_rows = cursor.fetchall()
78
+
79
+ keyword_results = []
80
+ for rank, row in enumerate(fts_rows):
81
+ keyword_results.append({'id': row[0], 'score': float(row[1]), 'rank': rank + 1})
82
+ except sqlite3.OperationalError:
83
+ # Fallback if query syntax is too complex for basic MATCH or FTS table missing
84
+ keyword_results = []
85
+
86
+ # --- 3. Reciprocal Rank Fusion (RRF) ---
87
+ # RRF Score = 1 / (k + rank) where k is usually 60
88
+ k = 60
89
+ rrf_scores = {}
90
+
91
+ # Add dense scores
92
+ for res in dense_results:
93
+ pid = res['id']
94
+ rrf_scores[pid] = rrf_scores.get(pid, 0.0) + (1.0 / (k + res['rank']))
95
+
96
+ # Add keyword scores
97
+ for res in keyword_results:
98
+ pid = res['id']
99
+ rrf_scores[pid] = rrf_scores.get(pid, 0.0) + (1.0 / (k + res['rank']))
100
+
101
+ # Sort by RRF score descending
102
+ # Fetch a larger pool of candidates for reranking
103
+ candidate_pool_size = top_k * 3
104
+ sorted_rrf = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:candidate_pool_size]
105
+
106
+ # --- 4. Fetch Details & Rerank (Cross-Encoder) ---
107
+ candidates = []
108
+ for pid, rrf_score in sorted_rrf:
109
+ cursor.execute(
110
+ "SELECT judgment_id, text, page_no FROM paragraphs WHERE id = ?",
111
+ (pid,)
112
+ )
113
+ row = cursor.fetchone()
114
+
115
+ if row:
116
+ candidates.append({
117
+ 'rrf_score': rrf_score,
118
+ 'judgment_id': row[0],
119
+ 'text': row[1],
120
+ 'page_no': row[2],
121
+ 'id': pid
122
+ })
123
+
124
+ if not candidates:
125
+ db.close()
126
+ return []
127
+
128
+ db.close()
129
+
130
+ # --- 5. Final Rerank (Cross-Encoder) ---
131
+
132
+ # Prepare inputs for cross-encoder: list of [query, document_text]
133
+ cross_inp = [[query, doc['text']] for doc in candidates]
134
+ rerank_scores = self.reranker.predict(cross_inp)
135
+
136
+ # Attach scores and sort
137
+ for i, score in enumerate(rerank_scores):
138
+ candidates[i]['score'] = float(score) # Use cross-encoder score as final score
139
+
140
+ candidates = sorted(candidates, key=lambda x: x['score'], reverse=True)
141
+
142
+ return candidates[:top_k]
143
+
144
+ def generate_answer(self, question: str, context: str, sources: list = [], chat_history: list = None):
145
+ """Generate answer using Groq LLM with strict Legal Guardrails.
146
+ Sources list is injected into the prompt so the LLM can ONLY cite
147
+ what was actually retrieved — no hallucinated references.
148
+ """
149
+ chat_history = chat_history or []
150
+ chat_history = chat_history[-6:] # Cap to last 3 turns
151
+ # Build a numbered source registry for the LLM
152
+ source_registry = ""
153
+ for i, s in enumerate(sources, 1):
154
+ source_registry += f"[{i}] {s.get('judgment_id', 'Unknown')}\n"
155
+
156
+ prompt = f"""You are a strict, brilliant legal research assistant specializing in Indian Supreme Court judgments.
157
+
158
+ GUARDRAIL: You MUST ONLY answer questions related to law, legal processes, or the provided context.
159
+ If the question is entirely unrelated to law (e.g., "how to bake a cake"), reply EXACTLY with:
160
+ "I am a legal AI assistant. I can only answer questions related to law."
161
+
162
+ CITATION RULES — THIS IS CRITICAL:
163
+ - You may ONLY cite sources from the APPROVED SOURCE LIST below.
164
+ - Do NOT cite any case from your training memory that is not in the APPROVED SOURCE LIST.
165
+ - If you cite a case not in this list, you are hallucinating and failing your task.
166
+ - Use [1], [2], [3] etc. to refer to sources from the list below.
167
+
168
+ APPROVED SOURCE LIST (cite ONLY these):
169
+ {source_registry}
170
+ CONTEXT (retrieved paragraphs):
171
+ {context}
172
+
173
+ QUESTION: {question}
174
+
175
+ INSTRUCTIONS:
176
+ - Provide a detailed, comprehensive legal answer in a professional conversational tone.
177
+ - Explain concepts clearly so a lawyer finds it extremely useful.
178
+ - Cite ONLY from the APPROVED SOURCE LIST above using [1], [2], [3] format.
179
+ - Use proper legal terminology.
180
+ - Do NOT invent case names, citations, or dates.
181
+ - TEMPORAL AWARENESS: Look at the years in the judgment titles (e.g. 2023_CaseName). Newer judgments (e.g. 2023) supersede older judgments (e.g. 2010). If the retrieved context contains conflicting rulings, you MUST prioritize the newer judgment and explicitly warn the user that the older precedent may have been superseded.
182
+
183
+ ANSWER:"""
184
+
185
+ messages = chat_history.copy()
186
+ messages.append({"role": "user", "content": prompt})
187
+
188
+ response = self.llm.chat.completions.create(
189
+ model=self.llm_model,
190
+ messages=messages,
191
+ temperature=0.2,
192
+ max_tokens=1024
193
+ )
194
+
195
+ return response.choices[0].message.content
196
+
197
+ def query(self, question: str, top_k: int = 5, chat_history: list = None):
198
+ """Main query method"""
199
+ print(f"\n{'='*70}")
200
+ print(f"QUERY: {question}")
201
+ print('='*70)
202
+
203
+ # Search
204
+ print("\nSearching FAISS index...")
205
+ results = self.search(question, top_k)
206
+
207
+ print(f"Found {len(results)} relevant paragraphs")
208
+
209
+ # Format context
210
+ context_parts = []
211
+ for i, r in enumerate(results, 1):
212
+ context_parts.append(
213
+ f"[{i}] {r['judgment_id']}\n{r['text']}"
214
+ )
215
+ context = "\n\n".join(context_parts)
216
+
217
+ # Generate answer — pass sources so LLM can only cite what was retrieved
218
+ print("Generating answer with LLM...")
219
+ answer = self.generate_answer(question, context, sources=results, chat_history=chat_history)
220
+
221
+ return {
222
+ 'question': question,
223
+ 'answer': answer,
224
+ 'sources': results
225
+ }
226
+
227
+ def query_with_document(self, question: str, filepath: str, chat_history: list = None):
228
+ """Queries a specific document. Falls back to global RAG if answer not found."""
229
+ chat_history = chat_history or []
230
+ try:
231
+ with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
232
+ doc_text = f.read()
233
+ except Exception as e:
234
+ return {"answer": f"Error reading document: {e}", "sources": []}
235
+
236
+ # JUDGMENTS are usually under 30k chars. Let's take as much as possible.
237
+ if len(doc_text) > 30000:
238
+ print("Document exceeds 30k chars. Applying Semantic Truncation...")
239
+ try:
240
+ sentences = [s for s in split_sentences(doc_text) if len(s.strip()) > 20]
241
+ scores = self.importance_ranker.score(sentences)
242
+ indexed = list(enumerate(zip(sentences, scores)))
243
+ sorted_by_score = sorted(indexed, key=lambda x: x[1][1], reverse=True)
244
+
245
+ selected_indices = []
246
+ current_chars = 0
247
+ for idx, (sentence, score) in sorted_by_score:
248
+ if current_chars + len(sentence) > 30000:
249
+ continue
250
+ selected_indices.append(idx)
251
+ current_chars += len(sentence)
252
+ if current_chars > 29000:
253
+ break
254
+
255
+ # Restore original chronological order
256
+ top_in_order = sorted([indexed[i] for i in selected_indices], key=lambda x: x[0])
257
+ doc_text = " ".join(s for _, (s, _) in top_in_order) + "\n\n... [TRUNCATED SEMANTICALLY FOR LLM] ..."
258
+ except Exception as e:
259
+ print(f"Semantic Truncation failed: {e}. Falling back to naive truncation.")
260
+ doc_text = doc_text[:30000] + "\n\n... [TRUNCATED DUE TO SIZE] ..."
261
+
262
+ print(f"--- DOCUMENT QA START ---")
263
+ print(f"File: {os.path.basename(filepath)}")
264
+ print(f"Size: {len(doc_text)} chars")
265
+ print(f"Question: {question}")
266
+
267
+ prompt = f"""You are a strict Legal Document Auditor.
268
+ Your ONLY source of information is the text provided below.
269
+
270
+ STRICT RULES:
271
+ 1. Answer the QUESTION using ONLY the DOCUMENT text.
272
+ 2. If the answer is not in the text, say "I cannot find this in the uploaded document."
273
+ 3. DO NOT cite external cases (like Venkata Reddy or V.C. Shukla) unless they are explicitly mentioned in the text below.
274
+ 4. If you use your own internal knowledge instead of the document, you are failing your task.
275
+
276
+ DOCUMENT TEXT:
277
+ {doc_text}
278
+
279
+ QUESTION: {question}
280
+
281
+ DETAILED ANSWER (citing specific paragraphs if possible):"""
282
+
283
+ messages = chat_history.copy()
284
+ messages.append({"role": "user", "content": prompt})
285
+
286
+ response = self.llm.chat.completions.create(
287
+ model=self.llm_model,
288
+ messages=messages,
289
+ temperature=0.1,
290
+ max_tokens=1024
291
+ )
292
+
293
+ answer = response.choices[0].message.content.strip()
294
+
295
+ return {
296
+ 'question': question,
297
+ 'answer': answer,
298
+ 'sources': [{'judgment_id': os.path.basename(filepath), 'score': 1.0}]
299
+ }
300
+
301
+ def close(self):
302
+ self.db.close()
303
+
304
+ # Test
305
+ if __name__ == "__main__":
306
+ engine = QueryEngine()
307
+
308
+ # Test queries
309
+ queries = [
310
+ "What are the conditions for granting anticipatory bail?",
311
+ "Explain the doctrine of legitimate expectation",
312
+ "What is the burden of proof in criminal cases?"
313
+ ]
314
+
315
+ for query in queries:
316
+ response = engine.query(query, top_k=3)
317
+
318
+ print(f"\nANSWER:\n{response['answer']}\n")
319
+
320
+ print("SOURCES:")
321
+ for i, src in enumerate(response['sources'], 1):
322
+ print(f" [{i}] {src['judgment_id']} (score: {src['score']:.3f})")
323
+
324
+ print("\n" + "="*70 + "\n")
325
+
326
+ engine.close()
src/rag/test_retriever.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test RAG retrieval system with SQLite
3
+ """
4
+
5
+ import faiss
6
+ import json
7
+ import sqlite3
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ def test_retrieval():
11
+ print("="*70)
12
+ print("Testing RAG Retrieval")
13
+ print("="*70)
14
+
15
+ # Load FAISS index
16
+ index = faiss.read_index("data/processed/faiss/faiss_index.bin")
17
+
18
+ # Load paragraph IDs
19
+ with open("data/processed/embeddings/paragraph_ids.json", encoding="utf-8") as f:
20
+ para_ids = json.load(f)
21
+
22
+ # Connect to SQLite
23
+ conn = sqlite3.connect("data/processed/indexed/paragraphs.db")
24
+ cursor = conn.cursor()
25
+
26
+ # Load embedding model (MUST match indexing)
27
+ model = SentenceTransformer("BAAI/bge-base-en-v1.5")
28
+
29
+ print(f"✓ Ready with {index.ntotal:,} vectors\n")
30
+
31
+ queries = [
32
+ "right to privacy under Article 21",
33
+ "anticipatory bail conditions",
34
+ "burden of proof in criminal cases",
35
+ "doctrine of legitimate expectation"
36
+ ]
37
+
38
+ for query in queries:
39
+ print(f"\n{'='*70}")
40
+ print(f"QUERY: {query}")
41
+ print('='*70)
42
+
43
+ query_vec = model.encode(
44
+ [query],
45
+ normalize_embeddings=True,
46
+ convert_to_numpy=True
47
+ )
48
+
49
+ scores, indices = index.search(query_vec, k=3)
50
+
51
+ for rank, (score, idx) in enumerate(zip(scores[0], indices[0]), 1):
52
+ if idx < 0 or idx >= len(para_ids):
53
+ continue
54
+
55
+ para_id = para_ids[idx]
56
+
57
+ cursor.execute(
58
+ "SELECT judgment_id, page_no, text FROM paragraphs WHERE id = ?",
59
+ (para_id,)
60
+ )
61
+ row = cursor.fetchone()
62
+
63
+ if not row:
64
+ continue
65
+
66
+ judgment_id, page_no, text = row
67
+
68
+ print(f"\n[{rank}] Score: {score:.4f}")
69
+ print(f"Source: {judgment_id} | Page: {page_no}")
70
+ print(f"Text: {text[:300]}...")
71
+
72
+ conn.close()
73
+
74
+ if __name__ == "__main__":
75
+ test_retrieval()
src/segmentation/__pycache__/judgement_segmenter.cpython-310.pyc ADDED
Binary file (4.54 kB). View file
 
src/segmentation/annotate_paragraphs.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # Annotate paragraphs with legal sections using JudgmentSegmenter
3
+ # Creates paragraph_index_with_sections.jsonl
4
+ # """
5
+
6
+ # import json
7
+ # from pathlib import Path
8
+ # from collections import defaultdict
9
+ # from tqdm import tqdm
10
+
11
+ # from judgement_segmenter import JudgmentSegmenter
12
+
13
+
14
+ # INPUT_INDEX = Path("data/processed/indexed/paragraph_index.jsonl")
15
+ # OUTPUT_INDEX = Path("data/processed/indexed/paragraph_index_with_sections.jsonl")
16
+
17
+
18
+ # def annotate_paragraphs():
19
+ # print("=" * 70)
20
+ # print("NyayLens – Annotating Paragraphs with Sections")
21
+ # print("=" * 70)
22
+
23
+ # # Load paragraphs grouped by judgment
24
+ # judgments = defaultdict(list)
25
+
26
+ # with open(INPUT_INDEX, "r", encoding="utf-8") as f:
27
+ # for line in f:
28
+ # p = json.loads(line)
29
+ # judgments[p["judgment_id"]].append(p)
30
+
31
+ # print(f"✓ Loaded {len(judgments):,} judgments")
32
+
33
+ # segmenter = JudgmentSegmenter()
34
+
35
+ # with open(OUTPUT_INDEX, "w", encoding="utf-8") as writer:
36
+ # for judgment_id, paras in tqdm(judgments.items(), desc="Annotating"):
37
+ # # Preserve original order
38
+ # paras = sorted(paras, key=lambda x: (x["page_no"], x["id"]))
39
+
40
+ # texts = [p["text"] for p in paras]
41
+
42
+ # sections = segmenter.segment(texts)
43
+
44
+ # # Default all to unknown
45
+ # section_labels = [
46
+ # ("unknown", 0.0) for _ in paras
47
+ # ]
48
+
49
+ # # Apply section labels
50
+ # for sec in sections:
51
+ # for i in range(sec.start_para_idx, sec.end_para_idx + 1):
52
+ # section_labels[i] = (sec.type, sec.confidence)
53
+
54
+ # # Write annotated paragraphs
55
+ # for p, (sec_type, sec_conf) in zip(paras, section_labels):
56
+ # p_out = dict(p)
57
+ # p_out["section"] = sec_type
58
+ # p_out["section_conf"] = sec_conf
59
+
60
+ # writer.write(json.dumps(p_out, ensure_ascii=False) + "\n")
61
+
62
+ # print("\n✓ Annotation complete")
63
+ # print(f"✓ Output written to: {OUTPUT_INDEX}")
64
+
65
+
66
+ # if __name__ == "__main__":
67
+ # annotate_paragraphs()
68
+ """
69
+ Annotate paragraphs with legal sections using JudgmentSegmenter
70
+ PRESERVES ORIGINAL IDs AND ORDER
71
+ """
72
+
73
+ import json
74
+ from pathlib import Path
75
+ from collections import defaultdict
76
+ from tqdm import tqdm
77
+
78
+ from judgement_segmenter import JudgmentSegmenter
79
+
80
+
81
+ INPUT_INDEX = Path("data/processed/indexed/paragraph_index.jsonl")
82
+ OUTPUT_INDEX = Path("data/processed/indexed/paragraph_index_with_sections.jsonl")
83
+
84
+
85
+ def annotate_paragraphs():
86
+ print("=" * 70)
87
+ print("NyayLens – Annotating Paragraphs with Sections")
88
+ print("=" * 70)
89
+
90
+ # Load paragraphs IN ORIGINAL ORDER
91
+ all_paragraphs = []
92
+ with open(INPUT_INDEX, "r", encoding="utf-8") as f:
93
+ for line in f:
94
+ all_paragraphs.append(json.loads(line))
95
+
96
+ print(f"✓ Loaded {len(all_paragraphs):,} paragraphs")
97
+
98
+ # Group by judgment (preserve index in group)
99
+ judgments = defaultdict(list)
100
+ for idx, p in enumerate(all_paragraphs):
101
+ judgments[p["judgment_id"]].append((idx, p)) # ← Store original index
102
+
103
+ segmenter = JudgmentSegmenter()
104
+
105
+ # Create array to store annotations (preserves original order)
106
+ annotations = [None] * len(all_paragraphs)
107
+
108
+ for judgment_id, indexed_paras in tqdm(judgments.items(), desc="Annotating"):
109
+ # Extract just the paragraphs
110
+ indices = [ip[0] for ip in indexed_paras]
111
+ paras = [ip[1] for ip in indexed_paras]
112
+
113
+ # Get texts
114
+ texts = [p["text"] for p in paras]
115
+
116
+ # Segment
117
+ sections = segmenter.segment(texts)
118
+
119
+ # Default labels
120
+ section_labels = [("unknown", 0.0) for _ in paras]
121
+
122
+ # Apply section labels
123
+ for sec in sections:
124
+ for i in range(sec.start_para_idx, sec.end_para_idx + 1):
125
+ if i < len(section_labels):
126
+ section_labels[i] = (sec.type, sec.confidence)
127
+
128
+ # Store annotations in ORIGINAL positions
129
+ for orig_idx, p, (sec_type, sec_conf) in zip(indices, paras, section_labels):
130
+ p_out = dict(p) # Copy original
131
+ p_out["section"] = sec_type
132
+ p_out["section_conf"] = sec_conf
133
+ annotations[orig_idx] = p_out
134
+
135
+ # Write in ORIGINAL order
136
+ print("\nWriting annotated paragraphs...")
137
+ with open(OUTPUT_INDEX, "w", encoding="utf-8") as writer:
138
+ for p_out in annotations:
139
+ writer.write(json.dumps(p_out, ensure_ascii=False) + "\n")
140
+
141
+ print(f"✓ Output written to: {OUTPUT_INDEX}")
142
+ print("=" * 70)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ annotate_paragraphs()
src/segmentation/check.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test section-aware retrieval"""
2
+ import faiss
3
+ import json
4
+ import sqlite3
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # Load
9
+ index = faiss.read_index("data/processed/faiss/faiss_index.bin")
10
+ with open("data/processed/embeddings/paragraph_ids.json") as f:
11
+ para_ids = json.load(f)
12
+
13
+ db = sqlite3.connect("data/processed/indexed/paragraphs.db")
14
+ cursor = db.cursor()
15
+
16
+ model = SentenceTransformer("BAAI/bge-base-en-v1.5")
17
+
18
+ # Test query
19
+ query = "What were the facts of the case?"
20
+ query_vec = model.encode([query], normalize_embeddings=True)
21
+
22
+ # Search
23
+ scores, indices = index.search(query_vec, k=10)
24
+
25
+ print(f"Query: {query}\n")
26
+ print("Top results with sections:")
27
+
28
+ for i, (score, idx) in enumerate(zip(scores[0], indices[0]), 1):
29
+ para_id = para_ids[idx]
30
+
31
+ cursor.execute(
32
+ "SELECT judgment_id, section, section_conf, text FROM paragraphs WHERE id = ?",
33
+ (para_id,)
34
+ )
35
+ row = cursor.fetchone()
36
+
37
+ print(f"\n[{i}] Score: {score:.3f} | Section: {row[1]} (conf={row[2]:.2f})")
38
+ print(f" Case: {row[0]}")
39
+ print(f" Text: {row[3][:100]}...")
40
+
41
+ db.close()
src/segmentation/judgement_segmenter.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced Judgment Segmenter (FIXED)
3
+ Segments judgments into: Facts, Issues, Arguments, Analysis, Decision
4
+ """
5
+
6
+ import re
7
+ import os
8
+ import logging
9
+ from typing import List, Dict, Tuple
10
+ from dataclasses import dataclass
11
+
12
+ try:
13
+ from transformers import pipeline
14
+ TRANSFORMERS_AVAILABLE = True
15
+ except ImportError:
16
+ TRANSFORMERS_AVAILABLE = False
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class Section:
23
+ type: str # facts/issues/arguments/analysis/decision/unknown
24
+ text: str
25
+ start_para_idx: int
26
+ end_para_idx: int
27
+ confidence: float
28
+
29
+
30
+ class JudgmentSegmenter:
31
+
32
+ MARKERS = {
33
+ 'facts': [
34
+ r'\bbrief\s+facts?\b',
35
+ r'\bfactual\s+(matrix|background)\b',
36
+ r'\bcircumstances\s+of\s+the\s+case\b',
37
+ r'\bbackground\b',
38
+ ],
39
+ 'issues': [
40
+ r'\bissues?\s+(for|of)\s+(consideration|determination)\b',
41
+ r'\bsubstantial\s+questions?\b',
42
+ r'\bpoints?\s+for\s+consideration\b',
43
+ r'\bquestions?\s+framed\b',
44
+ ],
45
+ 'arguments': [
46
+ r'\blearned\s+counsel\b',
47
+ r'\bsubmissions?\b',
48
+ r'\b(argued|submitted|contended)\b',
49
+ r'\bon\s+behalf\s+of\b',
50
+ ],
51
+ 'analysis': [
52
+ r'\bwe\s+have\s+(considered|examined|analysed)\b',
53
+ r'\bthe\s+court\s+(finds|observes|notes|holds)\b',
54
+ r'\bin\s+our\s+(view|opinion)\b',
55
+ r'\bit\s+is\s+clear\s+that\b',
56
+ ],
57
+ 'decision': [
58
+ r'\b(appeal|petition|writ)\s+is\s+(allowed|dismissed)\b',
59
+ r'\baccordingly\b',
60
+ r'\bwe\s+direct\b',
61
+ r'\bheld\s*:\b',
62
+ r'\border\b',
63
+ ]
64
+ }
65
+
66
+ def __init__(self, model_path: str = "models/segmentation_model"):
67
+ """Initialize segmenter, preferring ML model if available, else Regex fallback"""
68
+ self.use_ml = False
69
+ self.classifier = None
70
+
71
+ if TRANSFORMERS_AVAILABLE and os.path.exists(model_path):
72
+ try:
73
+ logger.info(f"Loading ML Segmentation model from {model_path}...")
74
+ self.classifier = pipeline("text-classification", model=model_path, device=-1)
75
+ self.use_ml = True
76
+ logger.info("✓ ML Segmenter loaded successfully.")
77
+ except Exception as e:
78
+ logger.warning(f"Failed to load ML model, falling back to Regex: {e}")
79
+ else:
80
+ logger.info("ML model not found or transformers not installed. Using Regex fallback.")
81
+
82
+ def detect_section(self, para: str, position_ratio: float) -> Tuple[str, float]:
83
+ """
84
+ Detect section type for a paragraph
85
+ Returns: (section_type, confidence)
86
+ """
87
+ para_lower = para.lower()
88
+ best_type = 'unknown'
89
+ best_conf = 0.0
90
+
91
+ for sec_type, patterns in self.MARKERS.items():
92
+ for pattern in patterns:
93
+ if re.search(pattern, para_lower):
94
+ conf = 0.6
95
+
96
+ # Position-based bias
97
+ if sec_type == 'facts' and position_ratio < 0.30:
98
+ conf += 0.2
99
+ elif sec_type == 'decision' and position_ratio > 0.70:
100
+ conf += 0.3
101
+
102
+ # Strong anchor near paragraph start
103
+ if re.search(pattern, para_lower[:120]):
104
+ conf += 0.2
105
+
106
+ conf = min(conf, 1.0)
107
+
108
+ if conf > best_conf:
109
+ best_type = sec_type
110
+ best_conf = conf
111
+
112
+ return best_type, best_conf
113
+
114
+ def detect_section_ml(self, para: str) -> Tuple[str, float]:
115
+ """Detect using HuggingFace classifier"""
116
+ if not para.strip() or not self.classifier:
117
+ return "unknown", 0.0
118
+
119
+ # Truncate to max length to avoid tokenization errors
120
+ truncated = para[:512]
121
+ result = self.classifier(truncated)[0]
122
+
123
+ # Assume labels are like LABEL_FACTS, LABEL_ISSUES or directly facts, issues
124
+ label = result['label'].lower().replace('label_', '')
125
+ score = result['score']
126
+
127
+ # Enforce confidence threshold
128
+ if score < 0.5:
129
+ return "unknown", score
130
+
131
+ return label, score
132
+
133
+ def segment(self, paragraph_texts: List[str]) -> List[Section]:
134
+ """
135
+ Segment judgment based on paragraph list (INDEX-ALIGNED)
136
+ """
137
+ if not paragraph_texts:
138
+ return []
139
+
140
+ sections: List[Section] = []
141
+
142
+ current_type = 'unknown'
143
+ current_paras = []
144
+ current_conf = 0.0
145
+ start_idx = 0
146
+
147
+ total = len(paragraph_texts)
148
+
149
+ for i, para in enumerate(paragraph_texts):
150
+ position_ratio = i / max(total, 1)
151
+
152
+ if self.use_ml:
153
+ sec_type, conf = self.detect_section_ml(para)
154
+ else:
155
+ sec_type, conf = self.detect_section(para, position_ratio)
156
+
157
+ # Fallback: early unknown paragraphs are likely facts
158
+ if sec_type == 'unknown' and position_ratio < 0.30 and i > 0:
159
+ sec_type = 'facts'
160
+ conf = 0.4
161
+
162
+ # Section boundary
163
+ if conf > 0.4 and sec_type != current_type:
164
+ if current_paras:
165
+ sections.append(
166
+ Section(
167
+ type=current_type,
168
+ text="\n\n".join(current_paras),
169
+ start_para_idx=start_idx,
170
+ end_para_idx=i - 1,
171
+ confidence=round(current_conf, 2)
172
+ )
173
+ )
174
+
175
+ current_type = sec_type
176
+ current_paras = [para]
177
+ current_conf = conf
178
+ start_idx = i
179
+ else:
180
+ current_paras.append(para)
181
+ current_conf = max(current_conf, conf)
182
+
183
+ # Final section
184
+ if current_paras:
185
+ sections.append(
186
+ Section(
187
+ type=current_type,
188
+ text="\n\n".join(current_paras),
189
+ start_para_idx=start_idx,
190
+ end_para_idx=total - 1,
191
+ confidence=round(current_conf, 2)
192
+ )
193
+ )
194
+
195
+ return sections
src/summarization/__pycache__/composer.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
src/summarization/__pycache__/inference.cpython-310.pyc ADDED
Binary file (4.67 kB). View file
 
src/summarization/__pycache__/model.cpython-310.pyc ADDED
Binary file (1.14 kB). View file
 
src/summarization/__pycache__/ranker.cpython-310.pyc ADDED
Binary file (1.45 kB). View file
 
src/summarization/__pycache__/utils.cpython-310.pyc ADDED
Binary file (314 Bytes). View file
 
src/summarization/composer.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/summarization/composer.py
2
+ def compose(sentences, scores, top_k=5):
3
+ """
4
+ Select top-k sentences by Legal-BERT score, then RESTORE their original
5
+ document order before returning. This gives PEGASUS a coherent narrative
6
+ instead of a randomly ordered bag of sentences.
7
+ """
8
+ # Tag each sentence with its original index
9
+ indexed = list(enumerate(zip(sentences, scores)))
10
+
11
+ # Pick top-k by score
12
+ top = sorted(indexed, key=lambda x: x[1][1], reverse=True)[:top_k]
13
+
14
+ # Re-sort by original document position for narrative coherence
15
+ top_in_order = sorted(top, key=lambda x: x[0])
16
+
17
+ return [s for _, (s, _) in top_in_order]
src/summarization/dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/summarization/dataset.py
2
+ import re
3
+ from pathlib import Path
4
+ from datasets import Dataset
5
+ from transformers import AutoTokenizer
6
+ from tqdm import tqdm
7
+
8
+ IMPORTANT_PATTERNS = [
9
+ r"\bheld\b",
10
+ r"\bwe conclude\b",
11
+ r"\btherefore\b",
12
+ r"\bappeal is (allowed|dismissed)\b",
13
+ r"\bsubstantial question\b",
14
+ r"\baccordingly\b",
15
+ ]
16
+
17
+ def sentence_split(text):
18
+ return re.split(r'(?<=[.!?])\s+', text)
19
+
20
+ def is_important(sentence):
21
+ s = sentence.lower()
22
+ return any(re.search(p, s) for p in IMPORTANT_PATTERNS)
23
+
24
+ def build_dataset(text_dir, tokenizer_name, limit=None):
25
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
26
+ samples = []
27
+
28
+ files = list(Path(text_dir).glob("*.txt"))
29
+ if limit:
30
+ files = files[:limit]
31
+
32
+ for file in tqdm(files, desc="Processing judgments"):
33
+ judgment_id = file.stem
34
+ text = file.read_text(encoding="utf-8", errors="ignore")
35
+
36
+ sentences = sentence_split(text)
37
+ for sent in sentences:
38
+ sent = sent.strip()
39
+ if len(sent) < 40:
40
+ continue
41
+
42
+ samples.append({
43
+ "text": sent,
44
+ "label": int(is_important(sent)),
45
+ "judgment_id": judgment_id
46
+ })
47
+
48
+ dataset = Dataset.from_list(samples)
49
+
50
+ def tokenize(batch):
51
+ return tokenizer(
52
+ batch["text"],
53
+ truncation=True,
54
+ padding="max_length",
55
+ max_length=256
56
+ )
57
+
58
+ return dataset.map(tokenize, batched=True)
59
+
60
+ if __name__ == "__main__":
61
+ print("Started dataset building...")
62
+ # Using a limit of 1000 for training, can be increased later
63
+ # 1000 judgments will yield ~50k-100k sentences, good for fine-tuning
64
+ ds = build_dataset(
65
+ "data/processed/extracted/texts",
66
+ "nlpaueb/legal-bert-base-uncased",
67
+ limit=1000
68
+ )
69
+ print("Tokenizing dataset... this may take a moment.")
70
+ print(f"Total sentences extracted: {len(ds)}")
71
+
72
+ print("Saving to Disk...")
73
+ ds.save_to_disk("data/processed/summarization_dataset")
74
+ print("✓ Dataset ready at data/processed/summarization_dataset")
src/summarization/inference.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/summarization/inference.py
2
+ import sys
3
+ import os
4
+ import re
5
+ from pathlib import Path
6
+
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
8
+
9
+ from src.summarization.ranker import ImportanceRanker
10
+ from src.summarization.utils import split_sentences
11
+ from src.segmentation.judgement_segmenter import JudgmentSegmenter
12
+ from transformers import PegasusTokenizer, PegasusForConditionalGeneration
13
+ import torch
14
+
15
+ # ── Model ──────────────────────────────────────────────────────────────────────
16
+ MODEL_NAME = "nsi319/legal-pegasus"
17
+ print(f"\nLoading Abstractive Model ({MODEL_NAME})...")
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ pegasus_tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME)
20
+ pegasus_model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
21
+ print(f"✓ Legal-PEGASUS loaded on {device.upper()}")
22
+
23
+
24
+ def _pegasus_generate(text: str, max_length: int = 300, min_length: int = 100) -> str:
25
+ """Run Legal-PEGASUS on a block of text and return the decoded summary."""
26
+ inputs = pegasus_tokenizer(
27
+ [text],
28
+ max_length=1024,
29
+ truncation=True,
30
+ padding=True,
31
+ return_tensors="pt"
32
+ ).to(device)
33
+
34
+ outputs = pegasus_model.generate(
35
+ inputs["input_ids"],
36
+ max_length=max_length,
37
+ min_length=min_length,
38
+ num_beams=4, # Reduced from 8 for 2x speedup on CPU
39
+ length_penalty=1.2,
40
+ no_repeat_ngram_size=3,
41
+ repetition_penalty=1.3,
42
+ early_stopping=True,
43
+ )
44
+ decoded = pegasus_tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return re.sub(r'\s+', ' ', decoded.replace("<n>", " ")).strip()
46
+
47
+
48
+ def summarize(judgment_file: str) -> dict:
49
+ """
50
+ Speed-Optimized Two-Pass Pipeline:
51
+ 1. Case Overview: Legal-BERT (Extraction) -> Legal-PEGASUS (Abstraction) [1 Pass]
52
+ 2. Detailed Sections: Legal-BERT (Extraction) -> Direct Output [No Abstraction pass to save 5+ minutes]
53
+ """
54
+ text = Path(judgment_file).read_text(encoding="utf-8", errors="ignore")
55
+
56
+ # ── Step 1: Global sentence extraction ─────────────────────────────────────
57
+ all_sentences = [s for s in split_sentences(text) if len(s.strip()) > 40]
58
+ if not all_sentences:
59
+ return {"overview": "Could not extract readable text."}
60
+
61
+ ranker = ImportanceRanker("outputs/summarization/final")
62
+ scores = ranker.score(all_sentences)
63
+
64
+ # Token-Aware Global Overview Extract (Limit to ~950 tokens for Pegasus)
65
+ indexed = list(enumerate(zip(all_sentences, scores)))
66
+ sorted_by_score = sorted(indexed, key=lambda x: x[1][1], reverse=True)
67
+
68
+ selected_indices = []
69
+ current_tokens = 0
70
+ MAX_TOKENS = 950
71
+
72
+ for idx, (sentence, score) in sorted_by_score:
73
+ tokens = len(pegasus_tokenizer.encode(sentence, add_special_tokens=False))
74
+ if current_tokens + tokens > MAX_TOKENS:
75
+ continue
76
+
77
+ selected_indices.append(idx)
78
+ current_tokens += tokens
79
+ if current_tokens >= MAX_TOKENS - 20:
80
+ break
81
+
82
+ # Restore chronological order
83
+ top_in_order = sorted([indexed[i] for i in selected_indices], key=lambda x: x[0])
84
+ global_extract = " ".join(s for _, (s, _) in top_in_order)
85
+
86
+ # ── Pass 1: Abstractive Overview (The only heavy pass) ────────────────────
87
+ print("Generating Case Overview (Abstractive)...")
88
+ overview = _pegasus_generate(global_extract, max_length=250, min_length=80)
89
+
90
+ # ── Pass 2: Extractive Section Breakdown (Instant) ────────────────────────
91
+ segmenter = JudgmentSegmenter()
92
+ paragraphs = [p.strip() for p in text.split("\n\n") if len(p.strip()) > 20]
93
+ sections = segmenter.segment(paragraphs)
94
+
95
+ final_summary = {"overview": overview}
96
+
97
+ print("Generating Section Breakdowns (Extractive - Instant)...")
98
+ for section in sections:
99
+ sec_type = section.type.lower()
100
+ if sec_type == 'unknown': continue
101
+
102
+ sentences = [s for s in split_sentences(section.text) if len(s.strip()) > 40]
103
+ if not sentences: continue
104
+
105
+ sec_scores = ranker.score(sentences)
106
+ # Select top 3 per section for readability
107
+ s_indexed = list(enumerate(zip(sentences, sec_scores)))
108
+ top_k = sorted(s_indexed, key=lambda x: x[1][1], reverse=True)[:3]
109
+ top_k_ordered = sorted(top_k, key=lambda x: x[0])
110
+
111
+ # We use original sentences here to save 5-10 minutes of CPU time
112
+ final_summary[sec_type] = " ".join(s for _, (s, _) in top_k_ordered)
113
+
114
+ return final_summary
115
+
116
+
117
+ if __name__ == "__main__":
118
+ file = list(Path("data/processed/extracted/texts").glob("*.txt"))[0]
119
+ print(f"\nProcessing {file.name}...")
120
+ result = summarize(file)
121
+
122
+ print("\n\nCOMPREHENSIVE LEGAL SUMMARY (Global Legal-BERT + Legal-PEGASUS)\n" + "=" * 80)
123
+ print("\n[CASE OVERVIEW]")
124
+ print(result.get("overview", ""))
125
+ for sec in ['facts', 'issues', 'arguments', 'analysis', 'decision']:
126
+ if sec in result:
127
+ print(f"\n[{sec.upper()}]")
128
+ print(result[sec])
129
+ print("\n" + "=" * 80)
src/summarization/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/summarization/model.py
2
+ from transformers import AutoModel
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class SentenceRanker(nn.Module):
7
+ def __init__(self, model_name):
8
+ super().__init__()
9
+ self.encoder = AutoModel.from_pretrained(model_name)
10
+ self.classifier = nn.Linear(self.encoder.config.hidden_size, 1)
11
+
12
+ def forward(self, input_ids, attention_mask, labels=None, **kwargs):
13
+ out = self.encoder(
14
+ input_ids=input_ids,
15
+ attention_mask=attention_mask,
16
+ **kwargs
17
+ )
18
+ cls = out.last_hidden_state[:, 0]
19
+ logits = self.classifier(cls).squeeze(-1)
20
+
21
+ loss = None
22
+ if labels is not None:
23
+ loss = nn.BCEWithLogitsLoss()(logits, labels.float())
24
+
25
+ return {"loss": loss, "logits": logits}
src/summarization/ranker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/summarization/ranker.py
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from src.summarization.model import SentenceRanker
5
+
6
+ class ImportanceRanker:
7
+ def __init__(self, model_dir, base_model="nlpaueb/legal-bert-base-uncased"):
8
+ # Load the tokenizer from the base model
9
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model)
10
+
11
+ # Initialize the custom architecture with base model
12
+ self.model = SentenceRanker(base_model)
13
+
14
+ # Load fine-tuned weights
15
+ import os
16
+ from safetensors.torch import load_file
17
+
18
+ weights_path = os.path.join(model_dir, "model.safetensors")
19
+ if os.path.exists(weights_path):
20
+ state_dict = load_file(weights_path)
21
+ self.model.load_state_dict(state_dict)
22
+ else:
23
+ print(f"Warning: Could not find {weights_path}")
24
+
25
+ self.model.eval()
26
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ self.model.to(self.device)
28
+
29
+ def score(self, sentences):
30
+ inputs = self.tokenizer(
31
+ sentences,
32
+ truncation=True,
33
+ padding=True,
34
+ return_tensors="pt"
35
+ ).to(self.device)
36
+
37
+ with torch.no_grad():
38
+ logits = self.model(**inputs)["logits"]
39
+
40
+ return logits.sigmoid().cpu().tolist()
src/summarization/train.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/summarization/train.py
2
+
3
+ from datasets import load_from_disk
4
+ from transformers import Trainer, TrainingArguments
5
+ from model import SentenceRanker
6
+ import torch
7
+
8
+ MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
9
+
10
+ def main():
11
+ # Load dataset
12
+ dataset = load_from_disk("data/processed/summarization_dataset")
13
+ dataset = dataset.train_test_split(test_size=0.1)
14
+
15
+ model = SentenceRanker(MODEL_NAME)
16
+
17
+ training_args = TrainingArguments(
18
+ output_dir="outputs/summarization",
19
+ per_device_train_batch_size=16,
20
+ per_device_eval_batch_size=16,
21
+ num_train_epochs=2,
22
+ learning_rate=2e-5,
23
+ logging_steps=500,
24
+ save_steps=2000,
25
+ save_total_limit=2,
26
+ report_to="none",
27
+ fp16=torch.cuda.is_available()
28
+ )
29
+
30
+ trainer = Trainer(
31
+ model=model,
32
+ args=training_args,
33
+ train_dataset=dataset["train"],
34
+ eval_dataset=dataset["test"]
35
+ )
36
+
37
+ trainer.train()
38
+ trainer.save_model("outputs/summarization/final")
39
+
40
+ if __name__ == "__main__":
41
+ main()
src/summarization/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # src/summarization/utils.py
2
+ import re
3
+
4
+ def split_sentences(text):
5
+ return re.split(r'(?<=[.!?])\s+', text)