Spaces:
Running
Running
Sai Pranav Reddy commited on
Commit ·
968e24d
0
Parent(s):
Clean lightweight deployment
Browse files- .gitattributes +36 -0
- Dockerfile +38 -0
- README.md +11 -0
- outputs/summarization/final/model.safetensors +3 -0
- outputs/summarization/final/training_args.bin +3 -0
- requirements.txt +27 -0
- src/api/__pycache__/main.cpython-310.pyc +0 -0
- src/api/main.py +223 -0
- src/evaluation/__pycache__/evaluator.cpython-310.pyc +0 -0
- src/extraction/__pycache__/batch_processor.cpython-310.pyc +0 -0
- src/extraction/__pycache__/pdf_extractor.cpython-310.pyc +0 -0
- src/extraction/batch_processor.py +459 -0
- src/extraction/pdf_extractor.py +522 -0
- src/indexing/build_faiss_index.py +85 -0
- src/indexing/create_embeddings.py +232 -0
- src/indexing/create_sqlite_index.py +196 -0
- src/indexing/paragraph_indexer.py +167 -0
- src/pipeline.py +84 -0
- src/qa/__pycache__/dataset.cpython-310.pyc +0 -0
- src/qa/__pycache__/model.cpython-310.pyc +0 -0
- src/qa/dataset.py +69 -0
- src/qa/inference.py +206 -0
- src/qa/model.py +8 -0
- src/qa/monitor_training.py +21 -0
- src/qa/train.py +42 -0
- src/rag/__pycache__/query_engine.cpython-310.pyc +0 -0
- src/rag/query_engine.py +326 -0
- src/rag/test_retriever.py +75 -0
- src/segmentation/__pycache__/judgement_segmenter.cpython-310.pyc +0 -0
- src/segmentation/annotate_paragraphs.py +146 -0
- src/segmentation/check.py +41 -0
- src/segmentation/judgement_segmenter.py +195 -0
- src/summarization/__pycache__/composer.cpython-310.pyc +0 -0
- src/summarization/__pycache__/inference.cpython-310.pyc +0 -0
- src/summarization/__pycache__/model.cpython-310.pyc +0 -0
- src/summarization/__pycache__/ranker.cpython-310.pyc +0 -0
- src/summarization/__pycache__/utils.cpython-310.pyc +0 -0
- src/summarization/composer.py +17 -0
- src/summarization/dataset.py +74 -0
- src/summarization/inference.py +129 -0
- src/summarization/model.py +25 -0
- src/summarization/ranker.py +40 -0
- src/summarization/train.py +41 -0
- src/summarization/utils.py +5 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory to /app
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
build-essential \
|
| 10 |
+
git \
|
| 11 |
+
git-lfs \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Copy the requirements file into the container
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
|
| 17 |
+
# Install Python dependencies
|
| 18 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 19 |
+
|
| 20 |
+
# Pre-download HF Models to cache them in the Docker image (prevents downloading on every boot)
|
| 21 |
+
RUN python -c "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM; AutoTokenizer.from_pretrained('nsi319/legal-pegasus'); AutoModelForSeq2SeqLM.from_pretrained('nsi319/legal-pegasus');"
|
| 22 |
+
RUN python -c "from sentence_transformers import SentenceTransformer, CrossEncoder; SentenceTransformer('BAAI/bge-base-en-v1.5'); CrossEncoder('BAAI/bge-reranker-base');"
|
| 23 |
+
|
| 24 |
+
# Copy the rest of the application code (including .git if not ignored)
|
| 25 |
+
COPY . .
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Download heavy databases from Dataset into their correct folders
|
| 29 |
+
RUN huggingface-cli download SaiPranav09/NyayLens-Data paragraphs.db --repo-type dataset --local-dir data/processed/indexed/
|
| 30 |
+
RUN huggingface-cli download SaiPranav09/NyayLens-Data faiss_index.bin --repo-type dataset --local-dir data/processed/faiss/
|
| 31 |
+
RUN huggingface-cli download SaiPranav09/NyayLens-Data model.safetensors --repo-type dataset --local-dir outputs/summarization/final/
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Expose port 7860 (Hugging Face standard)
|
| 35 |
+
EXPOSE 7860
|
| 36 |
+
|
| 37 |
+
# Command to run the application using Uvicorn
|
| 38 |
+
CMD ["uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: NyayLens API
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
short_description: NyayLens-Legal AI Assistant
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
outputs/summarization/final/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f3beb2c13e1b5773787c8bb39605511469c4f77d96d2f38cca086c14d352f37
|
| 3 |
+
size 437956140
|
outputs/summarization/final/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bbda47393ea48d83251e18b1d9e470b83cb2fa30ca734d2095ff10fc6b4fd2c6
|
| 3 |
+
size 5368
|
requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NyayLens — Python Backend Dependencies
|
| 2 |
+
# Install with: pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
# API
|
| 5 |
+
fastapi==0.115.12
|
| 6 |
+
uvicorn==0.34.0
|
| 7 |
+
python-multipart==0.0.20
|
| 8 |
+
pydantic==2.12.5
|
| 9 |
+
|
| 10 |
+
# RAG & Retrieval
|
| 11 |
+
faiss-cpu==1.9.0
|
| 12 |
+
sentence-transformers==5.2.0
|
| 13 |
+
groq==1.0.0
|
| 14 |
+
|
| 15 |
+
# Summarization
|
| 16 |
+
transformers==4.57.3
|
| 17 |
+
torch==2.6.0
|
| 18 |
+
sentencepiece==0.2.1
|
| 19 |
+
huggingface-hub==0.36.0
|
| 20 |
+
|
| 21 |
+
# PDF Extraction
|
| 22 |
+
pdfplumber==0.11.8
|
| 23 |
+
|
| 24 |
+
# Utilities
|
| 25 |
+
python-dotenv==1.0.1
|
| 26 |
+
safetensors==0.4.5
|
| 27 |
+
ragas==0.2.14
|
src/api/__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (6.17 kB). View file
|
|
|
src/api/main.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/api/main.py
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import io
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
import atexit
|
| 8 |
+
import shutil
|
| 9 |
+
import asyncio
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Request
|
| 14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from fastapi.responses import JSONResponse
|
| 16 |
+
from pydantic import BaseModel, field_validator
|
| 17 |
+
|
| 18 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
| 19 |
+
|
| 20 |
+
from src.rag.query_engine import QueryEngine
|
| 21 |
+
from src.summarization.inference import summarize
|
| 22 |
+
|
| 23 |
+
# ── Constants ──────────────────────────────────────────────────────────────
|
| 24 |
+
MAX_UPLOAD_MB = 10
|
| 25 |
+
MAX_UPLOAD_BYTES = MAX_UPLOAD_MB * 1024 * 1024
|
| 26 |
+
UPLOAD_DIR = Path("data/uploads")
|
| 27 |
+
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
SUMMARIZE_TIMEOUT_S = 180 # 3 min max for summarization on CPU
|
| 29 |
+
|
| 30 |
+
# ── App ────────────────────────────────────────────────────────────────────
|
| 31 |
+
app = FastAPI(
|
| 32 |
+
title="NyayLens API",
|
| 33 |
+
description="Production API for Legal Chat, Document QA, and Summarization",
|
| 34 |
+
version="1.0.0",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
app.add_middleware(
|
| 38 |
+
CORSMiddleware,
|
| 39 |
+
allow_origins=[
|
| 40 |
+
"https://nyay-lens.vercel.app", # Production Vercel URL
|
| 41 |
+
"http://localhost:5173", # Local Vite dev server
|
| 42 |
+
"http://127.0.0.1:5173"
|
| 43 |
+
],
|
| 44 |
+
allow_credentials=True,
|
| 45 |
+
allow_methods=["*"],
|
| 46 |
+
allow_headers=["*"],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# ── Startup / Shutdown ─────────────────────────────────────────────────────
|
| 50 |
+
async def cleanup_loop():
|
| 51 |
+
"""Background task to remove leftover files older than 2 hours."""
|
| 52 |
+
while True:
|
| 53 |
+
now = time.time()
|
| 54 |
+
for f in UPLOAD_DIR.glob("*"):
|
| 55 |
+
if f.is_file() and (now - f.stat().st_mtime) > 7200:
|
| 56 |
+
try:
|
| 57 |
+
f.unlink()
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Cleanup error: {e}")
|
| 60 |
+
await asyncio.sleep(3600) # Check every hour
|
| 61 |
+
|
| 62 |
+
@app.on_event("startup")
|
| 63 |
+
async def startup():
|
| 64 |
+
global query_engine
|
| 65 |
+
print("Initializing NyayLens Backend...")
|
| 66 |
+
query_engine = QueryEngine()
|
| 67 |
+
|
| 68 |
+
# Start the infinite cleanup loop
|
| 69 |
+
asyncio.create_task(cleanup_loop())
|
| 70 |
+
print("✓ Backend ready. Background cleanup active.")
|
| 71 |
+
|
| 72 |
+
@app.on_event("shutdown")
|
| 73 |
+
def shutdown():
|
| 74 |
+
"""Clean up all uploaded files on server shutdown."""
|
| 75 |
+
if UPLOAD_DIR.exists():
|
| 76 |
+
shutil.rmtree(UPLOAD_DIR)
|
| 77 |
+
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
print("✓ Uploads directory cleaned on shutdown.")
|
| 79 |
+
|
| 80 |
+
# ── Schema ─────────────────────────────────────────────────────────────────
|
| 81 |
+
class UnifiedRequest(BaseModel):
|
| 82 |
+
message: str
|
| 83 |
+
filepath: Optional[str] = None
|
| 84 |
+
top_k: int = 5
|
| 85 |
+
chat_history: Optional[list] = []
|
| 86 |
+
|
| 87 |
+
@field_validator("message")
|
| 88 |
+
@classmethod
|
| 89 |
+
def message_not_empty(cls, v):
|
| 90 |
+
if not v or not v.strip():
|
| 91 |
+
raise ValueError("Message cannot be empty")
|
| 92 |
+
if len(v) > 4000:
|
| 93 |
+
raise ValueError("Message too long (max 4000 characters)")
|
| 94 |
+
return v.strip()
|
| 95 |
+
|
| 96 |
+
# ── Health ─────────────────────────────────────────────────────────────────
|
| 97 |
+
@app.get("/")
|
| 98 |
+
@app.get("/api/health")
|
| 99 |
+
def health():
|
| 100 |
+
return {
|
| 101 |
+
"status": "online",
|
| 102 |
+
"service": "NyayLens API",
|
| 103 |
+
"version": "1.0.0",
|
| 104 |
+
"models": ["Legal-BERT", "Legal-PEGASUS", "Llama-3.1-8B (Groq)"],
|
| 105 |
+
"index": "FAISS 298K vectors",
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# ── Upload ─────────────────────────────────────────────────────────────────
|
| 109 |
+
@app.post("/api/upload")
|
| 110 |
+
async def upload_document(file: UploadFile = File(...)):
|
| 111 |
+
"""
|
| 112 |
+
Accepts .pdf and .txt files up to 10 MB.
|
| 113 |
+
PDFs are extracted to plain text via pdfplumber.
|
| 114 |
+
Returns a server filepath for subsequent /api/chat calls.
|
| 115 |
+
"""
|
| 116 |
+
import pdfplumber
|
| 117 |
+
|
| 118 |
+
# 1. Validate extension
|
| 119 |
+
filename = file.filename or "upload"
|
| 120 |
+
ext = Path(filename).suffix.lower()
|
| 121 |
+
if ext not in {".pdf", ".txt"}:
|
| 122 |
+
raise HTTPException(status_code=400, detail="Only .pdf and .txt files are supported.")
|
| 123 |
+
|
| 124 |
+
# 2. Read with size guard
|
| 125 |
+
raw_bytes = await file.read()
|
| 126 |
+
if len(raw_bytes) > MAX_UPLOAD_BYTES:
|
| 127 |
+
raise HTTPException(
|
| 128 |
+
status_code=413,
|
| 129 |
+
detail=f"File too large. Maximum allowed size is {MAX_UPLOAD_MB} MB."
|
| 130 |
+
)
|
| 131 |
+
if len(raw_bytes) == 0:
|
| 132 |
+
raise HTTPException(status_code=400, detail="Uploaded file is empty.")
|
| 133 |
+
|
| 134 |
+
# 3. Unique name to avoid collisions
|
| 135 |
+
uid = uuid.uuid4().hex[:8]
|
| 136 |
+
safe_name = f"{uid}_{Path(filename).stem}"
|
| 137 |
+
|
| 138 |
+
# 4. Extract / save
|
| 139 |
+
if ext == ".pdf":
|
| 140 |
+
text_parts = []
|
| 141 |
+
try:
|
| 142 |
+
with pdfplumber.open(io.BytesIO(raw_bytes)) as pdf:
|
| 143 |
+
for page in pdf.pages:
|
| 144 |
+
t = page.extract_text()
|
| 145 |
+
if t:
|
| 146 |
+
text_parts.append(t.strip())
|
| 147 |
+
except Exception as e:
|
| 148 |
+
raise HTTPException(status_code=400, detail=f"PDF extraction failed: {e}")
|
| 149 |
+
|
| 150 |
+
if not text_parts:
|
| 151 |
+
raise HTTPException(
|
| 152 |
+
status_code=422,
|
| 153 |
+
detail="PDF contains no readable text. It may be a scanned image — please use a searchable PDF."
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
out_path = UPLOAD_DIR / f"{safe_name}.txt"
|
| 157 |
+
out_path.write_text("\n\n".join(text_parts), encoding="utf-8")
|
| 158 |
+
return {"filepath": str(out_path), "filename": filename, "pages": len(text_parts), "size_kb": round(len(raw_bytes)/1024, 1)}
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
out_path = UPLOAD_DIR / f"{safe_name}.txt"
|
| 162 |
+
out_path.write_bytes(raw_bytes)
|
| 163 |
+
return {"filepath": str(out_path), "filename": filename, "size_kb": round(len(raw_bytes)/1024, 1)}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ── Chat ───────────────────────────────────────────────────────────────────
|
| 167 |
+
@app.post("/api/chat")
|
| 168 |
+
def chat(request: UnifiedRequest):
|
| 169 |
+
"""
|
| 170 |
+
Unified intent-aware chat endpoint.
|
| 171 |
+
Routes to: Summarization | Document QA | Global RAG
|
| 172 |
+
"""
|
| 173 |
+
message_lower = request.message.lower()
|
| 174 |
+
|
| 175 |
+
print(f"\n[BACKEND] '{request.message[:80]}' | file={os.path.basename(request.filepath) if request.filepath else 'None'}")
|
| 176 |
+
|
| 177 |
+
# Validate filepath if provided
|
| 178 |
+
if request.filepath:
|
| 179 |
+
if not os.path.exists(request.filepath):
|
| 180 |
+
return JSONResponse(
|
| 181 |
+
status_code=404,
|
| 182 |
+
content={"answer": "The uploaded document could not be found on the server. Please re-upload the file.", "sources": []}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# ── Route 1: Summarization (with timeout) ──────────────────────────
|
| 187 |
+
if "summarize" in message_lower or "summary" in message_lower:
|
| 188 |
+
if not request.filepath:
|
| 189 |
+
return {
|
| 190 |
+
"answer": "Please **upload a PDF or text file** first using the 📎 button, then ask me to summarize it.",
|
| 191 |
+
"sources": []
|
| 192 |
+
}
|
| 193 |
+
print("[BACKEND] → Summarization pipeline")
|
| 194 |
+
summary_dict = summarize(request.filepath)
|
| 195 |
+
return {
|
| 196 |
+
"answer": "__STRUCTURED_SUMMARY__",
|
| 197 |
+
"summary": summary_dict,
|
| 198 |
+
"sources": [{"judgment_id": os.path.basename(request.filepath), "score": 1.0}]
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
# ── Route 2: Document QA ────────────────────────────────────────────
|
| 202 |
+
if request.filepath:
|
| 203 |
+
print("[BACKEND] → Document QA")
|
| 204 |
+
return query_engine.query_with_document(request.message, request.filepath, chat_history=request.chat_history)
|
| 205 |
+
|
| 206 |
+
# ── Route 3: Global RAG ─────────────────────────────────────────────
|
| 207 |
+
print("[BACKEND] → Global RAG")
|
| 208 |
+
return query_engine.query(request.message, top_k=request.top_k, chat_history=request.chat_history)
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"[BACKEND ERROR] {e}")
|
| 212 |
+
raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ── Cleanup old uploads (files older than 2 hours) ─────────────────────────
|
| 216 |
+
@app.delete("/api/upload/{filename}")
|
| 217 |
+
def delete_upload(filename: str):
|
| 218 |
+
"""Explicit delete for a specific upload."""
|
| 219 |
+
target = UPLOAD_DIR / filename
|
| 220 |
+
if target.exists() and target.is_file():
|
| 221 |
+
target.unlink()
|
| 222 |
+
return {"status": "deleted"}
|
| 223 |
+
raise HTTPException(status_code=404, detail="File not found.")
|
src/evaluation/__pycache__/evaluator.cpython-310.pyc
ADDED
|
Binary file (7.86 kB). View file
|
|
|
src/extraction/__pycache__/batch_processor.cpython-310.pyc
ADDED
|
Binary file (6.6 kB). View file
|
|
|
src/extraction/__pycache__/pdf_extractor.cpython-310.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
src/extraction/batch_processor.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# Batch processor for large-scale PDF extraction
|
| 3 |
+
# Includes progress tracking, error handling, and resumability
|
| 4 |
+
# """
|
| 5 |
+
|
| 6 |
+
# from pathlib import Path
|
| 7 |
+
# from typing import List, Dict
|
| 8 |
+
# import logging
|
| 9 |
+
# from tqdm import tqdm
|
| 10 |
+
# import json
|
| 11 |
+
# from datetime import datetime
|
| 12 |
+
# from pdf_extractor import LegalJudgmentExtractor
|
| 13 |
+
|
| 14 |
+
# logging.basicConfig(
|
| 15 |
+
# level=logging.INFO,
|
| 16 |
+
# format='%(asctime)s - %(levelname)s - %(message)s'
|
| 17 |
+
# )
|
| 18 |
+
# logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# class BatchProcessor:
|
| 22 |
+
# """
|
| 23 |
+
# Batch processor for extracting text from thousands of legal judgments
|
| 24 |
+
# Features: Progress tracking, resumability, comprehensive reporting
|
| 25 |
+
# """
|
| 26 |
+
|
| 27 |
+
# def __init__(self,
|
| 28 |
+
# input_dir: Path,
|
| 29 |
+
# output_dir: Path,
|
| 30 |
+
# num_workers: int = 1,
|
| 31 |
+
# enable_ocr: bool = True):
|
| 32 |
+
|
| 33 |
+
# self.input_dir = Path(input_dir)
|
| 34 |
+
# self.output_dir = Path(output_dir)
|
| 35 |
+
# self.num_workers = num_workers
|
| 36 |
+
|
| 37 |
+
# # Create extractor
|
| 38 |
+
# self.extractor = LegalJudgmentExtractor(output_dir, enable_ocr=enable_ocr)
|
| 39 |
+
|
| 40 |
+
# # Progress tracking
|
| 41 |
+
# self.progress_file = output_dir / "processing_progress.json"
|
| 42 |
+
# self.stats_file = output_dir / "processing_stats.json"
|
| 43 |
+
|
| 44 |
+
# def get_all_pdfs(self) -> List[Path]:
|
| 45 |
+
# """Get all PDF files from input directory"""
|
| 46 |
+
# return sorted(self.input_dir.rglob("*.pdf"))
|
| 47 |
+
|
| 48 |
+
# def load_progress(self) -> set:
|
| 49 |
+
# """Load already processed files"""
|
| 50 |
+
# if self.progress_file.exists():
|
| 51 |
+
# with open(self.progress_file, 'r') as f:
|
| 52 |
+
# data = json.load(f)
|
| 53 |
+
# return set(data.get('processed_files', []))
|
| 54 |
+
# return set()
|
| 55 |
+
|
| 56 |
+
# def save_progress(self, processed_files: set):
|
| 57 |
+
# """Save processing progress"""
|
| 58 |
+
# with open(self.progress_file, 'w') as f:
|
| 59 |
+
# json.dump({
|
| 60 |
+
# 'processed_files': list(processed_files),
|
| 61 |
+
# 'last_updated': datetime.now().isoformat(),
|
| 62 |
+
# 'total_processed': len(processed_files)
|
| 63 |
+
# }, f, indent=2)
|
| 64 |
+
|
| 65 |
+
# def process_single_pdf(self, pdf_path: Path) -> Dict:
|
| 66 |
+
# """Process a single PDF"""
|
| 67 |
+
# try:
|
| 68 |
+
# success = self.extractor.process_pdf(pdf_path)
|
| 69 |
+
# return {
|
| 70 |
+
# 'filename': pdf_path.name,
|
| 71 |
+
# 'year': pdf_path.parent.name,
|
| 72 |
+
# 'success': success,
|
| 73 |
+
# 'error': None
|
| 74 |
+
# }
|
| 75 |
+
# except Exception as e:
|
| 76 |
+
# return {
|
| 77 |
+
# 'filename': pdf_path.name,
|
| 78 |
+
# 'year': pdf_path.parent.name,
|
| 79 |
+
# 'success': False,
|
| 80 |
+
# 'error': str(e)
|
| 81 |
+
# }
|
| 82 |
+
|
| 83 |
+
# def process_batch(self,
|
| 84 |
+
# start_year: int = None,
|
| 85 |
+
# end_year: int = None,
|
| 86 |
+
# limit: int = None,
|
| 87 |
+
# resume: bool = True):
|
| 88 |
+
# """
|
| 89 |
+
# Process PDFs in batch with progress tracking
|
| 90 |
+
|
| 91 |
+
# Args:
|
| 92 |
+
# start_year: Start from this year (inclusive)
|
| 93 |
+
# end_year: Process until this year (inclusive)
|
| 94 |
+
# limit: Maximum number of PDFs to process
|
| 95 |
+
# resume: Continue from last checkpoint
|
| 96 |
+
# """
|
| 97 |
+
|
| 98 |
+
# logger.info("Starting batch processing...")
|
| 99 |
+
# logger.info(f"Workers: {self.num_workers}")
|
| 100 |
+
|
| 101 |
+
# # Get all PDFs
|
| 102 |
+
# all_pdfs = self.get_all_pdfs()
|
| 103 |
+
# logger.info(f"Found {len(all_pdfs):,} PDFs")
|
| 104 |
+
|
| 105 |
+
# # Filter by year if specified
|
| 106 |
+
# if start_year or end_year:
|
| 107 |
+
# all_pdfs = [
|
| 108 |
+
# p for p in all_pdfs
|
| 109 |
+
# if (not start_year or int(p.parent.name) >= start_year) and
|
| 110 |
+
# (not end_year or int(p.parent.name) <= end_year)
|
| 111 |
+
# ]
|
| 112 |
+
# logger.info(f"Filtered to {len(all_pdfs):,} PDFs (years {start_year}-{end_year})")
|
| 113 |
+
|
| 114 |
+
# # Load progress and filter already processed
|
| 115 |
+
# if resume:
|
| 116 |
+
# processed = self.load_progress()
|
| 117 |
+
# all_pdfs = [p for p in all_pdfs if str(p) not in processed]
|
| 118 |
+
# logger.info(f"Resuming: {len(all_pdfs):,} PDFs remaining")
|
| 119 |
+
# else:
|
| 120 |
+
# processed = set()
|
| 121 |
+
|
| 122 |
+
# # Apply limit
|
| 123 |
+
# if limit:
|
| 124 |
+
# all_pdfs = all_pdfs[:limit]
|
| 125 |
+
# logger.info(f"Limited to {len(all_pdfs):,} PDFs")
|
| 126 |
+
|
| 127 |
+
# if not all_pdfs:
|
| 128 |
+
# logger.info("No PDFs to process!")
|
| 129 |
+
# return
|
| 130 |
+
|
| 131 |
+
# # Initialize stats
|
| 132 |
+
# stats = {
|
| 133 |
+
# 'total': len(all_pdfs),
|
| 134 |
+
# 'successful': 0,
|
| 135 |
+
# 'failed': 0,
|
| 136 |
+
# 'start_time': datetime.now().isoformat(),
|
| 137 |
+
# 'failed_files': []
|
| 138 |
+
# }
|
| 139 |
+
|
| 140 |
+
# # Process with progress bar
|
| 141 |
+
# with tqdm(total=len(all_pdfs), desc="Processing PDFs") as pbar:
|
| 142 |
+
# for pdf_path in all_pdfs:
|
| 143 |
+
# result = self.process_single_pdf(pdf_path)
|
| 144 |
+
|
| 145 |
+
# if result['success']:
|
| 146 |
+
# stats['successful'] += 1
|
| 147 |
+
# else:
|
| 148 |
+
# stats['failed'] += 1
|
| 149 |
+
# stats['failed_files'].append({
|
| 150 |
+
# 'file': result['filename'],
|
| 151 |
+
# 'error': result['error']
|
| 152 |
+
# })
|
| 153 |
+
# logger.warning(f"Failed: {result['filename']} - {result['error']}")
|
| 154 |
+
|
| 155 |
+
# # Update progress
|
| 156 |
+
# processed.add(str(pdf_path))
|
| 157 |
+
|
| 158 |
+
# # Save progress every 50 files
|
| 159 |
+
# if len(processed) % 50 == 0:
|
| 160 |
+
# self.save_progress(processed)
|
| 161 |
+
|
| 162 |
+
# pbar.update(1)
|
| 163 |
+
# pbar.set_postfix({
|
| 164 |
+
# 'Success': stats['successful'],
|
| 165 |
+
# 'Failed': stats['failed']
|
| 166 |
+
# })
|
| 167 |
+
|
| 168 |
+
# # Final save
|
| 169 |
+
# self.save_progress(processed)
|
| 170 |
+
|
| 171 |
+
# # Save statistics
|
| 172 |
+
# stats['end_time'] = datetime.now().isoformat()
|
| 173 |
+
# stats['success_rate'] = (stats['successful'] / stats['total'] * 100) if stats['total'] > 0 else 0
|
| 174 |
+
|
| 175 |
+
# with open(self.stats_file, 'w') as f:
|
| 176 |
+
# json.dump(stats, f, indent=2)
|
| 177 |
+
|
| 178 |
+
# # Summary
|
| 179 |
+
# logger.info("\n" + "="*60)
|
| 180 |
+
# logger.info("PROCESSING COMPLETE")
|
| 181 |
+
# logger.info("="*60)
|
| 182 |
+
# logger.info(f"Total processed: {stats['total']:,}")
|
| 183 |
+
# logger.info(f"Successful: {stats['successful']:,}")
|
| 184 |
+
# logger.info(f"Failed: {stats['failed']:,}")
|
| 185 |
+
# logger.info(f"Success rate: {stats['success_rate']:.2f}%")
|
| 186 |
+
# logger.info("="*60)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# def main():
|
| 190 |
+
# """Main execution with CLI arguments"""
|
| 191 |
+
# import argparse
|
| 192 |
+
|
| 193 |
+
# parser = argparse.ArgumentParser(description='Batch process legal judgment PDFs')
|
| 194 |
+
# parser.add_argument('--start-year', type=int, help='Start year (inclusive)')
|
| 195 |
+
# parser.add_argument('--end-year', type=int, help='End year (inclusive)')
|
| 196 |
+
# parser.add_argument('--limit', type=int, help='Maximum PDFs to process')
|
| 197 |
+
# parser.add_argument('--no-resume', action='store_true', help='Start fresh')
|
| 198 |
+
# parser.add_argument('--no-ocr', action='store_true', help='Disable OCR fallback')
|
| 199 |
+
|
| 200 |
+
# args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
# # Configuration
|
| 203 |
+
# INPUT_DIR = Path("data/raw")
|
| 204 |
+
# OUTPUT_DIR = Path("data/processed/extracted")
|
| 205 |
+
|
| 206 |
+
# processor = BatchProcessor(
|
| 207 |
+
# input_dir=INPUT_DIR,
|
| 208 |
+
# output_dir=OUTPUT_DIR,
|
| 209 |
+
# enable_ocr=not args.no_ocr
|
| 210 |
+
# )
|
| 211 |
+
|
| 212 |
+
# processor.process_batch(
|
| 213 |
+
# start_year=args.start_year,
|
| 214 |
+
# end_year=args.end_year,
|
| 215 |
+
# limit=args.limit,
|
| 216 |
+
# resume=not args.no_resume
|
| 217 |
+
# )
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# if __name__ == "__main__":
|
| 221 |
+
# main()
|
| 222 |
+
"""
|
| 223 |
+
Batch processor for large-scale PDF extraction
|
| 224 |
+
Includes progress tracking, error handling, and resumability
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
from pathlib import Path
|
| 228 |
+
from typing import List, Dict
|
| 229 |
+
import logging
|
| 230 |
+
from tqdm import tqdm
|
| 231 |
+
import json
|
| 232 |
+
from datetime import datetime
|
| 233 |
+
from pdf_extractor import LegalJudgmentExtractor
|
| 234 |
+
|
| 235 |
+
logging.basicConfig(
|
| 236 |
+
level=logging.INFO,
|
| 237 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 238 |
+
)
|
| 239 |
+
logger = logging.getLogger(__name__)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class BatchProcessor:
|
| 243 |
+
"""
|
| 244 |
+
Batch processor for extracting text from thousands of legal judgments
|
| 245 |
+
Features: Progress tracking, resumability, comprehensive reporting
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self,
|
| 249 |
+
input_dir: Path,
|
| 250 |
+
output_dir: Path,
|
| 251 |
+
num_workers: int = 1,
|
| 252 |
+
enable_ocr: bool = True):
|
| 253 |
+
|
| 254 |
+
self.input_dir = Path(input_dir)
|
| 255 |
+
self.output_dir = Path(output_dir)
|
| 256 |
+
self.num_workers = num_workers
|
| 257 |
+
|
| 258 |
+
# Create extractor
|
| 259 |
+
self.extractor = LegalJudgmentExtractor(output_dir, enable_ocr=enable_ocr)
|
| 260 |
+
|
| 261 |
+
# Progress tracking
|
| 262 |
+
self.progress_file = output_dir / "processing_progress.json"
|
| 263 |
+
self.stats_file = output_dir / "processing_stats.json"
|
| 264 |
+
|
| 265 |
+
def get_all_pdfs(self) -> List[Path]:
|
| 266 |
+
"""Get all PDF files (case-insensitive)"""
|
| 267 |
+
# FIX: Search for both .pdf and .PDF
|
| 268 |
+
pdfs_lower = list(self.input_dir.rglob("*.pdf"))
|
| 269 |
+
pdfs_upper = list(self.input_dir.rglob("*.PDF"))
|
| 270 |
+
all_pdfs = pdfs_lower + pdfs_upper
|
| 271 |
+
return sorted(set(all_pdfs)) # Remove duplicates and sort
|
| 272 |
+
|
| 273 |
+
def load_progress(self) -> set:
|
| 274 |
+
"""Load already processed files"""
|
| 275 |
+
if self.progress_file.exists():
|
| 276 |
+
with open(self.progress_file, 'r') as f:
|
| 277 |
+
data = json.load(f)
|
| 278 |
+
return set(data.get('processed_files', []))
|
| 279 |
+
return set()
|
| 280 |
+
|
| 281 |
+
def save_progress(self, processed_files: set):
|
| 282 |
+
"""Save processing progress"""
|
| 283 |
+
with open(self.progress_file, 'w') as f:
|
| 284 |
+
json.dump({
|
| 285 |
+
'processed_files': list(processed_files),
|
| 286 |
+
'last_updated': datetime.now().isoformat(),
|
| 287 |
+
'total_processed': len(processed_files)
|
| 288 |
+
}, f, indent=2)
|
| 289 |
+
|
| 290 |
+
def process_single_pdf(self, pdf_path: Path) -> Dict:
|
| 291 |
+
"""Process a single PDF"""
|
| 292 |
+
try:
|
| 293 |
+
success = self.extractor.process_pdf(pdf_path)
|
| 294 |
+
return {
|
| 295 |
+
'filename': pdf_path.name,
|
| 296 |
+
'year': pdf_path.parent.name,
|
| 297 |
+
'success': success,
|
| 298 |
+
'error': None
|
| 299 |
+
}
|
| 300 |
+
except Exception as e:
|
| 301 |
+
return {
|
| 302 |
+
'filename': pdf_path.name,
|
| 303 |
+
'year': pdf_path.parent.name,
|
| 304 |
+
'success': False,
|
| 305 |
+
'error': str(e)
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
def process_batch(self,
|
| 309 |
+
start_year: int = None,
|
| 310 |
+
end_year: int = None,
|
| 311 |
+
limit: int = None,
|
| 312 |
+
resume: bool = True):
|
| 313 |
+
"""
|
| 314 |
+
Process PDFs in batch with progress tracking
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
start_year: Start from this year (inclusive)
|
| 318 |
+
end_year: Process until this year (inclusive)
|
| 319 |
+
limit: Maximum number of PDFs to process
|
| 320 |
+
resume: Continue from last checkpoint
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
logger.info("Starting batch processing...")
|
| 324 |
+
logger.info(f"Input directory: {self.input_dir}")
|
| 325 |
+
logger.info(f"Output directory: {self.output_dir}")
|
| 326 |
+
|
| 327 |
+
# Get all PDFs
|
| 328 |
+
all_pdfs = self.get_all_pdfs()
|
| 329 |
+
logger.info(f"Found {len(all_pdfs):,} PDFs")
|
| 330 |
+
|
| 331 |
+
if len(all_pdfs) == 0:
|
| 332 |
+
logger.error("❌ No PDFs found! Check your data/raw directory.")
|
| 333 |
+
logger.error(f"Looking in: {self.input_dir}")
|
| 334 |
+
logger.error("Make sure PDFs are in year folders like: data/raw/1950/*.PDF")
|
| 335 |
+
return
|
| 336 |
+
|
| 337 |
+
# Filter by year if specified
|
| 338 |
+
if start_year or end_year:
|
| 339 |
+
filtered_pdfs = []
|
| 340 |
+
for p in all_pdfs:
|
| 341 |
+
try:
|
| 342 |
+
year = int(p.parent.name)
|
| 343 |
+
if (not start_year or year >= start_year) and (not end_year or year <= end_year):
|
| 344 |
+
filtered_pdfs.append(p)
|
| 345 |
+
except ValueError:
|
| 346 |
+
logger.warning(f"Skipping non-year folder: {p.parent.name}")
|
| 347 |
+
|
| 348 |
+
all_pdfs = filtered_pdfs
|
| 349 |
+
logger.info(f"Filtered to {len(all_pdfs):,} PDFs (years {start_year}-{end_year})")
|
| 350 |
+
|
| 351 |
+
# Load progress and filter already processed
|
| 352 |
+
if resume:
|
| 353 |
+
processed = self.load_progress()
|
| 354 |
+
all_pdfs = [p for p in all_pdfs if str(p) not in processed]
|
| 355 |
+
logger.info(f"Resuming: {len(all_pdfs):,} PDFs remaining")
|
| 356 |
+
else:
|
| 357 |
+
processed = set()
|
| 358 |
+
|
| 359 |
+
# Apply limit
|
| 360 |
+
if limit:
|
| 361 |
+
all_pdfs = all_pdfs[:limit]
|
| 362 |
+
logger.info(f"Limited to {len(all_pdfs):,} PDFs")
|
| 363 |
+
|
| 364 |
+
if not all_pdfs:
|
| 365 |
+
logger.info("No PDFs to process!")
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
# Initialize stats
|
| 369 |
+
stats = {
|
| 370 |
+
'total': len(all_pdfs),
|
| 371 |
+
'successful': 0,
|
| 372 |
+
'failed': 0,
|
| 373 |
+
'start_time': datetime.now().isoformat(),
|
| 374 |
+
'failed_files': []
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
# Process with progress bar
|
| 378 |
+
with tqdm(total=len(all_pdfs), desc="Processing PDFs") as pbar:
|
| 379 |
+
for pdf_path in all_pdfs:
|
| 380 |
+
result = self.process_single_pdf(pdf_path)
|
| 381 |
+
|
| 382 |
+
if result['success']:
|
| 383 |
+
stats['successful'] += 1
|
| 384 |
+
else:
|
| 385 |
+
stats['failed'] += 1
|
| 386 |
+
stats['failed_files'].append({
|
| 387 |
+
'file': result['filename'],
|
| 388 |
+
'year': result['year'],
|
| 389 |
+
'error': result['error']
|
| 390 |
+
})
|
| 391 |
+
logger.warning(f"Failed: {result['filename']} - {result['error']}")
|
| 392 |
+
|
| 393 |
+
# Update progress
|
| 394 |
+
processed.add(str(pdf_path))
|
| 395 |
+
|
| 396 |
+
# Save progress every 50 files
|
| 397 |
+
if len(processed) % 50 == 0:
|
| 398 |
+
self.save_progress(processed)
|
| 399 |
+
|
| 400 |
+
pbar.update(1)
|
| 401 |
+
pbar.set_postfix({
|
| 402 |
+
'Success': stats['successful'],
|
| 403 |
+
'Failed': stats['failed']
|
| 404 |
+
})
|
| 405 |
+
|
| 406 |
+
# Final save
|
| 407 |
+
self.save_progress(processed)
|
| 408 |
+
|
| 409 |
+
# Save statistics
|
| 410 |
+
stats['end_time'] = datetime.now().isoformat()
|
| 411 |
+
stats['success_rate'] = (stats['successful'] / stats['total'] * 100) if stats['total'] > 0 else 0
|
| 412 |
+
|
| 413 |
+
with open(self.stats_file, 'w') as f:
|
| 414 |
+
json.dump(stats, f, indent=2)
|
| 415 |
+
|
| 416 |
+
# Summary
|
| 417 |
+
logger.info("\n" + "="*60)
|
| 418 |
+
logger.info("PROCESSING COMPLETE")
|
| 419 |
+
logger.info("="*60)
|
| 420 |
+
logger.info(f"Total processed: {stats['total']:,}")
|
| 421 |
+
logger.info(f"Successful: {stats['successful']:,}")
|
| 422 |
+
logger.info(f"Failed: {stats['failed']:,}")
|
| 423 |
+
logger.info(f"Success rate: {stats['success_rate']:.2f}%")
|
| 424 |
+
logger.info("="*60)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def main():
|
| 428 |
+
"""Main execution with CLI arguments"""
|
| 429 |
+
import argparse
|
| 430 |
+
|
| 431 |
+
parser = argparse.ArgumentParser(description='Batch process legal judgment PDFs')
|
| 432 |
+
parser.add_argument('--start-year', type=int, help='Start year (inclusive)')
|
| 433 |
+
parser.add_argument('--end-year', type=int, help='End year (inclusive)')
|
| 434 |
+
parser.add_argument('--limit', type=int, help='Maximum PDFs to process')
|
| 435 |
+
parser.add_argument('--no-resume', action='store_true', help='Start fresh')
|
| 436 |
+
parser.add_argument('--no-ocr', action='store_true', help='Disable OCR fallback')
|
| 437 |
+
|
| 438 |
+
args = parser.parse_args()
|
| 439 |
+
|
| 440 |
+
# Configuration
|
| 441 |
+
INPUT_DIR = Path("data/raw")
|
| 442 |
+
OUTPUT_DIR = Path("data/processed/extracted")
|
| 443 |
+
|
| 444 |
+
processor = BatchProcessor(
|
| 445 |
+
input_dir=INPUT_DIR,
|
| 446 |
+
output_dir=OUTPUT_DIR,
|
| 447 |
+
enable_ocr=not args.no_ocr
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
processor.process_batch(
|
| 451 |
+
start_year=args.start_year,
|
| 452 |
+
end_year=args.end_year,
|
| 453 |
+
limit=args.limit,
|
| 454 |
+
resume=not args.no_resume
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
if __name__ == "__main__":
|
| 459 |
+
main()
|
src/extraction/pdf_extractor.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Production PDF Extractor for Legal Judgments
|
| 3 |
+
Enhanced with robust error handling, quality checks, and paragraph preservation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import PyPDF2
|
| 7 |
+
import pdfplumber
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Optional, List, Tuple
|
| 10 |
+
import logging
|
| 11 |
+
from dataclasses import dataclass, asdict
|
| 12 |
+
import json
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import re
|
| 15 |
+
|
| 16 |
+
# OCR imports
|
| 17 |
+
try:
|
| 18 |
+
import pytesseract
|
| 19 |
+
from pdf2image import convert_from_path
|
| 20 |
+
OCR_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
OCR_AVAILABLE = False
|
| 23 |
+
logging.warning("OCR libraries not installed. OCR fallback disabled.")
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ExtractionMetadata:
|
| 31 |
+
"""Metadata for extracted judgment"""
|
| 32 |
+
filename: str
|
| 33 |
+
year: str
|
| 34 |
+
num_pages: int
|
| 35 |
+
text_length: int
|
| 36 |
+
extraction_method: str
|
| 37 |
+
has_text: bool
|
| 38 |
+
extraction_timestamp: str
|
| 39 |
+
file_size_bytes: int
|
| 40 |
+
ocr_used: bool
|
| 41 |
+
quality_score: float
|
| 42 |
+
paragraph_count: int
|
| 43 |
+
errors: List[str]
|
| 44 |
+
warnings: List[str]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TextQualityChecker:
|
| 48 |
+
"""Utility class for assessing extracted text quality"""
|
| 49 |
+
|
| 50 |
+
# Legal keywords to preserve even in short lines
|
| 51 |
+
LEGAL_KEYWORDS = {
|
| 52 |
+
'held', 'order', 'appeal', 'writ', 'judgment', 'decree',
|
| 53 |
+
'petition', 'application', 'allowed', 'dismissed', 'granted',
|
| 54 |
+
'rejected', 'reserved', 'disposed', 'quashed', 'set aside',
|
| 55 |
+
'affirmed', 'reversed', 'remanded', 'suo moto', 'ex parte',
|
| 56 |
+
'interim', 'stay', 'injunction', 'bail', 'custody', 'liberty',
|
| 57 |
+
'notice', 'respondent', 'petitioner', 'appellant', 'accused'
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def calculate_quality_score(text: str) -> Tuple[float, List[str]]:
|
| 62 |
+
"""
|
| 63 |
+
Calculate quality score (0-1) for extracted text
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
(score, issues_found)
|
| 67 |
+
"""
|
| 68 |
+
if not text or len(text.strip()) < 100:
|
| 69 |
+
return 0.0, ["Text too short"]
|
| 70 |
+
|
| 71 |
+
issues = []
|
| 72 |
+
score = 1.0
|
| 73 |
+
|
| 74 |
+
# Check 1: Alphabetic character ratio
|
| 75 |
+
alpha_chars = sum(c.isalpha() for c in text)
|
| 76 |
+
total_chars = len(text.replace('\n', '').replace(' ', ''))
|
| 77 |
+
|
| 78 |
+
if total_chars > 0:
|
| 79 |
+
alpha_ratio = alpha_chars / total_chars
|
| 80 |
+
if alpha_ratio < 0.5:
|
| 81 |
+
score -= 0.3
|
| 82 |
+
issues.append(f"Low alphabetic ratio: {alpha_ratio:.2f}")
|
| 83 |
+
|
| 84 |
+
# Check 2: Average word length (gibberish detection)
|
| 85 |
+
words = text.split()
|
| 86 |
+
if words:
|
| 87 |
+
avg_word_len = sum(len(w) for w in words) / len(words)
|
| 88 |
+
if avg_word_len < 2 or avg_word_len > 15:
|
| 89 |
+
score -= 0.2
|
| 90 |
+
issues.append(f"Unusual avg word length: {avg_word_len:.1f}")
|
| 91 |
+
|
| 92 |
+
# Check 3: Check for repeated patterns (OCR errors)
|
| 93 |
+
lines = text.split('\n')
|
| 94 |
+
if len(lines) > 10:
|
| 95 |
+
unique_lines = len(set(line.strip() for line in lines if line.strip()))
|
| 96 |
+
repetition_ratio = unique_lines / len(lines)
|
| 97 |
+
if repetition_ratio < 0.3:
|
| 98 |
+
score -= 0.2
|
| 99 |
+
issues.append(f"High repetition: {repetition_ratio:.2f}")
|
| 100 |
+
|
| 101 |
+
# Check 4: Minimum sentence structure
|
| 102 |
+
sentence_markers = text.count('.') + text.count('?') + text.count('!')
|
| 103 |
+
if len(words) > 100 and sentence_markers < len(words) / 50:
|
| 104 |
+
score -= 0.1
|
| 105 |
+
issues.append("Lacks sentence structure")
|
| 106 |
+
|
| 107 |
+
return max(0.0, min(1.0, score)), issues
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def clean_ocr_text(text: str) -> str:
|
| 111 |
+
"""
|
| 112 |
+
Normalize OCR-extracted text with legal-aware filtering
|
| 113 |
+
- Remove excessive whitespace
|
| 114 |
+
- Collapse multiple newlines
|
| 115 |
+
- Remove repeated headers
|
| 116 |
+
- Preserve important legal terms
|
| 117 |
+
"""
|
| 118 |
+
# Collapse multiple spaces
|
| 119 |
+
text = re.sub(r' +', ' ', text)
|
| 120 |
+
|
| 121 |
+
# Collapse multiple newlines (keep max 2 for paragraph breaks)
|
| 122 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 123 |
+
|
| 124 |
+
# Remove common OCR artifacts
|
| 125 |
+
text = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f]', '', text)
|
| 126 |
+
|
| 127 |
+
# Legal-aware line filtering
|
| 128 |
+
lines = text.split('\n')
|
| 129 |
+
cleaned_lines = []
|
| 130 |
+
|
| 131 |
+
for line in lines:
|
| 132 |
+
stripped = line.strip()
|
| 133 |
+
|
| 134 |
+
# Skip empty lines
|
| 135 |
+
if not stripped:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
# Skip pure numbers (page numbers)
|
| 139 |
+
if stripped.isdigit():
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
# PRESERVE if:
|
| 143 |
+
# 1. Line is substantial (>10 chars)
|
| 144 |
+
# 2. Contains legal keyword (even if short like "Held.")
|
| 145 |
+
# 3. Is alphabetic and reasonable length (>3 chars)
|
| 146 |
+
if (len(stripped) > 10 or
|
| 147 |
+
any(keyword in stripped.lower() for keyword in TextQualityChecker.LEGAL_KEYWORDS) or
|
| 148 |
+
(stripped.replace('.', '').replace(',', '').isalpha() and len(stripped) > 3)):
|
| 149 |
+
cleaned_lines.append(line)
|
| 150 |
+
|
| 151 |
+
text = '\n'.join(cleaned_lines)
|
| 152 |
+
|
| 153 |
+
# Remove repeated header patterns
|
| 154 |
+
lines = text.split('\n')
|
| 155 |
+
result = []
|
| 156 |
+
prev_line = None
|
| 157 |
+
repeat_count = 0
|
| 158 |
+
|
| 159 |
+
for line in lines:
|
| 160 |
+
if line.strip() == prev_line and prev_line:
|
| 161 |
+
repeat_count += 1
|
| 162 |
+
if repeat_count < 2: # Allow max 2 repetitions
|
| 163 |
+
result.append(line)
|
| 164 |
+
else:
|
| 165 |
+
repeat_count = 0
|
| 166 |
+
result.append(line)
|
| 167 |
+
prev_line = line.strip()
|
| 168 |
+
|
| 169 |
+
return '\n'.join(result).strip()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class LegalJudgmentExtractor:
|
| 173 |
+
"""
|
| 174 |
+
Production-grade extractor with robust error handling and quality assurance
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, output_dir: Path, enable_ocr: bool = True, ocr_max_pages: int = 50):
|
| 178 |
+
self.output_dir = Path(output_dir)
|
| 179 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 180 |
+
self.enable_ocr = enable_ocr and OCR_AVAILABLE
|
| 181 |
+
self.ocr_max_pages = ocr_max_pages
|
| 182 |
+
|
| 183 |
+
if enable_ocr and not OCR_AVAILABLE:
|
| 184 |
+
logger.warning("OCR requested but libraries not installed.")
|
| 185 |
+
|
| 186 |
+
# Create subdirectories
|
| 187 |
+
self.text_dir = self.output_dir / "texts"
|
| 188 |
+
self.metadata_dir = self.output_dir / "metadata"
|
| 189 |
+
self.failed_dir = self.output_dir / "failed"
|
| 190 |
+
self.ocr_log_file = self.output_dir / "ocr_cases.jsonl"
|
| 191 |
+
|
| 192 |
+
for dir_path in [self.text_dir, self.metadata_dir, self.failed_dir]:
|
| 193 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
def extract_year_from_path(self, pdf_path: Path) -> Tuple[str, List[str]]:
|
| 196 |
+
"""
|
| 197 |
+
Safely extract year from path with validation
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
(year, warnings)
|
| 201 |
+
"""
|
| 202 |
+
warnings = []
|
| 203 |
+
year = pdf_path.parent.name
|
| 204 |
+
|
| 205 |
+
# Validate year
|
| 206 |
+
if not year.isdigit():
|
| 207 |
+
warnings.append(f"Invalid year from directory: {year}")
|
| 208 |
+
|
| 209 |
+
# Try to extract from filename
|
| 210 |
+
filename = pdf_path.stem
|
| 211 |
+
year_match = re.search(r'(19|20)\d{2}', filename)
|
| 212 |
+
if year_match:
|
| 213 |
+
year = year_match.group(0)
|
| 214 |
+
warnings.append(f"Year extracted from filename: {year}")
|
| 215 |
+
else:
|
| 216 |
+
year = "unknown"
|
| 217 |
+
warnings.append("Could not determine year")
|
| 218 |
+
else:
|
| 219 |
+
# Validate year range
|
| 220 |
+
year_int = int(year)
|
| 221 |
+
if year_int < 1950 or year_int > 2025:
|
| 222 |
+
warnings.append(f"Year {year} outside expected range (1950-2025)")
|
| 223 |
+
|
| 224 |
+
return year, warnings
|
| 225 |
+
|
| 226 |
+
def count_paragraphs(self, text: str) -> int:
|
| 227 |
+
"""Count paragraph-like structures in text"""
|
| 228 |
+
# Split by double newlines
|
| 229 |
+
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
| 230 |
+
# Filter out very short "paragraphs" (likely headers)
|
| 231 |
+
substantial_paragraphs = [p for p in paragraphs if len(p) > 50]
|
| 232 |
+
return len(substantial_paragraphs)
|
| 233 |
+
|
| 234 |
+
def extract_with_pypdf2(self, pdf_path: Path) -> Optional[str]:
|
| 235 |
+
"""Primary extraction - preserves paragraph structure"""
|
| 236 |
+
try:
|
| 237 |
+
with open(pdf_path, 'rb') as file:
|
| 238 |
+
reader = PyPDF2.PdfReader(file)
|
| 239 |
+
text_parts = []
|
| 240 |
+
|
| 241 |
+
for page in reader.pages:
|
| 242 |
+
text = page.extract_text()
|
| 243 |
+
if text:
|
| 244 |
+
text_parts.append(text.strip())
|
| 245 |
+
|
| 246 |
+
# Join with double newline to preserve page breaks
|
| 247 |
+
full_text = "\n\n".join(text_parts)
|
| 248 |
+
|
| 249 |
+
# Quality check
|
| 250 |
+
score, _ = TextQualityChecker.calculate_quality_score(full_text)
|
| 251 |
+
return full_text if score > 0.3 else None
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.debug(f"PyPDF2 failed for {pdf_path.name}: {e}")
|
| 255 |
+
return None
|
| 256 |
+
|
| 257 |
+
def extract_with_pdfplumber(self, pdf_path: Path) -> Optional[str]:
|
| 258 |
+
"""Fallback extraction - better for complex layouts"""
|
| 259 |
+
try:
|
| 260 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 261 |
+
text_parts = []
|
| 262 |
+
|
| 263 |
+
for page in pdf.pages:
|
| 264 |
+
text = page.extract_text()
|
| 265 |
+
if text:
|
| 266 |
+
text_parts.append(text.strip())
|
| 267 |
+
|
| 268 |
+
full_text = "\n\n".join(text_parts)
|
| 269 |
+
|
| 270 |
+
score, _ = TextQualityChecker.calculate_quality_score(full_text)
|
| 271 |
+
return full_text if score > 0.3 else None
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
logger.debug(f"pdfplumber failed for {pdf_path.name}: {e}")
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
def extract_with_ocr(self, pdf_path: Path, num_pages: int) -> Optional[str]:
|
| 278 |
+
"""
|
| 279 |
+
OCR extraction with proper page limiting and text normalization
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
pdf_path: Path to PDF
|
| 283 |
+
num_pages: Total pages in PDF (for proper limiting)
|
| 284 |
+
"""
|
| 285 |
+
if not self.enable_ocr:
|
| 286 |
+
return None
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
logger.info(f"OCR extraction: {pdf_path.name}")
|
| 290 |
+
|
| 291 |
+
# Proper page limiting
|
| 292 |
+
last_page = min(self.ocr_max_pages, num_pages)
|
| 293 |
+
|
| 294 |
+
if num_pages > self.ocr_max_pages:
|
| 295 |
+
logger.warning(f"PDF has {num_pages} pages, OCR limited to first {self.ocr_max_pages}")
|
| 296 |
+
|
| 297 |
+
# Convert to images
|
| 298 |
+
images = convert_from_path(
|
| 299 |
+
pdf_path,
|
| 300 |
+
dpi=300,
|
| 301 |
+
first_page=1,
|
| 302 |
+
last_page=last_page
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
text_parts = []
|
| 306 |
+
for i, image in enumerate(images, 1):
|
| 307 |
+
logger.debug(f"OCR page {i}/{len(images)}")
|
| 308 |
+
|
| 309 |
+
text = pytesseract.image_to_string(image, lang='eng')
|
| 310 |
+
if text.strip():
|
| 311 |
+
text_parts.append(text)
|
| 312 |
+
|
| 313 |
+
full_text = "\n\n".join(text_parts)
|
| 314 |
+
|
| 315 |
+
# Normalize OCR text
|
| 316 |
+
full_text = TextQualityChecker.clean_ocr_text(full_text)
|
| 317 |
+
|
| 318 |
+
# Check quality
|
| 319 |
+
score, issues = TextQualityChecker.calculate_quality_score(full_text)
|
| 320 |
+
|
| 321 |
+
if score > 0.3:
|
| 322 |
+
# Log successful OCR to JSONL
|
| 323 |
+
self._log_ocr_case(pdf_path, num_pages, last_page, score)
|
| 324 |
+
logger.info(f"✓ OCR successful (quality: {score:.2f})")
|
| 325 |
+
return full_text
|
| 326 |
+
else:
|
| 327 |
+
logger.warning(f"OCR quality too low ({score:.2f}): {issues}")
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
except Exception as e:
|
| 331 |
+
logger.warning(f"OCR failed for {pdf_path.name}: {e}")
|
| 332 |
+
return None
|
| 333 |
+
|
| 334 |
+
def _log_ocr_case(self, pdf_path: Path, total_pages: int, pages_processed: int, quality: float):
|
| 335 |
+
"""Log OCR usage to JSONL file"""
|
| 336 |
+
log_entry = {
|
| 337 |
+
'timestamp': datetime.now().isoformat(),
|
| 338 |
+
'filename': pdf_path.name,
|
| 339 |
+
'year': pdf_path.parent.name,
|
| 340 |
+
'total_pages': total_pages,
|
| 341 |
+
'pages_processed': pages_processed,
|
| 342 |
+
'quality_score': quality
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
with open(self.ocr_log_file, 'a', encoding='utf-8') as f:
|
| 346 |
+
f.write(json.dumps(log_entry) + '\n')
|
| 347 |
+
|
| 348 |
+
def extract_pdf(self, pdf_path: Path) -> Dict:
|
| 349 |
+
"""
|
| 350 |
+
Main extraction with fallback chain and quality assurance
|
| 351 |
+
"""
|
| 352 |
+
errors = []
|
| 353 |
+
warnings = []
|
| 354 |
+
text = None
|
| 355 |
+
method = None
|
| 356 |
+
ocr_used = False
|
| 357 |
+
quality_score = 0.0
|
| 358 |
+
|
| 359 |
+
# Get metadata
|
| 360 |
+
file_size = pdf_path.stat().st_size
|
| 361 |
+
|
| 362 |
+
# Robust year extraction
|
| 363 |
+
year, year_warnings = self.extract_year_from_path(pdf_path)
|
| 364 |
+
warnings.extend(year_warnings)
|
| 365 |
+
|
| 366 |
+
# Count pages first (needed for OCR)
|
| 367 |
+
try:
|
| 368 |
+
with open(pdf_path, 'rb') as f:
|
| 369 |
+
reader = PyPDF2.PdfReader(f)
|
| 370 |
+
num_pages = len(reader.pages)
|
| 371 |
+
except Exception as e:
|
| 372 |
+
num_pages = 0
|
| 373 |
+
errors.append(f"Could not count pages: {e}")
|
| 374 |
+
|
| 375 |
+
# Extraction chain: PyPDF2 → pdfplumber → OCR
|
| 376 |
+
text = self.extract_with_pypdf2(pdf_path)
|
| 377 |
+
if text:
|
| 378 |
+
method = "pypdf2"
|
| 379 |
+
else:
|
| 380 |
+
errors.append("PyPDF2 insufficient")
|
| 381 |
+
|
| 382 |
+
text = self.extract_with_pdfplumber(pdf_path)
|
| 383 |
+
if text:
|
| 384 |
+
method = "pdfplumber"
|
| 385 |
+
else:
|
| 386 |
+
errors.append("pdfplumber failed")
|
| 387 |
+
|
| 388 |
+
if self.enable_ocr and num_pages > 0:
|
| 389 |
+
text = self.extract_with_ocr(pdf_path, num_pages)
|
| 390 |
+
if text:
|
| 391 |
+
method = "ocr"
|
| 392 |
+
ocr_used = True
|
| 393 |
+
warnings.append("OCR used - verify quality")
|
| 394 |
+
else:
|
| 395 |
+
errors.append("OCR failed")
|
| 396 |
+
|
| 397 |
+
# Calculate quality
|
| 398 |
+
paragraph_count = 0
|
| 399 |
+
if text:
|
| 400 |
+
quality_score, quality_issues = TextQualityChecker.calculate_quality_score(text)
|
| 401 |
+
paragraph_count = self.count_paragraphs(text)
|
| 402 |
+
|
| 403 |
+
if quality_score < 0.7:
|
| 404 |
+
warnings.extend(quality_issues)
|
| 405 |
+
|
| 406 |
+
# Create metadata
|
| 407 |
+
metadata = ExtractionMetadata(
|
| 408 |
+
filename=pdf_path.name,
|
| 409 |
+
year=year,
|
| 410 |
+
num_pages=num_pages,
|
| 411 |
+
text_length=len(text) if text else 0,
|
| 412 |
+
extraction_method=method if method else "failed",
|
| 413 |
+
has_text=text is not None,
|
| 414 |
+
extraction_timestamp=datetime.now().isoformat(),
|
| 415 |
+
file_size_bytes=file_size,
|
| 416 |
+
ocr_used=ocr_used,
|
| 417 |
+
quality_score=quality_score,
|
| 418 |
+
paragraph_count=paragraph_count,
|
| 419 |
+
errors=errors,
|
| 420 |
+
warnings=warnings
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
return {
|
| 424 |
+
'text': text,
|
| 425 |
+
'metadata': metadata
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
def save_extraction(self, pdf_path: Path, extraction_result: Dict) -> bool:
|
| 429 |
+
"""Save with quality indicators"""
|
| 430 |
+
|
| 431 |
+
metadata = extraction_result['metadata']
|
| 432 |
+
text = extraction_result['text']
|
| 433 |
+
|
| 434 |
+
base_name = pdf_path.stem
|
| 435 |
+
year = metadata.year
|
| 436 |
+
|
| 437 |
+
# Save text
|
| 438 |
+
if text:
|
| 439 |
+
text_file = self.text_dir / f"{year}_{base_name}.txt"
|
| 440 |
+
try:
|
| 441 |
+
with open(text_file, 'w', encoding='utf-8') as f:
|
| 442 |
+
# Add quality header
|
| 443 |
+
f.write(f"{'='*70}\n")
|
| 444 |
+
f.write(f"File: {metadata.filename}\n")
|
| 445 |
+
f.write(f"Extraction: {metadata.extraction_method}\n")
|
| 446 |
+
f.write(f"Quality: {metadata.quality_score:.2f}\n")
|
| 447 |
+
f.write(f"Paragraphs: {metadata.paragraph_count}\n")
|
| 448 |
+
|
| 449 |
+
if metadata.ocr_used:
|
| 450 |
+
f.write("⚠️ OCR USED - Verify important details\n")
|
| 451 |
+
|
| 452 |
+
if metadata.warnings:
|
| 453 |
+
f.write(f"Warnings: {', '.join(metadata.warnings[:3])}\n")
|
| 454 |
+
|
| 455 |
+
f.write(f"{'='*70}\n\n")
|
| 456 |
+
f.write(text)
|
| 457 |
+
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.error(f"Failed to save text: {e}")
|
| 460 |
+
return False
|
| 461 |
+
|
| 462 |
+
# Save metadata
|
| 463 |
+
metadata_file = self.metadata_dir / f"{year}_{base_name}.json"
|
| 464 |
+
try:
|
| 465 |
+
with open(metadata_file, 'w', encoding='utf-8') as f:
|
| 466 |
+
json.dump(asdict(metadata), f, indent=2)
|
| 467 |
+
except Exception as e:
|
| 468 |
+
logger.error(f"Failed to save metadata: {e}")
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
# Log failures
|
| 472 |
+
if not text:
|
| 473 |
+
failed_log = self.failed_dir / "failed_extractions.jsonl"
|
| 474 |
+
with open(failed_log, 'a', encoding='utf-8') as f:
|
| 475 |
+
log_entry = {
|
| 476 |
+
'timestamp': datetime.now().isoformat(),
|
| 477 |
+
'file': str(pdf_path),
|
| 478 |
+
'errors': metadata.errors
|
| 479 |
+
}
|
| 480 |
+
f.write(json.dumps(log_entry) + '\n')
|
| 481 |
+
|
| 482 |
+
return True
|
| 483 |
+
|
| 484 |
+
def process_pdf(self, pdf_path: Path) -> bool:
|
| 485 |
+
"""Process single PDF"""
|
| 486 |
+
try:
|
| 487 |
+
result = self.extract_pdf(pdf_path)
|
| 488 |
+
return self.save_extraction(pdf_path, result)
|
| 489 |
+
except Exception as e:
|
| 490 |
+
logger.error(f"Unexpected error: {pdf_path.name}: {e}")
|
| 491 |
+
return False
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
if __name__ == "__main__":
|
| 495 |
+
# Test
|
| 496 |
+
print("="*70)
|
| 497 |
+
print("Testing Enhanced PDF Extractor")
|
| 498 |
+
print("="*70)
|
| 499 |
+
|
| 500 |
+
extractor = LegalJudgmentExtractor(
|
| 501 |
+
output_dir=Path("data/processed/extracted"),
|
| 502 |
+
enable_ocr=False
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
test_pdf = Path("data/raw/2025/A_John_Kennedy_vs_The_State_Of_Tamil_Nadu_on_24_March_2025_1.PDF")
|
| 506 |
+
|
| 507 |
+
if test_pdf.exists():
|
| 508 |
+
print(f"\nTesting: {test_pdf.name}")
|
| 509 |
+
success = extractor.process_pdf(test_pdf)
|
| 510 |
+
print(f"\n{'✓' if success else '✗'} Extraction {'successful' if success else 'failed'}")
|
| 511 |
+
|
| 512 |
+
# Show metadata
|
| 513 |
+
metadata_file = Path("data/processed/extracted/metadata") / f"2025_{test_pdf.stem}.json"
|
| 514 |
+
if metadata_file.exists():
|
| 515 |
+
with open(metadata_file, 'r') as f:
|
| 516 |
+
metadata = json.load(f)
|
| 517 |
+
print(f"\nMethod: {metadata['extraction_method']}")
|
| 518 |
+
print(f"Quality: {metadata['quality_score']:.2f}")
|
| 519 |
+
print(f"Paragraphs: {metadata['paragraph_count']}")
|
| 520 |
+
print(f"Text length: {metadata['text_length']:,} chars")
|
| 521 |
+
else:
|
| 522 |
+
print("Test PDF not found")
|
src/indexing/build_faiss_index.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NyayLens – FAISS Index Builder (Merged Embeddings)
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import faiss
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_faiss_index():
|
| 12 |
+
print("=" * 70)
|
| 13 |
+
print("NyayLens – Building FAISS Index")
|
| 14 |
+
print("=" * 70)
|
| 15 |
+
|
| 16 |
+
embeddings_dir = Path("data/processed/embeddings")
|
| 17 |
+
output_dir = Path("data/processed/faiss")
|
| 18 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
embeddings_file = embeddings_dir / "paragraph_embeddings.npy"
|
| 21 |
+
ids_file = embeddings_dir / "paragraph_ids.json"
|
| 22 |
+
meta_file = embeddings_dir / "embedding_metadata.json"
|
| 23 |
+
|
| 24 |
+
# --- Load data ---
|
| 25 |
+
print("\nLoading embeddings...")
|
| 26 |
+
embeddings = np.load(embeddings_file)
|
| 27 |
+
|
| 28 |
+
print("Loading paragraph IDs...")
|
| 29 |
+
with open(ids_file, "r") as f:
|
| 30 |
+
paragraph_ids = json.load(f)
|
| 31 |
+
|
| 32 |
+
# --- Safety checks ---
|
| 33 |
+
assert embeddings.shape[0] == len(paragraph_ids), (
|
| 34 |
+
f"Mismatch: {embeddings.shape[0]} embeddings vs "
|
| 35 |
+
f"{len(paragraph_ids)} paragraph IDs"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
print(f"✓ Loaded {embeddings.shape[0]:,} vectors")
|
| 39 |
+
print(f"✓ Embedding dimension: {embeddings.shape[1]}")
|
| 40 |
+
|
| 41 |
+
# --- Normalize (cosine similarity) ---
|
| 42 |
+
print("\nNormalizing embeddings...")
|
| 43 |
+
faiss.normalize_L2(embeddings)
|
| 44 |
+
|
| 45 |
+
# --- Build FAISS index ---
|
| 46 |
+
dim = embeddings.shape[1]
|
| 47 |
+
|
| 48 |
+
# Switch to HNSW for rapid Approximate Nearest Neighbor search
|
| 49 |
+
# M=32 is number of connections per layer, typical for dense models
|
| 50 |
+
index = faiss.IndexHNSWFlat(dim, 32, faiss.METRIC_INNER_PRODUCT)
|
| 51 |
+
index.hnsw.efConstruction = 40 # Depth of search during index build
|
| 52 |
+
index.hnsw.efSearch = 64 # Depth of search during inference
|
| 53 |
+
|
| 54 |
+
print(f"Adding {embeddings.shape[0]:,} vectors to HNSW index...")
|
| 55 |
+
index.add(embeddings)
|
| 56 |
+
|
| 57 |
+
print(f"✓ FAISS index contains {index.ntotal:,} vectors")
|
| 58 |
+
|
| 59 |
+
# --- Save index ---
|
| 60 |
+
index_file = output_dir / "faiss_index.bin"
|
| 61 |
+
faiss.write_index(index, str(index_file))
|
| 62 |
+
|
| 63 |
+
# --- Save FAISS metadata ---
|
| 64 |
+
with open(meta_file, "r") as f:
|
| 65 |
+
emb_meta = json.load(f)
|
| 66 |
+
|
| 67 |
+
faiss_meta = {
|
| 68 |
+
"index_type": "IndexHNSWFlat",
|
| 69 |
+
"metric": "cosine_similarity",
|
| 70 |
+
"dimension": dim,
|
| 71 |
+
"total_vectors": int(index.ntotal),
|
| 72 |
+
"embedding_model": emb_meta.get("model_name", "unknown"),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
with open(output_dir / "faiss_metadata.json", "w") as f:
|
| 76 |
+
json.dump(faiss_meta, f, indent=2)
|
| 77 |
+
|
| 78 |
+
print(f"\n✓ FAISS index saved to: {index_file}")
|
| 79 |
+
print("=" * 70)
|
| 80 |
+
print("✓ FAISS Index Build Complete")
|
| 81 |
+
print("=" * 70)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
build_faiss_index()
|
src/indexing/create_embeddings.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate embeddings in CHUNKS - GPU-safe with resume capability
|
| 3 |
+
Processes 20K paragraphs at a time with cooling breaks
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
import json
|
| 8 |
+
import numpy as np
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import torch
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
class ChunkedEmbeddingGenerator:
|
| 15 |
+
"""Generate embeddings in safe chunks with resume capability"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_name: str = "BAAI/bge-base-en-v1.5", chunk_size: int = 20000):
|
| 18 |
+
"""
|
| 19 |
+
Initialize with sentence transformer
|
| 20 |
+
chunk_size: Process this many paragraphs at a time
|
| 21 |
+
"""
|
| 22 |
+
print(f"Loading model: {model_name}")
|
| 23 |
+
|
| 24 |
+
# Check CUDA availability
|
| 25 |
+
if torch.cuda.is_available():
|
| 26 |
+
print(f"✓ CUDA available! GPU: {torch.cuda.get_device_name(0)}")
|
| 27 |
+
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
| 28 |
+
self.device = 'cuda'
|
| 29 |
+
else:
|
| 30 |
+
print("⚠️ CUDA not available, using CPU")
|
| 31 |
+
self.device = 'cpu'
|
| 32 |
+
|
| 33 |
+
self.model = SentenceTransformer(model_name)
|
| 34 |
+
self.model.to(self.device)
|
| 35 |
+
self.chunk_size = chunk_size
|
| 36 |
+
print(f"✓ Model loaded on: {self.device}")
|
| 37 |
+
print(f"✓ Chunk size: {chunk_size:,} paragraphs")
|
| 38 |
+
|
| 39 |
+
def load_paragraphs(self, index_file: Path):
|
| 40 |
+
"""Load paragraphs from JSONL"""
|
| 41 |
+
print("\nLoading paragraphs from index...")
|
| 42 |
+
paragraphs = []
|
| 43 |
+
|
| 44 |
+
with open(index_file, 'r', encoding='utf-8') as f:
|
| 45 |
+
total_lines = sum(1 for _ in f)
|
| 46 |
+
|
| 47 |
+
with open(index_file, 'r', encoding='utf-8') as f:
|
| 48 |
+
for line in tqdm(f, total=total_lines, desc="Loading index"):
|
| 49 |
+
para = json.loads(line)
|
| 50 |
+
paragraphs.append(para)
|
| 51 |
+
|
| 52 |
+
print(f"✓ Loaded {len(paragraphs):,} paragraphs")
|
| 53 |
+
return paragraphs
|
| 54 |
+
|
| 55 |
+
def process_chunk(self, texts, batch_size=64):
|
| 56 |
+
"""Process one chunk of texts"""
|
| 57 |
+
embeddings_list = []
|
| 58 |
+
|
| 59 |
+
with tqdm(total=len(texts), desc="Processing chunk", unit="para") as pbar:
|
| 60 |
+
for i in range(0, len(texts), batch_size):
|
| 61 |
+
batch = texts[i:i+batch_size]
|
| 62 |
+
|
| 63 |
+
# Generate embeddings
|
| 64 |
+
batch_embeddings = self.model.encode(
|
| 65 |
+
batch,
|
| 66 |
+
convert_to_numpy=True,
|
| 67 |
+
normalize_embeddings=True,
|
| 68 |
+
show_progress_bar=False,
|
| 69 |
+
device=self.device
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
embeddings_list.append(batch_embeddings)
|
| 73 |
+
pbar.update(len(batch))
|
| 74 |
+
|
| 75 |
+
# Clear GPU cache every 10 batches
|
| 76 |
+
if self.device == 'cuda' and i % (batch_size * 10) == 0:
|
| 77 |
+
torch.cuda.empty_cache()
|
| 78 |
+
|
| 79 |
+
return np.vstack(embeddings_list)
|
| 80 |
+
|
| 81 |
+
def generate_embeddings_chunked(self, index_file: Path, output_dir: Path, batch_size: int = 64):
|
| 82 |
+
"""Generate embeddings in chunks with resume capability"""
|
| 83 |
+
|
| 84 |
+
output_dir = Path(output_dir)
|
| 85 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
chunks_dir = output_dir / "chunks"
|
| 88 |
+
chunks_dir.mkdir(exist_ok=True)
|
| 89 |
+
|
| 90 |
+
# Load paragraphs
|
| 91 |
+
paragraphs = self.load_paragraphs(index_file)
|
| 92 |
+
total_paras = len(paragraphs)
|
| 93 |
+
|
| 94 |
+
# Calculate chunks
|
| 95 |
+
num_chunks = (total_paras + self.chunk_size - 1) // self.chunk_size
|
| 96 |
+
print(f"\n✓ Will process in {num_chunks} chunks of {self.chunk_size:,} paragraphs each")
|
| 97 |
+
|
| 98 |
+
# Check for existing chunks
|
| 99 |
+
existing_chunks = list(chunks_dir.glob("chunk_*.npy"))
|
| 100 |
+
completed_chunks = len(existing_chunks)
|
| 101 |
+
|
| 102 |
+
if completed_chunks > 0:
|
| 103 |
+
print(f"✓ Found {completed_chunks} existing chunks")
|
| 104 |
+
response = input(f"Resume from chunk {completed_chunks + 1}? (yes/no): ")
|
| 105 |
+
if response.lower() != 'yes':
|
| 106 |
+
print("Starting fresh...")
|
| 107 |
+
for f in existing_chunks:
|
| 108 |
+
f.unlink()
|
| 109 |
+
completed_chunks = 0
|
| 110 |
+
|
| 111 |
+
# Process chunks
|
| 112 |
+
start_time = time.time()
|
| 113 |
+
|
| 114 |
+
for chunk_idx in range(completed_chunks, num_chunks):
|
| 115 |
+
print(f"\n{'='*70}")
|
| 116 |
+
print(f"CHUNK {chunk_idx + 1}/{num_chunks}")
|
| 117 |
+
print(f"{'='*70}")
|
| 118 |
+
|
| 119 |
+
# Get chunk data
|
| 120 |
+
start_idx = chunk_idx * self.chunk_size
|
| 121 |
+
end_idx = min(start_idx + self.chunk_size, total_paras)
|
| 122 |
+
chunk_paras = paragraphs[start_idx:end_idx]
|
| 123 |
+
|
| 124 |
+
print(f"Processing paragraphs {start_idx:,} to {end_idx:,}")
|
| 125 |
+
|
| 126 |
+
# Extract texts
|
| 127 |
+
texts = [p['text'] for p in chunk_paras]
|
| 128 |
+
ids = [p['id'] for p in chunk_paras]
|
| 129 |
+
|
| 130 |
+
# Process chunk
|
| 131 |
+
chunk_start = time.time()
|
| 132 |
+
embeddings = self.process_chunk(texts, batch_size)
|
| 133 |
+
chunk_time = time.time() - chunk_start
|
| 134 |
+
|
| 135 |
+
print(f"✓ Chunk completed in {chunk_time/60:.1f} minutes")
|
| 136 |
+
print(f" Speed: {len(texts)/chunk_time:.1f} para/s")
|
| 137 |
+
|
| 138 |
+
# Save chunk
|
| 139 |
+
chunk_file = chunks_dir / f"chunk_{chunk_idx:03d}.npy"
|
| 140 |
+
ids_file = chunks_dir / f"chunk_{chunk_idx:03d}_ids.json"
|
| 141 |
+
|
| 142 |
+
np.save(chunk_file, embeddings)
|
| 143 |
+
with open(ids_file, 'w') as f:
|
| 144 |
+
json.dump(ids, f)
|
| 145 |
+
|
| 146 |
+
print(f"✓ Saved to: {chunk_file.name}")
|
| 147 |
+
|
| 148 |
+
# Clear GPU and add cooling break
|
| 149 |
+
if self.device == 'cuda':
|
| 150 |
+
torch.cuda.empty_cache()
|
| 151 |
+
if chunk_idx < num_chunks - 1: # Not last chunk
|
| 152 |
+
print("\n⏸️ Cooling break: 10 seconds...")
|
| 153 |
+
time.sleep(10)
|
| 154 |
+
|
| 155 |
+
# Combine all chunks
|
| 156 |
+
print(f"\n{'='*70}")
|
| 157 |
+
print("COMBINING CHUNKS...")
|
| 158 |
+
print(f"{'='*70}")
|
| 159 |
+
|
| 160 |
+
all_embeddings = []
|
| 161 |
+
all_ids = []
|
| 162 |
+
|
| 163 |
+
for chunk_idx in tqdm(range(num_chunks), desc="Loading chunks"):
|
| 164 |
+
chunk_file = chunks_dir / f"chunk_{chunk_idx:03d}.npy"
|
| 165 |
+
ids_file = chunks_dir / f"chunk_{chunk_idx:03d}_ids.json"
|
| 166 |
+
|
| 167 |
+
embeddings = np.load(chunk_file)
|
| 168 |
+
with open(ids_file, 'r') as f:
|
| 169 |
+
ids = json.load(f)
|
| 170 |
+
|
| 171 |
+
all_embeddings.append(embeddings)
|
| 172 |
+
all_ids.extend(ids)
|
| 173 |
+
|
| 174 |
+
final_embeddings = np.vstack(all_embeddings)
|
| 175 |
+
|
| 176 |
+
print(f"✓ Combined shape: {final_embeddings.shape}")
|
| 177 |
+
|
| 178 |
+
# Save final files
|
| 179 |
+
print("\nSaving final files...")
|
| 180 |
+
|
| 181 |
+
embeddings_file = output_dir / "paragraph_embeddings.npy"
|
| 182 |
+
np.save(embeddings_file, final_embeddings)
|
| 183 |
+
print(f"✓ Saved: {embeddings_file}")
|
| 184 |
+
|
| 185 |
+
ids_file = output_dir / "paragraph_ids.json"
|
| 186 |
+
with open(ids_file, 'w') as f:
|
| 187 |
+
json.dump(all_ids, f)
|
| 188 |
+
print(f"✓ Saved: {ids_file}")
|
| 189 |
+
|
| 190 |
+
# Save metadata
|
| 191 |
+
total_time = time.time() - start_time
|
| 192 |
+
metadata = {
|
| 193 |
+
'model_name': "BAAI/bge-base-en-v1.5",
|
| 194 |
+
'embedding_dim': int(final_embeddings.shape[1]),
|
| 195 |
+
'total_paragraphs': len(all_ids),
|
| 196 |
+
'device_used': self.device,
|
| 197 |
+
'batch_size': batch_size,
|
| 198 |
+
'chunk_size': self.chunk_size,
|
| 199 |
+
'num_chunks': num_chunks,
|
| 200 |
+
'total_time_minutes': total_time / 60
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
with open(output_dir / "embedding_metadata.json", 'w') as f:
|
| 204 |
+
json.dump(metadata, f, indent=2)
|
| 205 |
+
|
| 206 |
+
print(f"\n✓ Total time: {total_time/60:.1f} minutes")
|
| 207 |
+
print(f" Average speed: {len(all_ids)/total_time:.1f} para/s")
|
| 208 |
+
|
| 209 |
+
return final_embeddings, all_ids
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
print("="*70)
|
| 214 |
+
print("NyayLens - Chunked Embedding Generation (GPU-Safe)")
|
| 215 |
+
print("="*70)
|
| 216 |
+
print()
|
| 217 |
+
|
| 218 |
+
generator = ChunkedEmbeddingGenerator(
|
| 219 |
+
model_name="BAAI/bge-base-en-v1.5",
|
| 220 |
+
chunk_size=20000 # Process 20K at a time
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
embeddings, ids = generator.generate_embeddings_chunked(
|
| 224 |
+
index_file=Path("data/processed/indexed/paragraph_index.jsonl"),
|
| 225 |
+
output_dir=Path("data/processed/embeddings"),
|
| 226 |
+
batch_size=64 # Reduced for safety
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
print()
|
| 230 |
+
print("="*70)
|
| 231 |
+
print(f"✓ COMPLETE! Generated {len(embeddings):,} embeddings")
|
| 232 |
+
print("="*70)
|
src/indexing/create_sqlite_index.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """Create SQLite index for fast paragraph lookup"""
|
| 2 |
+
|
| 3 |
+
# import sqlite3
|
| 4 |
+
# import json
|
| 5 |
+
# from pathlib import Path
|
| 6 |
+
# from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# def create_sqlite_index():
|
| 9 |
+
# print("Creating SQLite index...")
|
| 10 |
+
|
| 11 |
+
# db_path = Path("data/processed/indexed/paragraphs.db")
|
| 12 |
+
# db_path.parent.mkdir(parents=True, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
# # Create database
|
| 15 |
+
# conn = sqlite3.connect(db_path)
|
| 16 |
+
# cursor = conn.cursor()
|
| 17 |
+
|
| 18 |
+
# # Create table
|
| 19 |
+
# cursor.execute("""
|
| 20 |
+
# CREATE TABLE IF NOT EXISTS paragraphs (
|
| 21 |
+
# id TEXT PRIMARY KEY,
|
| 22 |
+
# judgment_id TEXT,
|
| 23 |
+
# page_no INTEGER,
|
| 24 |
+
# text TEXT,
|
| 25 |
+
# char_count INTEGER,
|
| 26 |
+
# word_count INTEGER
|
| 27 |
+
# )
|
| 28 |
+
# """)
|
| 29 |
+
|
| 30 |
+
# cursor.execute("CREATE INDEX IF NOT EXISTS idx_judgment ON paragraphs(judgment_id)")
|
| 31 |
+
|
| 32 |
+
# # Load data
|
| 33 |
+
# index_file = Path("data/processed/indexed/paragraph_index.jsonl")
|
| 34 |
+
|
| 35 |
+
# with open(index_file, 'r', encoding='utf-8') as f:
|
| 36 |
+
# total = sum(1 for _ in f)
|
| 37 |
+
|
| 38 |
+
# with open(index_file, 'r', encoding='utf-8') as f:
|
| 39 |
+
# batch = []
|
| 40 |
+
# for line in tqdm(f, total=total, desc="Inserting"):
|
| 41 |
+
# p = json.loads(line)
|
| 42 |
+
# batch.append((
|
| 43 |
+
# p['id'], p['judgment_id'], p['page_no'],
|
| 44 |
+
# p['text'], p['char_count'], p['word_count']
|
| 45 |
+
# ))
|
| 46 |
+
|
| 47 |
+
# if len(batch) >= 1000:
|
| 48 |
+
# cursor.executemany(
|
| 49 |
+
# "INSERT OR REPLACE INTO paragraphs VALUES (?,?,?,?,?,?)",
|
| 50 |
+
# batch
|
| 51 |
+
# )
|
| 52 |
+
# batch = []
|
| 53 |
+
|
| 54 |
+
# if batch:
|
| 55 |
+
# cursor.executemany(
|
| 56 |
+
# "INSERT OR REPLACE INTO paragraphs VALUES (?,?,?,?,?,?)",
|
| 57 |
+
# batch
|
| 58 |
+
# )
|
| 59 |
+
|
| 60 |
+
# conn.commit()
|
| 61 |
+
# conn.close()
|
| 62 |
+
|
| 63 |
+
# print(f"✓ SQLite index created: {db_path}")
|
| 64 |
+
|
| 65 |
+
# if __name__ == "__main__":
|
| 66 |
+
# create_sqlite_index()
|
| 67 |
+
"""
|
| 68 |
+
Create SQLite index with section annotations
|
| 69 |
+
Source: paragraph_index_with_sections.jsonl
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
import sqlite3
|
| 73 |
+
import json
|
| 74 |
+
from pathlib import Path
|
| 75 |
+
from tqdm import tqdm
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
INPUT_INDEX = Path("data/processed/indexed/paragraph_index_with_sections.jsonl")
|
| 79 |
+
DB_PATH = Path("data/processed/indexed/paragraphs.db")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_sqlite_index():
|
| 83 |
+
print("=" * 70)
|
| 84 |
+
print("NyayLens – Creating SQLite Index (with Sections)")
|
| 85 |
+
print("=" * 70)
|
| 86 |
+
|
| 87 |
+
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 88 |
+
|
| 89 |
+
# Connect to SQLite
|
| 90 |
+
conn = sqlite3.connect(DB_PATH)
|
| 91 |
+
cursor = conn.cursor()
|
| 92 |
+
|
| 93 |
+
# Drop existing table (derived data → safe to rebuild)
|
| 94 |
+
cursor.execute("DROP TABLE IF EXISTS paragraphs")
|
| 95 |
+
|
| 96 |
+
# Create table
|
| 97 |
+
cursor.execute("""
|
| 98 |
+
CREATE TABLE paragraphs (
|
| 99 |
+
id TEXT PRIMARY KEY,
|
| 100 |
+
judgment_id TEXT,
|
| 101 |
+
page_no INTEGER,
|
| 102 |
+
text TEXT,
|
| 103 |
+
char_count INTEGER,
|
| 104 |
+
word_count INTEGER,
|
| 105 |
+
section TEXT,
|
| 106 |
+
section_conf REAL
|
| 107 |
+
)
|
| 108 |
+
""")
|
| 109 |
+
|
| 110 |
+
# Create FTS5 virtual table for fast full-text search (BM25)
|
| 111 |
+
cursor.execute("DROP TABLE IF EXISTS paragraphs_fts")
|
| 112 |
+
cursor.execute("""
|
| 113 |
+
CREATE VIRTUAL TABLE paragraphs_fts USING fts5(
|
| 114 |
+
id UNINDEXED,
|
| 115 |
+
text,
|
| 116 |
+
tokenize='porter unicode61'
|
| 117 |
+
)
|
| 118 |
+
""")
|
| 119 |
+
|
| 120 |
+
# Indexes for fast lookup
|
| 121 |
+
cursor.execute("CREATE INDEX idx_judgment_id ON paragraphs(judgment_id)")
|
| 122 |
+
cursor.execute("CREATE INDEX idx_section ON paragraphs(section)")
|
| 123 |
+
cursor.execute("CREATE INDEX idx_judgment_section ON paragraphs(judgment_id, section)")
|
| 124 |
+
|
| 125 |
+
conn.commit()
|
| 126 |
+
|
| 127 |
+
# Count total records
|
| 128 |
+
with open(INPUT_INDEX, "r", encoding="utf-8") as f:
|
| 129 |
+
total = sum(1 for _ in f)
|
| 130 |
+
|
| 131 |
+
print(f"✓ Inserting {total:,} paragraphs")
|
| 132 |
+
|
| 133 |
+
# Insert data in batches
|
| 134 |
+
batch = []
|
| 135 |
+
BATCH_SIZE = 1000
|
| 136 |
+
|
| 137 |
+
with open(INPUT_INDEX, "r", encoding="utf-8") as f:
|
| 138 |
+
for line in tqdm(f, total=total, desc="Inserting"):
|
| 139 |
+
p = json.loads(line)
|
| 140 |
+
|
| 141 |
+
batch.append((
|
| 142 |
+
p["id"],
|
| 143 |
+
p["judgment_id"],
|
| 144 |
+
p.get("page_no", -1),
|
| 145 |
+
p["text"],
|
| 146 |
+
p.get("char_count", len(p["text"])),
|
| 147 |
+
p.get("word_count", len(p["text"].split())),
|
| 148 |
+
p.get("section", "unknown"),
|
| 149 |
+
p.get("section_conf", 0.0),
|
| 150 |
+
))
|
| 151 |
+
|
| 152 |
+
if len(batch) >= BATCH_SIZE:
|
| 153 |
+
cursor.executemany(
|
| 154 |
+
"""
|
| 155 |
+
INSERT INTO paragraphs
|
| 156 |
+
(id, judgment_id, page_no, text, char_count, word_count, section, section_conf)
|
| 157 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 158 |
+
""",
|
| 159 |
+
batch
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Insert into FTS5 table
|
| 163 |
+
fts_batch = [(b[0], b[3]) for b in batch]
|
| 164 |
+
cursor.executemany(
|
| 165 |
+
"INSERT INTO paragraphs_fts (id, text) VALUES (?, ?)",
|
| 166 |
+
fts_batch
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
batch.clear()
|
| 170 |
+
|
| 171 |
+
if batch:
|
| 172 |
+
cursor.executemany(
|
| 173 |
+
"""
|
| 174 |
+
INSERT INTO paragraphs
|
| 175 |
+
(id, judgment_id, page_no, text, char_count, word_count, section, section_conf)
|
| 176 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 177 |
+
""",
|
| 178 |
+
batch
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
fts_batch = [(b[0], b[3]) for b in batch]
|
| 182 |
+
cursor.executemany(
|
| 183 |
+
"INSERT INTO paragraphs_fts (id, text) VALUES (?, ?)",
|
| 184 |
+
fts_batch
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
conn.commit()
|
| 188 |
+
conn.close()
|
| 189 |
+
|
| 190 |
+
print("\n✓ SQLite index created successfully")
|
| 191 |
+
print(f"✓ Database path: {DB_PATH}")
|
| 192 |
+
print("=" * 70)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
create_sqlite_index()
|
src/indexing/paragraph_indexer.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Paragraph-level indexer for NyayLens RAG
|
| 3 |
+
- Page-aware
|
| 4 |
+
- Content-stable paragraph IDs
|
| 5 |
+
- Legal-aware filtering
|
| 6 |
+
- Streaming JSONL output (memory safe)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import hashlib
|
| 11 |
+
import re
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, Iterable
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ParagraphIndexer:
|
| 18 |
+
"""Index legal judgments at paragraph level with stable IDs and metadata"""
|
| 19 |
+
|
| 20 |
+
# Legal keywords worth preserving even in short paragraphs
|
| 21 |
+
LEGAL_KEYWORDS = {
|
| 22 |
+
'held', 'order', 'appeal', 'writ', 'judgment', 'decree',
|
| 23 |
+
'petition', 'application', 'allowed', 'dismissed', 'granted',
|
| 24 |
+
'rejected', 'disposed', 'quashed', 'set aside',
|
| 25 |
+
'affirmed', 'reversed', 'remanded', 'bail', 'custody',
|
| 26 |
+
'interim', 'stay', 'injunction', 'no costs'
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
PAGE_MARKER_PATTERN = re.compile(r'<<<PAGE:(\d+)>>>')
|
| 30 |
+
|
| 31 |
+
def __init__(self, texts_dir: Path, output_dir: Path):
|
| 32 |
+
self.texts_dir = Path(texts_dir)
|
| 33 |
+
self.output_dir = Path(output_dir)
|
| 34 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
self.index_file = self.output_dir / "paragraph_index.jsonl"
|
| 37 |
+
self.stats_file = self.output_dir / "index_stats.json"
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def _contains_legal_keyword(text: str) -> bool:
|
| 41 |
+
text_l = text.lower()
|
| 42 |
+
return any(
|
| 43 |
+
re.search(rf"\b{re.escape(kw)}\b", text_l)
|
| 44 |
+
for kw in ParagraphIndexer.LEGAL_KEYWORDS
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def _stable_paragraph_id(judgment_id: str, page_no: int, text: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Content-stable ID:
|
| 51 |
+
- same paragraph text => same ID
|
| 52 |
+
- survives re-indexing
|
| 53 |
+
"""
|
| 54 |
+
h = hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
|
| 55 |
+
page_str = page_no if page_no is not None else "unk"
|
| 56 |
+
return f"{judgment_id}_p{page_str}_{h}"
|
| 57 |
+
|
| 58 |
+
def _strip_header(self, content: str) -> str:
|
| 59 |
+
"""
|
| 60 |
+
Remove extractor quality header safely
|
| 61 |
+
"""
|
| 62 |
+
sep = "=" * 70
|
| 63 |
+
if sep in content:
|
| 64 |
+
parts = content.split(sep, 2)
|
| 65 |
+
if len(parts) == 3:
|
| 66 |
+
return parts[2].strip()
|
| 67 |
+
return content.strip()
|
| 68 |
+
|
| 69 |
+
def _iter_paragraphs(self, content: str) -> Iterable[tuple]:
|
| 70 |
+
"""
|
| 71 |
+
Yield paragraph records with page numbers
|
| 72 |
+
"""
|
| 73 |
+
current_page = None
|
| 74 |
+
buffer = []
|
| 75 |
+
|
| 76 |
+
for line in content.splitlines():
|
| 77 |
+
page_match = self.PAGE_MARKER_PATTERN.match(line.strip())
|
| 78 |
+
if page_match:
|
| 79 |
+
# Flush buffer before page change
|
| 80 |
+
if buffer:
|
| 81 |
+
yield current_page, "\n".join(buffer).strip()
|
| 82 |
+
buffer = []
|
| 83 |
+
current_page = int(page_match.group(1))
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
if not line.strip():
|
| 87 |
+
if buffer:
|
| 88 |
+
yield current_page, "\n".join(buffer).strip()
|
| 89 |
+
buffer = []
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
buffer.append(line)
|
| 93 |
+
|
| 94 |
+
if buffer:
|
| 95 |
+
yield current_page, "\n".join(buffer).strip()
|
| 96 |
+
|
| 97 |
+
def index_judgment(self, text_file: Path, writer) -> int:
|
| 98 |
+
"""
|
| 99 |
+
Index a single judgment file.
|
| 100 |
+
Returns number of paragraphs indexed.
|
| 101 |
+
"""
|
| 102 |
+
with open(text_file, "r", encoding="utf-8") as f:
|
| 103 |
+
content = self._strip_header(f.read())
|
| 104 |
+
|
| 105 |
+
judgment_id = text_file.stem
|
| 106 |
+
para_count = 0
|
| 107 |
+
|
| 108 |
+
for page_no, para in self._iter_paragraphs(content):
|
| 109 |
+
if not para:
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
# Keep substantial OR legally important short paragraphs
|
| 113 |
+
if len(para) < 50 and not self._contains_legal_keyword(para):
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
record = {
|
| 117 |
+
"id": self._stable_paragraph_id(judgment_id, page_no if page_no is not None else -1, para),
|
| 118 |
+
"judgment_id": judgment_id,
|
| 119 |
+
"page_no": page_no if page_no is not None else -1,
|
| 120 |
+
"text": para,
|
| 121 |
+
"char_count": len(para),
|
| 122 |
+
"word_count": len(para.split())
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
writer.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 126 |
+
para_count += 1
|
| 127 |
+
|
| 128 |
+
return para_count
|
| 129 |
+
|
| 130 |
+
def build_full_index(self):
|
| 131 |
+
text_files = sorted(self.texts_dir.glob("*.txt"))
|
| 132 |
+
print(f"Indexing {len(text_files):,} judgments...")
|
| 133 |
+
|
| 134 |
+
total_paragraphs = 0
|
| 135 |
+
|
| 136 |
+
with open(self.index_file, "w", encoding="utf-8") as writer:
|
| 137 |
+
for text_file in tqdm(text_files, desc="Indexing"):
|
| 138 |
+
try:
|
| 139 |
+
total_paragraphs += self.index_judgment(text_file, writer)
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"❌ Failed indexing {text_file.name}: {e}")
|
| 142 |
+
|
| 143 |
+
stats = {
|
| 144 |
+
"total_judgments": len(text_files),
|
| 145 |
+
"total_paragraphs": total_paragraphs,
|
| 146 |
+
"avg_paragraphs_per_judgment":
|
| 147 |
+
total_paragraphs / len(text_files) if text_files else 0
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
with open(self.stats_file, "w", encoding="utf-8") as f:
|
| 151 |
+
json.dump(stats, f, indent=2)
|
| 152 |
+
|
| 153 |
+
print("\n✓ Paragraph indexing complete")
|
| 154 |
+
print(f" Total paragraphs: {total_paragraphs:,}")
|
| 155 |
+
print(f" Output: {self.index_file}")
|
| 156 |
+
|
| 157 |
+
return stats
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
indexer = ParagraphIndexer(
|
| 162 |
+
texts_dir=Path("data/processed/extracted/texts"),
|
| 163 |
+
output_dir=Path("data/processed/indexed")
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
stats = indexer.build_full_index()
|
| 167 |
+
print(f"\nAverage paragraphs per judgment: {stats['avg_paragraphs_per_judgment']:.1f}")
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from extraction.pdf_extractor import LegalJudgmentExtractor
|
| 7 |
+
from segmentation.judgement_segmenter import JudgmentSegmenter
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class NyayLensPipeline:
|
| 13 |
+
"""Unified ingestion pipeline for NyayLens"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, raw_dir="data/raw", processed_dir="data/processed"):
|
| 16 |
+
self.raw_dir = Path(raw_dir)
|
| 17 |
+
self.processed_dir = Path(processed_dir)
|
| 18 |
+
|
| 19 |
+
logger.info("Initializing NyayLens Pipeline components...")
|
| 20 |
+
self.extractor = LegalJudgmentExtractor(output_dir=self.processed_dir / "extracted")
|
| 21 |
+
self.segmenter = JudgmentSegmenter()
|
| 22 |
+
|
| 23 |
+
def ingest_pdf(self, pdf_path: Path):
|
| 24 |
+
"""Process a single PDF end-to-end"""
|
| 25 |
+
logger.info(f"--- Starting Ingestion: {pdf_path.name} ---")
|
| 26 |
+
start_time = time.time()
|
| 27 |
+
|
| 28 |
+
# 1. Extraction
|
| 29 |
+
logger.info("Step 1: Extracting text...")
|
| 30 |
+
extraction_result = self.extractor.extract_pdf(pdf_path)
|
| 31 |
+
if not extraction_result['text']:
|
| 32 |
+
logger.error(f"Extraction failed for {pdf_path.name}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
# Optional: Save extraction
|
| 36 |
+
self.extractor.save_extraction(pdf_path, extraction_result)
|
| 37 |
+
|
| 38 |
+
# 2. Segmentation
|
| 39 |
+
logger.info("Step 2: Segmenting judgment...")
|
| 40 |
+
# Simple paragraph split for segmentation
|
| 41 |
+
paragraphs = extraction_result['text'].split('\n\n')
|
| 42 |
+
sections = self.segmenter.segment(paragraphs)
|
| 43 |
+
|
| 44 |
+
logger.info(f"Found {len(sections)} distinct sections.")
|
| 45 |
+
|
| 46 |
+
# 3. Next steps would hook into `create_embeddings.py` and `create_sqlite_index.py`
|
| 47 |
+
logger.info("Step 3: Ready for Embeddings & Indexing (Batch process recommended)")
|
| 48 |
+
|
| 49 |
+
elapsed = time.time() - start_time
|
| 50 |
+
logger.info(f"--- Successfully processed {pdf_path.name} in {elapsed:.2f}s ---")
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
def process_directory(self, limit: int = None):
|
| 54 |
+
"""Process all PDFs in the raw directory"""
|
| 55 |
+
pdfs = list(self.raw_dir.glob("**/*.pdf")) + list(self.raw_dir.glob("**/*.PDF"))
|
| 56 |
+
|
| 57 |
+
if limit:
|
| 58 |
+
pdfs = pdfs[:limit]
|
| 59 |
+
|
| 60 |
+
logger.info(f"Found {len(pdfs)} PDFs to process.")
|
| 61 |
+
|
| 62 |
+
success = 0
|
| 63 |
+
for pdf in pdfs:
|
| 64 |
+
if self.ingest_pdf(pdf):
|
| 65 |
+
success += 1
|
| 66 |
+
|
| 67 |
+
logger.info(f"Batch completed. {success}/{len(pdfs)} successful.")
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
parser = argparse.ArgumentParser(description="NyayLens Unified Ingestion Pipeline")
|
| 71 |
+
parser.add_argument("--pdf", type=str, help="Path to a single PDF to ingest")
|
| 72 |
+
parser.add_argument("--batch", action="store_true", help="Process all PDFs in data/raw")
|
| 73 |
+
parser.add_argument("--limit", type=int, default=None, help="Limit number of PDFs in batch")
|
| 74 |
+
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
|
| 77 |
+
pipeline = NyayLensPipeline()
|
| 78 |
+
|
| 79 |
+
if args.pdf:
|
| 80 |
+
pipeline.ingest_pdf(Path(args.pdf))
|
| 81 |
+
elif args.batch:
|
| 82 |
+
pipeline.process_directory(limit=args.limit)
|
| 83 |
+
else:
|
| 84 |
+
logger.warning("Please specify --pdf <path> or --batch. Use --help for options.")
|
src/qa/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
src/qa/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (462 Bytes). View file
|
|
|
src/qa/dataset.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
|
| 3 |
+
MAX_LENGTH = 384
|
| 4 |
+
DOC_STRIDE = 128
|
| 5 |
+
|
| 6 |
+
def load_and_prepare_dataset(tokenizer):
|
| 7 |
+
dataset = load_dataset("squad") # auto-download
|
| 8 |
+
|
| 9 |
+
def preprocess(examples):
|
| 10 |
+
questions = [q.strip() for q in examples["question"]]
|
| 11 |
+
contexts = examples["context"]
|
| 12 |
+
|
| 13 |
+
tokenized = tokenizer(
|
| 14 |
+
questions,
|
| 15 |
+
contexts,
|
| 16 |
+
truncation="only_second",
|
| 17 |
+
max_length=MAX_LENGTH,
|
| 18 |
+
stride=DOC_STRIDE,
|
| 19 |
+
return_overflowing_tokens=True,
|
| 20 |
+
return_offsets_mapping=True,
|
| 21 |
+
padding="max_length",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
sample_mapping = tokenized.pop("overflow_to_sample_mapping")
|
| 25 |
+
offset_mapping = tokenized.pop("offset_mapping")
|
| 26 |
+
|
| 27 |
+
start_positions = []
|
| 28 |
+
end_positions = []
|
| 29 |
+
|
| 30 |
+
for i, offsets in enumerate(offset_mapping):
|
| 31 |
+
input_ids = tokenized["input_ids"][i]
|
| 32 |
+
cls_index = input_ids.index(tokenizer.cls_token_id)
|
| 33 |
+
|
| 34 |
+
sample_idx = sample_mapping[i]
|
| 35 |
+
answer = examples["answers"][sample_idx]
|
| 36 |
+
|
| 37 |
+
if len(answer["answer_start"]) == 0:
|
| 38 |
+
start_positions.append(cls_index)
|
| 39 |
+
end_positions.append(cls_index)
|
| 40 |
+
else:
|
| 41 |
+
start_char = answer["answer_start"][0]
|
| 42 |
+
end_char = start_char + len(answer["text"][0])
|
| 43 |
+
|
| 44 |
+
token_start = token_end = None
|
| 45 |
+
for idx, (start, end) in enumerate(offsets):
|
| 46 |
+
if start <= start_char < end:
|
| 47 |
+
token_start = idx
|
| 48 |
+
if start < end_char <= end:
|
| 49 |
+
token_end = idx
|
| 50 |
+
break
|
| 51 |
+
|
| 52 |
+
if token_start is None or token_end is None:
|
| 53 |
+
start_positions.append(cls_index)
|
| 54 |
+
end_positions.append(cls_index)
|
| 55 |
+
else:
|
| 56 |
+
start_positions.append(token_start)
|
| 57 |
+
end_positions.append(token_end)
|
| 58 |
+
|
| 59 |
+
tokenized["start_positions"] = start_positions
|
| 60 |
+
tokenized["end_positions"] = end_positions
|
| 61 |
+
return tokenized
|
| 62 |
+
|
| 63 |
+
tokenized_dataset = dataset.map(
|
| 64 |
+
preprocess,
|
| 65 |
+
batched=True,
|
| 66 |
+
remove_columns=dataset["train"].column_names,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return tokenized_dataset
|
src/qa/inference.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import faiss
|
| 3 |
+
import json
|
| 4 |
+
import sqlite3
|
| 5 |
+
import re
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LegalQAEngine:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
print(f"Using device: {self.device}")
|
| 14 |
+
|
| 15 |
+
# ---- Load QA model ----
|
| 16 |
+
self.tokenizer = AutoTokenizer.from_pretrained("outputs/qa_model/final")
|
| 17 |
+
self.qa_model = AutoModelForQuestionAnswering.from_pretrained(
|
| 18 |
+
"outputs/qa_model/final"
|
| 19 |
+
).to(self.device)
|
| 20 |
+
self.qa_model.eval()
|
| 21 |
+
|
| 22 |
+
# ---- Load retriever ----
|
| 23 |
+
self.embedder = SentenceTransformer("BAAI/bge-base-en-v1.5", device=self.device)
|
| 24 |
+
self.index = faiss.read_index("data/processed/faiss/faiss_index.bin")
|
| 25 |
+
|
| 26 |
+
with open("data/processed/embeddings/paragraph_ids.json", encoding="utf-8") as f:
|
| 27 |
+
self.para_ids = json.load(f)
|
| 28 |
+
|
| 29 |
+
self.db = sqlite3.connect("data/processed/indexed/paragraphs.db")
|
| 30 |
+
self.cursor = self.db.cursor()
|
| 31 |
+
|
| 32 |
+
print("✓ Enhanced QA inference system ready")
|
| 33 |
+
|
| 34 |
+
# ------------------------------------------------------------------
|
| 35 |
+
# TEXT NORMALIZATION (critical for PDF artifacts)
|
| 36 |
+
# ------------------------------------------------------------------
|
| 37 |
+
def _normalize(self, text: str) -> str:
|
| 38 |
+
text = text.lower()
|
| 39 |
+
text = re.sub(r"\s+", " ", text)
|
| 40 |
+
return text.strip()
|
| 41 |
+
|
| 42 |
+
# ------------------------------------------------------------------
|
| 43 |
+
# REFUTED CLAUSE DETECTION (Article 21 FIX)
|
| 44 |
+
# ------------------------------------------------------------------
|
| 45 |
+
def _is_refuted_clause(self, answer_text, paragraph_text):
|
| 46 |
+
para = self._normalize(paragraph_text)
|
| 47 |
+
ans = self._normalize(answer_text)
|
| 48 |
+
|
| 49 |
+
# Patterns like:
|
| 50 |
+
# "it is not correct to say, ..., that X"
|
| 51 |
+
# "it cannot be said, ..., that X"
|
| 52 |
+
refutation_regexes = [
|
| 53 |
+
r"not correct to say.*?that\s+(.*?)(?:\.|,)",
|
| 54 |
+
r"cannot be said.*?that\s+(.*?)(?:\.|,)",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
for pattern in refutation_regexes:
|
| 58 |
+
matches = re.findall(pattern, para)
|
| 59 |
+
for refuted_prop in matches:
|
| 60 |
+
# If answer is part of the refuted proposition → block
|
| 61 |
+
if ans in refuted_prop:
|
| 62 |
+
return True
|
| 63 |
+
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ------------------------------------------------------------------
|
| 68 |
+
# RETRIEVAL
|
| 69 |
+
# ------------------------------------------------------------------
|
| 70 |
+
def retrieve_paragraphs(self, question, top_k=8):
|
| 71 |
+
q_emb = self.embedder.encode(
|
| 72 |
+
[question], normalize_embeddings=True, convert_to_numpy=True
|
| 73 |
+
)
|
| 74 |
+
scores, indices = self.index.search(q_emb, top_k)
|
| 75 |
+
|
| 76 |
+
results = []
|
| 77 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 78 |
+
para_id = self.para_ids[idx]
|
| 79 |
+
self.cursor.execute(
|
| 80 |
+
"SELECT judgment_id, page_no, text FROM paragraphs WHERE id = ?",
|
| 81 |
+
(para_id,),
|
| 82 |
+
)
|
| 83 |
+
row = self.cursor.fetchone()
|
| 84 |
+
if row:
|
| 85 |
+
judgment_id, page_no, text = row
|
| 86 |
+
results.append(
|
| 87 |
+
{
|
| 88 |
+
"judgment_id": judgment_id,
|
| 89 |
+
"page_no": page_no,
|
| 90 |
+
"text": text,
|
| 91 |
+
"retrieval_score": float(score),
|
| 92 |
+
}
|
| 93 |
+
)
|
| 94 |
+
return results
|
| 95 |
+
|
| 96 |
+
# ------------------------------------------------------------------
|
| 97 |
+
# ANSWERING
|
| 98 |
+
# ------------------------------------------------------------------
|
| 99 |
+
def answer_question(self, question, top_k=8, max_answers=2):
|
| 100 |
+
paragraphs = self.retrieve_paragraphs(question, top_k)
|
| 101 |
+
candidates = []
|
| 102 |
+
|
| 103 |
+
for para in paragraphs:
|
| 104 |
+
inputs = self.tokenizer(
|
| 105 |
+
question,
|
| 106 |
+
para["text"],
|
| 107 |
+
return_tensors="pt",
|
| 108 |
+
truncation=True,
|
| 109 |
+
max_length=512,
|
| 110 |
+
).to(self.device)
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
outputs = self.qa_model(**inputs)
|
| 114 |
+
|
| 115 |
+
start_logits = outputs.start_logits[0]
|
| 116 |
+
end_logits = outputs.end_logits[0]
|
| 117 |
+
|
| 118 |
+
token_type_ids = inputs["token_type_ids"][0].tolist()
|
| 119 |
+
question_end = token_type_ids.index(1)
|
| 120 |
+
|
| 121 |
+
top_starts = torch.topk(start_logits, k=5).indices
|
| 122 |
+
top_ends = torch.topk(end_logits, k=5).indices
|
| 123 |
+
|
| 124 |
+
for s in top_starts:
|
| 125 |
+
for e in top_ends:
|
| 126 |
+
if e < s or (e - s) > 80:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
# ❌ Block question echo
|
| 130 |
+
if s < question_end:
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
answer_tokens = inputs["input_ids"][0][s : e + 1]
|
| 134 |
+
answer_text = self.tokenizer.decode(
|
| 135 |
+
answer_tokens, skip_special_tokens=True
|
| 136 |
+
).strip()
|
| 137 |
+
|
| 138 |
+
words = answer_text.split()
|
| 139 |
+
if len(words) < 8:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
# ❌ Block refuted propositions
|
| 143 |
+
if self._is_refuted_clause(answer_text, para["text"]):
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
score = start_logits[s].item() + end_logits[e].item()
|
| 147 |
+
|
| 148 |
+
# Doctrinal boost
|
| 149 |
+
if any(
|
| 150 |
+
k in answer_text.lower()
|
| 151 |
+
for k in ["the court", "held that", "it is clear that", "the law"]
|
| 152 |
+
):
|
| 153 |
+
score += 1.5
|
| 154 |
+
|
| 155 |
+
candidates.append(
|
| 156 |
+
{
|
| 157 |
+
"answer": answer_text,
|
| 158 |
+
"confidence": score,
|
| 159 |
+
"judgment_id": para["judgment_id"],
|
| 160 |
+
"page_no": para["page_no"],
|
| 161 |
+
"paragraph": para["text"],
|
| 162 |
+
"retrieval_score": para["retrieval_score"],
|
| 163 |
+
}
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# ---- Deduplicate answers ----
|
| 167 |
+
seen = set()
|
| 168 |
+
final = []
|
| 169 |
+
for c in sorted(candidates, key=lambda x: x["confidence"], reverse=True):
|
| 170 |
+
key = self._normalize(c["answer"])
|
| 171 |
+
if key not in seen:
|
| 172 |
+
seen.add(key)
|
| 173 |
+
final.append(c)
|
| 174 |
+
|
| 175 |
+
return final[:max_answers]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ----------------------------------------------------------------------
|
| 179 |
+
# DEMO
|
| 180 |
+
# ----------------------------------------------------------------------
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
qa = LegalQAEngine()
|
| 183 |
+
|
| 184 |
+
questions = [
|
| 185 |
+
"What is the scope of Article 21?",
|
| 186 |
+
"What are the conditions for granting anticipatory bail?",
|
| 187 |
+
"What is the burden of proof in criminal law?",
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
for q in questions:
|
| 191 |
+
print("\n" + "=" * 90)
|
| 192 |
+
print(f"QUESTION: {q}")
|
| 193 |
+
print("=" * 90)
|
| 194 |
+
|
| 195 |
+
answers = qa.answer_question(q)
|
| 196 |
+
|
| 197 |
+
for i, ans in enumerate(answers, 1):
|
| 198 |
+
print(f"\nANSWER {i}:")
|
| 199 |
+
print(ans["answer"])
|
| 200 |
+
print(
|
| 201 |
+
f"\nSOURCE: {ans['judgment_id']} | Page: {ans['page_no']}"
|
| 202 |
+
)
|
| 203 |
+
print(f"Retrieval score: {ans['retrieval_score']:.3f}")
|
| 204 |
+
print(f"Confidence score: {ans['confidence']:.2f}")
|
| 205 |
+
print("\nPARAGRAPH:")
|
| 206 |
+
print(ans["paragraph"][:700] + "...")
|
src/qa/model.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
| 2 |
+
|
| 3 |
+
MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
|
| 4 |
+
|
| 5 |
+
def load_model():
|
| 6 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 7 |
+
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
|
| 8 |
+
return tokenizer, model
|
src/qa/monitor_training.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Monitor training progress"""
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
log_file = Path("outputs/qa_model/trainer_state.json")
|
| 7 |
+
|
| 8 |
+
print("Monitoring training progress...\n")
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
if log_file.exists():
|
| 12 |
+
with open(log_file) as f:
|
| 13 |
+
state = json.load(f)
|
| 14 |
+
|
| 15 |
+
if 'log_history' in state:
|
| 16 |
+
latest = state['log_history'][-1]
|
| 17 |
+
print(f"Step {latest.get('step', 0):5d} | "
|
| 18 |
+
f"Loss: {latest.get('loss', 0):.4f} | "
|
| 19 |
+
f"Eval Loss: {latest.get('eval_loss', 'N/A')}")
|
| 20 |
+
|
| 21 |
+
time.sleep(10)
|
src/qa/train.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import TrainingArguments, Trainer
|
| 2 |
+
from model import load_model
|
| 3 |
+
from dataset import load_and_prepare_dataset
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
tokenizer, model = load_model()
|
| 7 |
+
dataset = load_and_prepare_dataset(tokenizer)
|
| 8 |
+
|
| 9 |
+
training_args = TrainingArguments(
|
| 10 |
+
output_dir="outputs/qa_model",
|
| 11 |
+
eval_strategy="steps", # ✅ correct API
|
| 12 |
+
eval_steps=1000,
|
| 13 |
+
learning_rate=3e-5,
|
| 14 |
+
per_device_train_batch_size=8,
|
| 15 |
+
per_device_eval_batch_size=8,
|
| 16 |
+
num_train_epochs=2,
|
| 17 |
+
weight_decay=0.01,
|
| 18 |
+
fp16=True,
|
| 19 |
+
logging_steps=500,
|
| 20 |
+
save_steps=2000,
|
| 21 |
+
save_total_limit=2,
|
| 22 |
+
load_best_model_at_end=True,
|
| 23 |
+
metric_for_best_model="eval_loss",
|
| 24 |
+
report_to="none",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
trainer = Trainer(
|
| 28 |
+
model=model,
|
| 29 |
+
args=training_args,
|
| 30 |
+
train_dataset=dataset["train"],
|
| 31 |
+
eval_dataset=dataset["validation"],
|
| 32 |
+
tokenizer=tokenizer,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
trainer.train()
|
| 36 |
+
|
| 37 |
+
# Save final model + tokenizer
|
| 38 |
+
trainer.save_model("outputs/qa_model/final")
|
| 39 |
+
tokenizer.save_pretrained("outputs/qa_model/final")
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
src/rag/__pycache__/query_engine.cpython-310.pyc
ADDED
|
Binary file (8.49 kB). View file
|
|
|
src/rag/query_engine.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG Query Engine with LLM"""
|
| 2 |
+
|
| 3 |
+
# pyrefly: ignore [missing-import]
|
| 4 |
+
import faiss
|
| 5 |
+
import json
|
| 6 |
+
import sqlite3
|
| 7 |
+
import re
|
| 8 |
+
# pyrefly: ignore [missing-import]
|
| 9 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 10 |
+
import os
|
| 11 |
+
from groq import Groq
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
from src.summarization.ranker import ImportanceRanker
|
| 15 |
+
from src.summarization.utils import split_sentences
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
class QueryEngine:
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
print("Loading RAG components...")
|
| 23 |
+
|
| 24 |
+
# FAISS + SQLite + Embeddings
|
| 25 |
+
self.index = faiss.read_index("data/processed/faiss/faiss_index.bin")
|
| 26 |
+
|
| 27 |
+
with open("data/processed/embeddings/paragraph_ids.json") as f:
|
| 28 |
+
self.para_ids = json.load(f)
|
| 29 |
+
|
| 30 |
+
self.model = SentenceTransformer("BAAI/bge-base-en-v1.5")
|
| 31 |
+
self.reranker = CrossEncoder("BAAI/bge-reranker-base")
|
| 32 |
+
self.importance_ranker = ImportanceRanker("outputs/summarization/final")
|
| 33 |
+
|
| 34 |
+
def _get_db(self):
|
| 35 |
+
return sqlite3.connect("data/processed/indexed/paragraphs.db")
|
| 36 |
+
|
| 37 |
+
# LLM Setup
|
| 38 |
+
api_key = os.getenv('GROQ_API_KEY')
|
| 39 |
+
if not api_key:
|
| 40 |
+
raise ValueError("GROQ_API_KEY not set")
|
| 41 |
+
|
| 42 |
+
self.llm = Groq(api_key=api_key)
|
| 43 |
+
self.llm_model = 'llama-3.1-8b-instant'
|
| 44 |
+
|
| 45 |
+
print(f"✓ Ready with {self.index.ntotal:,} vectors")
|
| 46 |
+
print(f"✓ LLM: Groq (Llama 3.1 8B)")
|
| 47 |
+
|
| 48 |
+
def search(self, query: str, top_k: int = 5):
|
| 49 |
+
"""Hybrid Search: FAISS (Dense) + SQLite FTS5 (BM25) with RRF"""
|
| 50 |
+
|
| 51 |
+
# --- 1. Dense Search (FAISS) ---
|
| 52 |
+
query_vec = self.model.encode([query], normalize_embeddings=True)
|
| 53 |
+
dense_scores, dense_indices = self.index.search(query_vec, k=top_k * 2) # Fetch extra for fusion
|
| 54 |
+
|
| 55 |
+
dense_results = []
|
| 56 |
+
for rank, (score, idx) in enumerate(zip(dense_scores[0], dense_indices[0])):
|
| 57 |
+
para_id = self.para_ids[idx]
|
| 58 |
+
dense_results.append({'id': para_id, 'score': float(score), 'rank': rank + 1})
|
| 59 |
+
|
| 60 |
+
# --- 2. Keyword Search (SQLite FTS5 BM25) ---
|
| 61 |
+
db = self._get_db()
|
| 62 |
+
cursor = db.cursor()
|
| 63 |
+
|
| 64 |
+
# FTS5 requires a specific syntax. A raw string means AND.
|
| 65 |
+
# We want BM25 behavior (OR), so we clean punctuation and join words with OR.
|
| 66 |
+
import re
|
| 67 |
+
clean_query = re.sub(r'[^\w\s]', '', query)
|
| 68 |
+
fts_query = " OR ".join(clean_query.split())
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
cursor.execute(f"""
|
| 72 |
+
SELECT id, bm25(paragraphs_fts) as bm25_score
|
| 73 |
+
FROM paragraphs_fts
|
| 74 |
+
WHERE paragraphs_fts MATCH ?
|
| 75 |
+
ORDER BY bm25_score LIMIT ?
|
| 76 |
+
""", (fts_query, top_k * 2))
|
| 77 |
+
fts_rows = cursor.fetchall()
|
| 78 |
+
|
| 79 |
+
keyword_results = []
|
| 80 |
+
for rank, row in enumerate(fts_rows):
|
| 81 |
+
keyword_results.append({'id': row[0], 'score': float(row[1]), 'rank': rank + 1})
|
| 82 |
+
except sqlite3.OperationalError:
|
| 83 |
+
# Fallback if query syntax is too complex for basic MATCH or FTS table missing
|
| 84 |
+
keyword_results = []
|
| 85 |
+
|
| 86 |
+
# --- 3. Reciprocal Rank Fusion (RRF) ---
|
| 87 |
+
# RRF Score = 1 / (k + rank) where k is usually 60
|
| 88 |
+
k = 60
|
| 89 |
+
rrf_scores = {}
|
| 90 |
+
|
| 91 |
+
# Add dense scores
|
| 92 |
+
for res in dense_results:
|
| 93 |
+
pid = res['id']
|
| 94 |
+
rrf_scores[pid] = rrf_scores.get(pid, 0.0) + (1.0 / (k + res['rank']))
|
| 95 |
+
|
| 96 |
+
# Add keyword scores
|
| 97 |
+
for res in keyword_results:
|
| 98 |
+
pid = res['id']
|
| 99 |
+
rrf_scores[pid] = rrf_scores.get(pid, 0.0) + (1.0 / (k + res['rank']))
|
| 100 |
+
|
| 101 |
+
# Sort by RRF score descending
|
| 102 |
+
# Fetch a larger pool of candidates for reranking
|
| 103 |
+
candidate_pool_size = top_k * 3
|
| 104 |
+
sorted_rrf = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:candidate_pool_size]
|
| 105 |
+
|
| 106 |
+
# --- 4. Fetch Details & Rerank (Cross-Encoder) ---
|
| 107 |
+
candidates = []
|
| 108 |
+
for pid, rrf_score in sorted_rrf:
|
| 109 |
+
cursor.execute(
|
| 110 |
+
"SELECT judgment_id, text, page_no FROM paragraphs WHERE id = ?",
|
| 111 |
+
(pid,)
|
| 112 |
+
)
|
| 113 |
+
row = cursor.fetchone()
|
| 114 |
+
|
| 115 |
+
if row:
|
| 116 |
+
candidates.append({
|
| 117 |
+
'rrf_score': rrf_score,
|
| 118 |
+
'judgment_id': row[0],
|
| 119 |
+
'text': row[1],
|
| 120 |
+
'page_no': row[2],
|
| 121 |
+
'id': pid
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
if not candidates:
|
| 125 |
+
db.close()
|
| 126 |
+
return []
|
| 127 |
+
|
| 128 |
+
db.close()
|
| 129 |
+
|
| 130 |
+
# --- 5. Final Rerank (Cross-Encoder) ---
|
| 131 |
+
|
| 132 |
+
# Prepare inputs for cross-encoder: list of [query, document_text]
|
| 133 |
+
cross_inp = [[query, doc['text']] for doc in candidates]
|
| 134 |
+
rerank_scores = self.reranker.predict(cross_inp)
|
| 135 |
+
|
| 136 |
+
# Attach scores and sort
|
| 137 |
+
for i, score in enumerate(rerank_scores):
|
| 138 |
+
candidates[i]['score'] = float(score) # Use cross-encoder score as final score
|
| 139 |
+
|
| 140 |
+
candidates = sorted(candidates, key=lambda x: x['score'], reverse=True)
|
| 141 |
+
|
| 142 |
+
return candidates[:top_k]
|
| 143 |
+
|
| 144 |
+
def generate_answer(self, question: str, context: str, sources: list = [], chat_history: list = None):
|
| 145 |
+
"""Generate answer using Groq LLM with strict Legal Guardrails.
|
| 146 |
+
Sources list is injected into the prompt so the LLM can ONLY cite
|
| 147 |
+
what was actually retrieved — no hallucinated references.
|
| 148 |
+
"""
|
| 149 |
+
chat_history = chat_history or []
|
| 150 |
+
chat_history = chat_history[-6:] # Cap to last 3 turns
|
| 151 |
+
# Build a numbered source registry for the LLM
|
| 152 |
+
source_registry = ""
|
| 153 |
+
for i, s in enumerate(sources, 1):
|
| 154 |
+
source_registry += f"[{i}] {s.get('judgment_id', 'Unknown')}\n"
|
| 155 |
+
|
| 156 |
+
prompt = f"""You are a strict, brilliant legal research assistant specializing in Indian Supreme Court judgments.
|
| 157 |
+
|
| 158 |
+
GUARDRAIL: You MUST ONLY answer questions related to law, legal processes, or the provided context.
|
| 159 |
+
If the question is entirely unrelated to law (e.g., "how to bake a cake"), reply EXACTLY with:
|
| 160 |
+
"I am a legal AI assistant. I can only answer questions related to law."
|
| 161 |
+
|
| 162 |
+
CITATION RULES — THIS IS CRITICAL:
|
| 163 |
+
- You may ONLY cite sources from the APPROVED SOURCE LIST below.
|
| 164 |
+
- Do NOT cite any case from your training memory that is not in the APPROVED SOURCE LIST.
|
| 165 |
+
- If you cite a case not in this list, you are hallucinating and failing your task.
|
| 166 |
+
- Use [1], [2], [3] etc. to refer to sources from the list below.
|
| 167 |
+
|
| 168 |
+
APPROVED SOURCE LIST (cite ONLY these):
|
| 169 |
+
{source_registry}
|
| 170 |
+
CONTEXT (retrieved paragraphs):
|
| 171 |
+
{context}
|
| 172 |
+
|
| 173 |
+
QUESTION: {question}
|
| 174 |
+
|
| 175 |
+
INSTRUCTIONS:
|
| 176 |
+
- Provide a detailed, comprehensive legal answer in a professional conversational tone.
|
| 177 |
+
- Explain concepts clearly so a lawyer finds it extremely useful.
|
| 178 |
+
- Cite ONLY from the APPROVED SOURCE LIST above using [1], [2], [3] format.
|
| 179 |
+
- Use proper legal terminology.
|
| 180 |
+
- Do NOT invent case names, citations, or dates.
|
| 181 |
+
- TEMPORAL AWARENESS: Look at the years in the judgment titles (e.g. 2023_CaseName). Newer judgments (e.g. 2023) supersede older judgments (e.g. 2010). If the retrieved context contains conflicting rulings, you MUST prioritize the newer judgment and explicitly warn the user that the older precedent may have been superseded.
|
| 182 |
+
|
| 183 |
+
ANSWER:"""
|
| 184 |
+
|
| 185 |
+
messages = chat_history.copy()
|
| 186 |
+
messages.append({"role": "user", "content": prompt})
|
| 187 |
+
|
| 188 |
+
response = self.llm.chat.completions.create(
|
| 189 |
+
model=self.llm_model,
|
| 190 |
+
messages=messages,
|
| 191 |
+
temperature=0.2,
|
| 192 |
+
max_tokens=1024
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
return response.choices[0].message.content
|
| 196 |
+
|
| 197 |
+
def query(self, question: str, top_k: int = 5, chat_history: list = None):
|
| 198 |
+
"""Main query method"""
|
| 199 |
+
print(f"\n{'='*70}")
|
| 200 |
+
print(f"QUERY: {question}")
|
| 201 |
+
print('='*70)
|
| 202 |
+
|
| 203 |
+
# Search
|
| 204 |
+
print("\nSearching FAISS index...")
|
| 205 |
+
results = self.search(question, top_k)
|
| 206 |
+
|
| 207 |
+
print(f"Found {len(results)} relevant paragraphs")
|
| 208 |
+
|
| 209 |
+
# Format context
|
| 210 |
+
context_parts = []
|
| 211 |
+
for i, r in enumerate(results, 1):
|
| 212 |
+
context_parts.append(
|
| 213 |
+
f"[{i}] {r['judgment_id']}\n{r['text']}"
|
| 214 |
+
)
|
| 215 |
+
context = "\n\n".join(context_parts)
|
| 216 |
+
|
| 217 |
+
# Generate answer — pass sources so LLM can only cite what was retrieved
|
| 218 |
+
print("Generating answer with LLM...")
|
| 219 |
+
answer = self.generate_answer(question, context, sources=results, chat_history=chat_history)
|
| 220 |
+
|
| 221 |
+
return {
|
| 222 |
+
'question': question,
|
| 223 |
+
'answer': answer,
|
| 224 |
+
'sources': results
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
def query_with_document(self, question: str, filepath: str, chat_history: list = None):
|
| 228 |
+
"""Queries a specific document. Falls back to global RAG if answer not found."""
|
| 229 |
+
chat_history = chat_history or []
|
| 230 |
+
try:
|
| 231 |
+
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
| 232 |
+
doc_text = f.read()
|
| 233 |
+
except Exception as e:
|
| 234 |
+
return {"answer": f"Error reading document: {e}", "sources": []}
|
| 235 |
+
|
| 236 |
+
# JUDGMENTS are usually under 30k chars. Let's take as much as possible.
|
| 237 |
+
if len(doc_text) > 30000:
|
| 238 |
+
print("Document exceeds 30k chars. Applying Semantic Truncation...")
|
| 239 |
+
try:
|
| 240 |
+
sentences = [s for s in split_sentences(doc_text) if len(s.strip()) > 20]
|
| 241 |
+
scores = self.importance_ranker.score(sentences)
|
| 242 |
+
indexed = list(enumerate(zip(sentences, scores)))
|
| 243 |
+
sorted_by_score = sorted(indexed, key=lambda x: x[1][1], reverse=True)
|
| 244 |
+
|
| 245 |
+
selected_indices = []
|
| 246 |
+
current_chars = 0
|
| 247 |
+
for idx, (sentence, score) in sorted_by_score:
|
| 248 |
+
if current_chars + len(sentence) > 30000:
|
| 249 |
+
continue
|
| 250 |
+
selected_indices.append(idx)
|
| 251 |
+
current_chars += len(sentence)
|
| 252 |
+
if current_chars > 29000:
|
| 253 |
+
break
|
| 254 |
+
|
| 255 |
+
# Restore original chronological order
|
| 256 |
+
top_in_order = sorted([indexed[i] for i in selected_indices], key=lambda x: x[0])
|
| 257 |
+
doc_text = " ".join(s for _, (s, _) in top_in_order) + "\n\n... [TRUNCATED SEMANTICALLY FOR LLM] ..."
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(f"Semantic Truncation failed: {e}. Falling back to naive truncation.")
|
| 260 |
+
doc_text = doc_text[:30000] + "\n\n... [TRUNCATED DUE TO SIZE] ..."
|
| 261 |
+
|
| 262 |
+
print(f"--- DOCUMENT QA START ---")
|
| 263 |
+
print(f"File: {os.path.basename(filepath)}")
|
| 264 |
+
print(f"Size: {len(doc_text)} chars")
|
| 265 |
+
print(f"Question: {question}")
|
| 266 |
+
|
| 267 |
+
prompt = f"""You are a strict Legal Document Auditor.
|
| 268 |
+
Your ONLY source of information is the text provided below.
|
| 269 |
+
|
| 270 |
+
STRICT RULES:
|
| 271 |
+
1. Answer the QUESTION using ONLY the DOCUMENT text.
|
| 272 |
+
2. If the answer is not in the text, say "I cannot find this in the uploaded document."
|
| 273 |
+
3. DO NOT cite external cases (like Venkata Reddy or V.C. Shukla) unless they are explicitly mentioned in the text below.
|
| 274 |
+
4. If you use your own internal knowledge instead of the document, you are failing your task.
|
| 275 |
+
|
| 276 |
+
DOCUMENT TEXT:
|
| 277 |
+
{doc_text}
|
| 278 |
+
|
| 279 |
+
QUESTION: {question}
|
| 280 |
+
|
| 281 |
+
DETAILED ANSWER (citing specific paragraphs if possible):"""
|
| 282 |
+
|
| 283 |
+
messages = chat_history.copy()
|
| 284 |
+
messages.append({"role": "user", "content": prompt})
|
| 285 |
+
|
| 286 |
+
response = self.llm.chat.completions.create(
|
| 287 |
+
model=self.llm_model,
|
| 288 |
+
messages=messages,
|
| 289 |
+
temperature=0.1,
|
| 290 |
+
max_tokens=1024
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
answer = response.choices[0].message.content.strip()
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
'question': question,
|
| 297 |
+
'answer': answer,
|
| 298 |
+
'sources': [{'judgment_id': os.path.basename(filepath), 'score': 1.0}]
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def close(self):
|
| 302 |
+
self.db.close()
|
| 303 |
+
|
| 304 |
+
# Test
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
engine = QueryEngine()
|
| 307 |
+
|
| 308 |
+
# Test queries
|
| 309 |
+
queries = [
|
| 310 |
+
"What are the conditions for granting anticipatory bail?",
|
| 311 |
+
"Explain the doctrine of legitimate expectation",
|
| 312 |
+
"What is the burden of proof in criminal cases?"
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
for query in queries:
|
| 316 |
+
response = engine.query(query, top_k=3)
|
| 317 |
+
|
| 318 |
+
print(f"\nANSWER:\n{response['answer']}\n")
|
| 319 |
+
|
| 320 |
+
print("SOURCES:")
|
| 321 |
+
for i, src in enumerate(response['sources'], 1):
|
| 322 |
+
print(f" [{i}] {src['judgment_id']} (score: {src['score']:.3f})")
|
| 323 |
+
|
| 324 |
+
print("\n" + "="*70 + "\n")
|
| 325 |
+
|
| 326 |
+
engine.close()
|
src/rag/test_retriever.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test RAG retrieval system with SQLite
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import faiss
|
| 6 |
+
import json
|
| 7 |
+
import sqlite3
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
|
| 10 |
+
def test_retrieval():
|
| 11 |
+
print("="*70)
|
| 12 |
+
print("Testing RAG Retrieval")
|
| 13 |
+
print("="*70)
|
| 14 |
+
|
| 15 |
+
# Load FAISS index
|
| 16 |
+
index = faiss.read_index("data/processed/faiss/faiss_index.bin")
|
| 17 |
+
|
| 18 |
+
# Load paragraph IDs
|
| 19 |
+
with open("data/processed/embeddings/paragraph_ids.json", encoding="utf-8") as f:
|
| 20 |
+
para_ids = json.load(f)
|
| 21 |
+
|
| 22 |
+
# Connect to SQLite
|
| 23 |
+
conn = sqlite3.connect("data/processed/indexed/paragraphs.db")
|
| 24 |
+
cursor = conn.cursor()
|
| 25 |
+
|
| 26 |
+
# Load embedding model (MUST match indexing)
|
| 27 |
+
model = SentenceTransformer("BAAI/bge-base-en-v1.5")
|
| 28 |
+
|
| 29 |
+
print(f"✓ Ready with {index.ntotal:,} vectors\n")
|
| 30 |
+
|
| 31 |
+
queries = [
|
| 32 |
+
"right to privacy under Article 21",
|
| 33 |
+
"anticipatory bail conditions",
|
| 34 |
+
"burden of proof in criminal cases",
|
| 35 |
+
"doctrine of legitimate expectation"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
for query in queries:
|
| 39 |
+
print(f"\n{'='*70}")
|
| 40 |
+
print(f"QUERY: {query}")
|
| 41 |
+
print('='*70)
|
| 42 |
+
|
| 43 |
+
query_vec = model.encode(
|
| 44 |
+
[query],
|
| 45 |
+
normalize_embeddings=True,
|
| 46 |
+
convert_to_numpy=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
scores, indices = index.search(query_vec, k=3)
|
| 50 |
+
|
| 51 |
+
for rank, (score, idx) in enumerate(zip(scores[0], indices[0]), 1):
|
| 52 |
+
if idx < 0 or idx >= len(para_ids):
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
para_id = para_ids[idx]
|
| 56 |
+
|
| 57 |
+
cursor.execute(
|
| 58 |
+
"SELECT judgment_id, page_no, text FROM paragraphs WHERE id = ?",
|
| 59 |
+
(para_id,)
|
| 60 |
+
)
|
| 61 |
+
row = cursor.fetchone()
|
| 62 |
+
|
| 63 |
+
if not row:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
judgment_id, page_no, text = row
|
| 67 |
+
|
| 68 |
+
print(f"\n[{rank}] Score: {score:.4f}")
|
| 69 |
+
print(f"Source: {judgment_id} | Page: {page_no}")
|
| 70 |
+
print(f"Text: {text[:300]}...")
|
| 71 |
+
|
| 72 |
+
conn.close()
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
test_retrieval()
|
src/segmentation/__pycache__/judgement_segmenter.cpython-310.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
src/segmentation/annotate_paragraphs.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# Annotate paragraphs with legal sections using JudgmentSegmenter
|
| 3 |
+
# Creates paragraph_index_with_sections.jsonl
|
| 4 |
+
# """
|
| 5 |
+
|
| 6 |
+
# import json
|
| 7 |
+
# from pathlib import Path
|
| 8 |
+
# from collections import defaultdict
|
| 9 |
+
# from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
# from judgement_segmenter import JudgmentSegmenter
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# INPUT_INDEX = Path("data/processed/indexed/paragraph_index.jsonl")
|
| 15 |
+
# OUTPUT_INDEX = Path("data/processed/indexed/paragraph_index_with_sections.jsonl")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# def annotate_paragraphs():
|
| 19 |
+
# print("=" * 70)
|
| 20 |
+
# print("NyayLens – Annotating Paragraphs with Sections")
|
| 21 |
+
# print("=" * 70)
|
| 22 |
+
|
| 23 |
+
# # Load paragraphs grouped by judgment
|
| 24 |
+
# judgments = defaultdict(list)
|
| 25 |
+
|
| 26 |
+
# with open(INPUT_INDEX, "r", encoding="utf-8") as f:
|
| 27 |
+
# for line in f:
|
| 28 |
+
# p = json.loads(line)
|
| 29 |
+
# judgments[p["judgment_id"]].append(p)
|
| 30 |
+
|
| 31 |
+
# print(f"✓ Loaded {len(judgments):,} judgments")
|
| 32 |
+
|
| 33 |
+
# segmenter = JudgmentSegmenter()
|
| 34 |
+
|
| 35 |
+
# with open(OUTPUT_INDEX, "w", encoding="utf-8") as writer:
|
| 36 |
+
# for judgment_id, paras in tqdm(judgments.items(), desc="Annotating"):
|
| 37 |
+
# # Preserve original order
|
| 38 |
+
# paras = sorted(paras, key=lambda x: (x["page_no"], x["id"]))
|
| 39 |
+
|
| 40 |
+
# texts = [p["text"] for p in paras]
|
| 41 |
+
|
| 42 |
+
# sections = segmenter.segment(texts)
|
| 43 |
+
|
| 44 |
+
# # Default all to unknown
|
| 45 |
+
# section_labels = [
|
| 46 |
+
# ("unknown", 0.0) for _ in paras
|
| 47 |
+
# ]
|
| 48 |
+
|
| 49 |
+
# # Apply section labels
|
| 50 |
+
# for sec in sections:
|
| 51 |
+
# for i in range(sec.start_para_idx, sec.end_para_idx + 1):
|
| 52 |
+
# section_labels[i] = (sec.type, sec.confidence)
|
| 53 |
+
|
| 54 |
+
# # Write annotated paragraphs
|
| 55 |
+
# for p, (sec_type, sec_conf) in zip(paras, section_labels):
|
| 56 |
+
# p_out = dict(p)
|
| 57 |
+
# p_out["section"] = sec_type
|
| 58 |
+
# p_out["section_conf"] = sec_conf
|
| 59 |
+
|
| 60 |
+
# writer.write(json.dumps(p_out, ensure_ascii=False) + "\n")
|
| 61 |
+
|
| 62 |
+
# print("\n✓ Annotation complete")
|
| 63 |
+
# print(f"✓ Output written to: {OUTPUT_INDEX}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# if __name__ == "__main__":
|
| 67 |
+
# annotate_paragraphs()
|
| 68 |
+
"""
|
| 69 |
+
Annotate paragraphs with legal sections using JudgmentSegmenter
|
| 70 |
+
PRESERVES ORIGINAL IDs AND ORDER
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
import json
|
| 74 |
+
from pathlib import Path
|
| 75 |
+
from collections import defaultdict
|
| 76 |
+
from tqdm import tqdm
|
| 77 |
+
|
| 78 |
+
from judgement_segmenter import JudgmentSegmenter
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
INPUT_INDEX = Path("data/processed/indexed/paragraph_index.jsonl")
|
| 82 |
+
OUTPUT_INDEX = Path("data/processed/indexed/paragraph_index_with_sections.jsonl")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def annotate_paragraphs():
|
| 86 |
+
print("=" * 70)
|
| 87 |
+
print("NyayLens – Annotating Paragraphs with Sections")
|
| 88 |
+
print("=" * 70)
|
| 89 |
+
|
| 90 |
+
# Load paragraphs IN ORIGINAL ORDER
|
| 91 |
+
all_paragraphs = []
|
| 92 |
+
with open(INPUT_INDEX, "r", encoding="utf-8") as f:
|
| 93 |
+
for line in f:
|
| 94 |
+
all_paragraphs.append(json.loads(line))
|
| 95 |
+
|
| 96 |
+
print(f"✓ Loaded {len(all_paragraphs):,} paragraphs")
|
| 97 |
+
|
| 98 |
+
# Group by judgment (preserve index in group)
|
| 99 |
+
judgments = defaultdict(list)
|
| 100 |
+
for idx, p in enumerate(all_paragraphs):
|
| 101 |
+
judgments[p["judgment_id"]].append((idx, p)) # ← Store original index
|
| 102 |
+
|
| 103 |
+
segmenter = JudgmentSegmenter()
|
| 104 |
+
|
| 105 |
+
# Create array to store annotations (preserves original order)
|
| 106 |
+
annotations = [None] * len(all_paragraphs)
|
| 107 |
+
|
| 108 |
+
for judgment_id, indexed_paras in tqdm(judgments.items(), desc="Annotating"):
|
| 109 |
+
# Extract just the paragraphs
|
| 110 |
+
indices = [ip[0] for ip in indexed_paras]
|
| 111 |
+
paras = [ip[1] for ip in indexed_paras]
|
| 112 |
+
|
| 113 |
+
# Get texts
|
| 114 |
+
texts = [p["text"] for p in paras]
|
| 115 |
+
|
| 116 |
+
# Segment
|
| 117 |
+
sections = segmenter.segment(texts)
|
| 118 |
+
|
| 119 |
+
# Default labels
|
| 120 |
+
section_labels = [("unknown", 0.0) for _ in paras]
|
| 121 |
+
|
| 122 |
+
# Apply section labels
|
| 123 |
+
for sec in sections:
|
| 124 |
+
for i in range(sec.start_para_idx, sec.end_para_idx + 1):
|
| 125 |
+
if i < len(section_labels):
|
| 126 |
+
section_labels[i] = (sec.type, sec.confidence)
|
| 127 |
+
|
| 128 |
+
# Store annotations in ORIGINAL positions
|
| 129 |
+
for orig_idx, p, (sec_type, sec_conf) in zip(indices, paras, section_labels):
|
| 130 |
+
p_out = dict(p) # Copy original
|
| 131 |
+
p_out["section"] = sec_type
|
| 132 |
+
p_out["section_conf"] = sec_conf
|
| 133 |
+
annotations[orig_idx] = p_out
|
| 134 |
+
|
| 135 |
+
# Write in ORIGINAL order
|
| 136 |
+
print("\nWriting annotated paragraphs...")
|
| 137 |
+
with open(OUTPUT_INDEX, "w", encoding="utf-8") as writer:
|
| 138 |
+
for p_out in annotations:
|
| 139 |
+
writer.write(json.dumps(p_out, ensure_ascii=False) + "\n")
|
| 140 |
+
|
| 141 |
+
print(f"✓ Output written to: {OUTPUT_INDEX}")
|
| 142 |
+
print("=" * 70)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
annotate_paragraphs()
|
src/segmentation/check.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test section-aware retrieval"""
|
| 2 |
+
import faiss
|
| 3 |
+
import json
|
| 4 |
+
import sqlite3
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
# Load
|
| 9 |
+
index = faiss.read_index("data/processed/faiss/faiss_index.bin")
|
| 10 |
+
with open("data/processed/embeddings/paragraph_ids.json") as f:
|
| 11 |
+
para_ids = json.load(f)
|
| 12 |
+
|
| 13 |
+
db = sqlite3.connect("data/processed/indexed/paragraphs.db")
|
| 14 |
+
cursor = db.cursor()
|
| 15 |
+
|
| 16 |
+
model = SentenceTransformer("BAAI/bge-base-en-v1.5")
|
| 17 |
+
|
| 18 |
+
# Test query
|
| 19 |
+
query = "What were the facts of the case?"
|
| 20 |
+
query_vec = model.encode([query], normalize_embeddings=True)
|
| 21 |
+
|
| 22 |
+
# Search
|
| 23 |
+
scores, indices = index.search(query_vec, k=10)
|
| 24 |
+
|
| 25 |
+
print(f"Query: {query}\n")
|
| 26 |
+
print("Top results with sections:")
|
| 27 |
+
|
| 28 |
+
for i, (score, idx) in enumerate(zip(scores[0], indices[0]), 1):
|
| 29 |
+
para_id = para_ids[idx]
|
| 30 |
+
|
| 31 |
+
cursor.execute(
|
| 32 |
+
"SELECT judgment_id, section, section_conf, text FROM paragraphs WHERE id = ?",
|
| 33 |
+
(para_id,)
|
| 34 |
+
)
|
| 35 |
+
row = cursor.fetchone()
|
| 36 |
+
|
| 37 |
+
print(f"\n[{i}] Score: {score:.3f} | Section: {row[1]} (conf={row[2]:.2f})")
|
| 38 |
+
print(f" Case: {row[0]}")
|
| 39 |
+
print(f" Text: {row[3][:100]}...")
|
| 40 |
+
|
| 41 |
+
db.close()
|
src/segmentation/judgement_segmenter.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced Judgment Segmenter (FIXED)
|
| 3 |
+
Segments judgments into: Facts, Issues, Arguments, Analysis, Decision
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Tuple
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from transformers import pipeline
|
| 14 |
+
TRANSFORMERS_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
TRANSFORMERS_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class Section:
|
| 23 |
+
type: str # facts/issues/arguments/analysis/decision/unknown
|
| 24 |
+
text: str
|
| 25 |
+
start_para_idx: int
|
| 26 |
+
end_para_idx: int
|
| 27 |
+
confidence: float
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class JudgmentSegmenter:
|
| 31 |
+
|
| 32 |
+
MARKERS = {
|
| 33 |
+
'facts': [
|
| 34 |
+
r'\bbrief\s+facts?\b',
|
| 35 |
+
r'\bfactual\s+(matrix|background)\b',
|
| 36 |
+
r'\bcircumstances\s+of\s+the\s+case\b',
|
| 37 |
+
r'\bbackground\b',
|
| 38 |
+
],
|
| 39 |
+
'issues': [
|
| 40 |
+
r'\bissues?\s+(for|of)\s+(consideration|determination)\b',
|
| 41 |
+
r'\bsubstantial\s+questions?\b',
|
| 42 |
+
r'\bpoints?\s+for\s+consideration\b',
|
| 43 |
+
r'\bquestions?\s+framed\b',
|
| 44 |
+
],
|
| 45 |
+
'arguments': [
|
| 46 |
+
r'\blearned\s+counsel\b',
|
| 47 |
+
r'\bsubmissions?\b',
|
| 48 |
+
r'\b(argued|submitted|contended)\b',
|
| 49 |
+
r'\bon\s+behalf\s+of\b',
|
| 50 |
+
],
|
| 51 |
+
'analysis': [
|
| 52 |
+
r'\bwe\s+have\s+(considered|examined|analysed)\b',
|
| 53 |
+
r'\bthe\s+court\s+(finds|observes|notes|holds)\b',
|
| 54 |
+
r'\bin\s+our\s+(view|opinion)\b',
|
| 55 |
+
r'\bit\s+is\s+clear\s+that\b',
|
| 56 |
+
],
|
| 57 |
+
'decision': [
|
| 58 |
+
r'\b(appeal|petition|writ)\s+is\s+(allowed|dismissed)\b',
|
| 59 |
+
r'\baccordingly\b',
|
| 60 |
+
r'\bwe\s+direct\b',
|
| 61 |
+
r'\bheld\s*:\b',
|
| 62 |
+
r'\border\b',
|
| 63 |
+
]
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(self, model_path: str = "models/segmentation_model"):
|
| 67 |
+
"""Initialize segmenter, preferring ML model if available, else Regex fallback"""
|
| 68 |
+
self.use_ml = False
|
| 69 |
+
self.classifier = None
|
| 70 |
+
|
| 71 |
+
if TRANSFORMERS_AVAILABLE and os.path.exists(model_path):
|
| 72 |
+
try:
|
| 73 |
+
logger.info(f"Loading ML Segmentation model from {model_path}...")
|
| 74 |
+
self.classifier = pipeline("text-classification", model=model_path, device=-1)
|
| 75 |
+
self.use_ml = True
|
| 76 |
+
logger.info("✓ ML Segmenter loaded successfully.")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
logger.warning(f"Failed to load ML model, falling back to Regex: {e}")
|
| 79 |
+
else:
|
| 80 |
+
logger.info("ML model not found or transformers not installed. Using Regex fallback.")
|
| 81 |
+
|
| 82 |
+
def detect_section(self, para: str, position_ratio: float) -> Tuple[str, float]:
|
| 83 |
+
"""
|
| 84 |
+
Detect section type for a paragraph
|
| 85 |
+
Returns: (section_type, confidence)
|
| 86 |
+
"""
|
| 87 |
+
para_lower = para.lower()
|
| 88 |
+
best_type = 'unknown'
|
| 89 |
+
best_conf = 0.0
|
| 90 |
+
|
| 91 |
+
for sec_type, patterns in self.MARKERS.items():
|
| 92 |
+
for pattern in patterns:
|
| 93 |
+
if re.search(pattern, para_lower):
|
| 94 |
+
conf = 0.6
|
| 95 |
+
|
| 96 |
+
# Position-based bias
|
| 97 |
+
if sec_type == 'facts' and position_ratio < 0.30:
|
| 98 |
+
conf += 0.2
|
| 99 |
+
elif sec_type == 'decision' and position_ratio > 0.70:
|
| 100 |
+
conf += 0.3
|
| 101 |
+
|
| 102 |
+
# Strong anchor near paragraph start
|
| 103 |
+
if re.search(pattern, para_lower[:120]):
|
| 104 |
+
conf += 0.2
|
| 105 |
+
|
| 106 |
+
conf = min(conf, 1.0)
|
| 107 |
+
|
| 108 |
+
if conf > best_conf:
|
| 109 |
+
best_type = sec_type
|
| 110 |
+
best_conf = conf
|
| 111 |
+
|
| 112 |
+
return best_type, best_conf
|
| 113 |
+
|
| 114 |
+
def detect_section_ml(self, para: str) -> Tuple[str, float]:
|
| 115 |
+
"""Detect using HuggingFace classifier"""
|
| 116 |
+
if not para.strip() or not self.classifier:
|
| 117 |
+
return "unknown", 0.0
|
| 118 |
+
|
| 119 |
+
# Truncate to max length to avoid tokenization errors
|
| 120 |
+
truncated = para[:512]
|
| 121 |
+
result = self.classifier(truncated)[0]
|
| 122 |
+
|
| 123 |
+
# Assume labels are like LABEL_FACTS, LABEL_ISSUES or directly facts, issues
|
| 124 |
+
label = result['label'].lower().replace('label_', '')
|
| 125 |
+
score = result['score']
|
| 126 |
+
|
| 127 |
+
# Enforce confidence threshold
|
| 128 |
+
if score < 0.5:
|
| 129 |
+
return "unknown", score
|
| 130 |
+
|
| 131 |
+
return label, score
|
| 132 |
+
|
| 133 |
+
def segment(self, paragraph_texts: List[str]) -> List[Section]:
|
| 134 |
+
"""
|
| 135 |
+
Segment judgment based on paragraph list (INDEX-ALIGNED)
|
| 136 |
+
"""
|
| 137 |
+
if not paragraph_texts:
|
| 138 |
+
return []
|
| 139 |
+
|
| 140 |
+
sections: List[Section] = []
|
| 141 |
+
|
| 142 |
+
current_type = 'unknown'
|
| 143 |
+
current_paras = []
|
| 144 |
+
current_conf = 0.0
|
| 145 |
+
start_idx = 0
|
| 146 |
+
|
| 147 |
+
total = len(paragraph_texts)
|
| 148 |
+
|
| 149 |
+
for i, para in enumerate(paragraph_texts):
|
| 150 |
+
position_ratio = i / max(total, 1)
|
| 151 |
+
|
| 152 |
+
if self.use_ml:
|
| 153 |
+
sec_type, conf = self.detect_section_ml(para)
|
| 154 |
+
else:
|
| 155 |
+
sec_type, conf = self.detect_section(para, position_ratio)
|
| 156 |
+
|
| 157 |
+
# Fallback: early unknown paragraphs are likely facts
|
| 158 |
+
if sec_type == 'unknown' and position_ratio < 0.30 and i > 0:
|
| 159 |
+
sec_type = 'facts'
|
| 160 |
+
conf = 0.4
|
| 161 |
+
|
| 162 |
+
# Section boundary
|
| 163 |
+
if conf > 0.4 and sec_type != current_type:
|
| 164 |
+
if current_paras:
|
| 165 |
+
sections.append(
|
| 166 |
+
Section(
|
| 167 |
+
type=current_type,
|
| 168 |
+
text="\n\n".join(current_paras),
|
| 169 |
+
start_para_idx=start_idx,
|
| 170 |
+
end_para_idx=i - 1,
|
| 171 |
+
confidence=round(current_conf, 2)
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
current_type = sec_type
|
| 176 |
+
current_paras = [para]
|
| 177 |
+
current_conf = conf
|
| 178 |
+
start_idx = i
|
| 179 |
+
else:
|
| 180 |
+
current_paras.append(para)
|
| 181 |
+
current_conf = max(current_conf, conf)
|
| 182 |
+
|
| 183 |
+
# Final section
|
| 184 |
+
if current_paras:
|
| 185 |
+
sections.append(
|
| 186 |
+
Section(
|
| 187 |
+
type=current_type,
|
| 188 |
+
text="\n\n".join(current_paras),
|
| 189 |
+
start_para_idx=start_idx,
|
| 190 |
+
end_para_idx=total - 1,
|
| 191 |
+
confidence=round(current_conf, 2)
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
return sections
|
src/summarization/__pycache__/composer.cpython-310.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
src/summarization/__pycache__/inference.cpython-310.pyc
ADDED
|
Binary file (4.67 kB). View file
|
|
|
src/summarization/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (1.14 kB). View file
|
|
|
src/summarization/__pycache__/ranker.cpython-310.pyc
ADDED
|
Binary file (1.45 kB). View file
|
|
|
src/summarization/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (314 Bytes). View file
|
|
|
src/summarization/composer.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/summarization/composer.py
|
| 2 |
+
def compose(sentences, scores, top_k=5):
|
| 3 |
+
"""
|
| 4 |
+
Select top-k sentences by Legal-BERT score, then RESTORE their original
|
| 5 |
+
document order before returning. This gives PEGASUS a coherent narrative
|
| 6 |
+
instead of a randomly ordered bag of sentences.
|
| 7 |
+
"""
|
| 8 |
+
# Tag each sentence with its original index
|
| 9 |
+
indexed = list(enumerate(zip(sentences, scores)))
|
| 10 |
+
|
| 11 |
+
# Pick top-k by score
|
| 12 |
+
top = sorted(indexed, key=lambda x: x[1][1], reverse=True)[:top_k]
|
| 13 |
+
|
| 14 |
+
# Re-sort by original document position for narrative coherence
|
| 15 |
+
top_in_order = sorted(top, key=lambda x: x[0])
|
| 16 |
+
|
| 17 |
+
return [s for _, (s, _) in top_in_order]
|
src/summarization/dataset.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/summarization/dataset.py
|
| 2 |
+
import re
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from datasets import Dataset
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
IMPORTANT_PATTERNS = [
|
| 9 |
+
r"\bheld\b",
|
| 10 |
+
r"\bwe conclude\b",
|
| 11 |
+
r"\btherefore\b",
|
| 12 |
+
r"\bappeal is (allowed|dismissed)\b",
|
| 13 |
+
r"\bsubstantial question\b",
|
| 14 |
+
r"\baccordingly\b",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
def sentence_split(text):
|
| 18 |
+
return re.split(r'(?<=[.!?])\s+', text)
|
| 19 |
+
|
| 20 |
+
def is_important(sentence):
|
| 21 |
+
s = sentence.lower()
|
| 22 |
+
return any(re.search(p, s) for p in IMPORTANT_PATTERNS)
|
| 23 |
+
|
| 24 |
+
def build_dataset(text_dir, tokenizer_name, limit=None):
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 26 |
+
samples = []
|
| 27 |
+
|
| 28 |
+
files = list(Path(text_dir).glob("*.txt"))
|
| 29 |
+
if limit:
|
| 30 |
+
files = files[:limit]
|
| 31 |
+
|
| 32 |
+
for file in tqdm(files, desc="Processing judgments"):
|
| 33 |
+
judgment_id = file.stem
|
| 34 |
+
text = file.read_text(encoding="utf-8", errors="ignore")
|
| 35 |
+
|
| 36 |
+
sentences = sentence_split(text)
|
| 37 |
+
for sent in sentences:
|
| 38 |
+
sent = sent.strip()
|
| 39 |
+
if len(sent) < 40:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
samples.append({
|
| 43 |
+
"text": sent,
|
| 44 |
+
"label": int(is_important(sent)),
|
| 45 |
+
"judgment_id": judgment_id
|
| 46 |
+
})
|
| 47 |
+
|
| 48 |
+
dataset = Dataset.from_list(samples)
|
| 49 |
+
|
| 50 |
+
def tokenize(batch):
|
| 51 |
+
return tokenizer(
|
| 52 |
+
batch["text"],
|
| 53 |
+
truncation=True,
|
| 54 |
+
padding="max_length",
|
| 55 |
+
max_length=256
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return dataset.map(tokenize, batched=True)
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
print("Started dataset building...")
|
| 62 |
+
# Using a limit of 1000 for training, can be increased later
|
| 63 |
+
# 1000 judgments will yield ~50k-100k sentences, good for fine-tuning
|
| 64 |
+
ds = build_dataset(
|
| 65 |
+
"data/processed/extracted/texts",
|
| 66 |
+
"nlpaueb/legal-bert-base-uncased",
|
| 67 |
+
limit=1000
|
| 68 |
+
)
|
| 69 |
+
print("Tokenizing dataset... this may take a moment.")
|
| 70 |
+
print(f"Total sentences extracted: {len(ds)}")
|
| 71 |
+
|
| 72 |
+
print("Saving to Disk...")
|
| 73 |
+
ds.save_to_disk("data/processed/summarization_dataset")
|
| 74 |
+
print("✓ Dataset ready at data/processed/summarization_dataset")
|
src/summarization/inference.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/summarization/inference.py
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
| 8 |
+
|
| 9 |
+
from src.summarization.ranker import ImportanceRanker
|
| 10 |
+
from src.summarization.utils import split_sentences
|
| 11 |
+
from src.segmentation.judgement_segmenter import JudgmentSegmenter
|
| 12 |
+
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# ── Model ──────────────────────────────────────────────────────────────────────
|
| 16 |
+
MODEL_NAME = "nsi319/legal-pegasus"
|
| 17 |
+
print(f"\nLoading Abstractive Model ({MODEL_NAME})...")
|
| 18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
pegasus_tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME)
|
| 20 |
+
pegasus_model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
|
| 21 |
+
print(f"✓ Legal-PEGASUS loaded on {device.upper()}")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _pegasus_generate(text: str, max_length: int = 300, min_length: int = 100) -> str:
|
| 25 |
+
"""Run Legal-PEGASUS on a block of text and return the decoded summary."""
|
| 26 |
+
inputs = pegasus_tokenizer(
|
| 27 |
+
[text],
|
| 28 |
+
max_length=1024,
|
| 29 |
+
truncation=True,
|
| 30 |
+
padding=True,
|
| 31 |
+
return_tensors="pt"
|
| 32 |
+
).to(device)
|
| 33 |
+
|
| 34 |
+
outputs = pegasus_model.generate(
|
| 35 |
+
inputs["input_ids"],
|
| 36 |
+
max_length=max_length,
|
| 37 |
+
min_length=min_length,
|
| 38 |
+
num_beams=4, # Reduced from 8 for 2x speedup on CPU
|
| 39 |
+
length_penalty=1.2,
|
| 40 |
+
no_repeat_ngram_size=3,
|
| 41 |
+
repetition_penalty=1.3,
|
| 42 |
+
early_stopping=True,
|
| 43 |
+
)
|
| 44 |
+
decoded = pegasus_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 45 |
+
return re.sub(r'\s+', ' ', decoded.replace("<n>", " ")).strip()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def summarize(judgment_file: str) -> dict:
|
| 49 |
+
"""
|
| 50 |
+
Speed-Optimized Two-Pass Pipeline:
|
| 51 |
+
1. Case Overview: Legal-BERT (Extraction) -> Legal-PEGASUS (Abstraction) [1 Pass]
|
| 52 |
+
2. Detailed Sections: Legal-BERT (Extraction) -> Direct Output [No Abstraction pass to save 5+ minutes]
|
| 53 |
+
"""
|
| 54 |
+
text = Path(judgment_file).read_text(encoding="utf-8", errors="ignore")
|
| 55 |
+
|
| 56 |
+
# ── Step 1: Global sentence extraction ─────────────────────────────────────
|
| 57 |
+
all_sentences = [s for s in split_sentences(text) if len(s.strip()) > 40]
|
| 58 |
+
if not all_sentences:
|
| 59 |
+
return {"overview": "Could not extract readable text."}
|
| 60 |
+
|
| 61 |
+
ranker = ImportanceRanker("outputs/summarization/final")
|
| 62 |
+
scores = ranker.score(all_sentences)
|
| 63 |
+
|
| 64 |
+
# Token-Aware Global Overview Extract (Limit to ~950 tokens for Pegasus)
|
| 65 |
+
indexed = list(enumerate(zip(all_sentences, scores)))
|
| 66 |
+
sorted_by_score = sorted(indexed, key=lambda x: x[1][1], reverse=True)
|
| 67 |
+
|
| 68 |
+
selected_indices = []
|
| 69 |
+
current_tokens = 0
|
| 70 |
+
MAX_TOKENS = 950
|
| 71 |
+
|
| 72 |
+
for idx, (sentence, score) in sorted_by_score:
|
| 73 |
+
tokens = len(pegasus_tokenizer.encode(sentence, add_special_tokens=False))
|
| 74 |
+
if current_tokens + tokens > MAX_TOKENS:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
selected_indices.append(idx)
|
| 78 |
+
current_tokens += tokens
|
| 79 |
+
if current_tokens >= MAX_TOKENS - 20:
|
| 80 |
+
break
|
| 81 |
+
|
| 82 |
+
# Restore chronological order
|
| 83 |
+
top_in_order = sorted([indexed[i] for i in selected_indices], key=lambda x: x[0])
|
| 84 |
+
global_extract = " ".join(s for _, (s, _) in top_in_order)
|
| 85 |
+
|
| 86 |
+
# ── Pass 1: Abstractive Overview (The only heavy pass) ────────────────────
|
| 87 |
+
print("Generating Case Overview (Abstractive)...")
|
| 88 |
+
overview = _pegasus_generate(global_extract, max_length=250, min_length=80)
|
| 89 |
+
|
| 90 |
+
# ── Pass 2: Extractive Section Breakdown (Instant) ────────────────────────
|
| 91 |
+
segmenter = JudgmentSegmenter()
|
| 92 |
+
paragraphs = [p.strip() for p in text.split("\n\n") if len(p.strip()) > 20]
|
| 93 |
+
sections = segmenter.segment(paragraphs)
|
| 94 |
+
|
| 95 |
+
final_summary = {"overview": overview}
|
| 96 |
+
|
| 97 |
+
print("Generating Section Breakdowns (Extractive - Instant)...")
|
| 98 |
+
for section in sections:
|
| 99 |
+
sec_type = section.type.lower()
|
| 100 |
+
if sec_type == 'unknown': continue
|
| 101 |
+
|
| 102 |
+
sentences = [s for s in split_sentences(section.text) if len(s.strip()) > 40]
|
| 103 |
+
if not sentences: continue
|
| 104 |
+
|
| 105 |
+
sec_scores = ranker.score(sentences)
|
| 106 |
+
# Select top 3 per section for readability
|
| 107 |
+
s_indexed = list(enumerate(zip(sentences, sec_scores)))
|
| 108 |
+
top_k = sorted(s_indexed, key=lambda x: x[1][1], reverse=True)[:3]
|
| 109 |
+
top_k_ordered = sorted(top_k, key=lambda x: x[0])
|
| 110 |
+
|
| 111 |
+
# We use original sentences here to save 5-10 minutes of CPU time
|
| 112 |
+
final_summary[sec_type] = " ".join(s for _, (s, _) in top_k_ordered)
|
| 113 |
+
|
| 114 |
+
return final_summary
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
file = list(Path("data/processed/extracted/texts").glob("*.txt"))[0]
|
| 119 |
+
print(f"\nProcessing {file.name}...")
|
| 120 |
+
result = summarize(file)
|
| 121 |
+
|
| 122 |
+
print("\n\nCOMPREHENSIVE LEGAL SUMMARY (Global Legal-BERT + Legal-PEGASUS)\n" + "=" * 80)
|
| 123 |
+
print("\n[CASE OVERVIEW]")
|
| 124 |
+
print(result.get("overview", ""))
|
| 125 |
+
for sec in ['facts', 'issues', 'arguments', 'analysis', 'decision']:
|
| 126 |
+
if sec in result:
|
| 127 |
+
print(f"\n[{sec.upper()}]")
|
| 128 |
+
print(result[sec])
|
| 129 |
+
print("\n" + "=" * 80)
|
src/summarization/model.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/summarization/model.py
|
| 2 |
+
from transformers import AutoModel
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
class SentenceRanker(nn.Module):
|
| 7 |
+
def __init__(self, model_name):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.encoder = AutoModel.from_pretrained(model_name)
|
| 10 |
+
self.classifier = nn.Linear(self.encoder.config.hidden_size, 1)
|
| 11 |
+
|
| 12 |
+
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
|
| 13 |
+
out = self.encoder(
|
| 14 |
+
input_ids=input_ids,
|
| 15 |
+
attention_mask=attention_mask,
|
| 16 |
+
**kwargs
|
| 17 |
+
)
|
| 18 |
+
cls = out.last_hidden_state[:, 0]
|
| 19 |
+
logits = self.classifier(cls).squeeze(-1)
|
| 20 |
+
|
| 21 |
+
loss = None
|
| 22 |
+
if labels is not None:
|
| 23 |
+
loss = nn.BCEWithLogitsLoss()(logits, labels.float())
|
| 24 |
+
|
| 25 |
+
return {"loss": loss, "logits": logits}
|
src/summarization/ranker.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/summarization/ranker.py
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from src.summarization.model import SentenceRanker
|
| 5 |
+
|
| 6 |
+
class ImportanceRanker:
|
| 7 |
+
def __init__(self, model_dir, base_model="nlpaueb/legal-bert-base-uncased"):
|
| 8 |
+
# Load the tokenizer from the base model
|
| 9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 10 |
+
|
| 11 |
+
# Initialize the custom architecture with base model
|
| 12 |
+
self.model = SentenceRanker(base_model)
|
| 13 |
+
|
| 14 |
+
# Load fine-tuned weights
|
| 15 |
+
import os
|
| 16 |
+
from safetensors.torch import load_file
|
| 17 |
+
|
| 18 |
+
weights_path = os.path.join(model_dir, "model.safetensors")
|
| 19 |
+
if os.path.exists(weights_path):
|
| 20 |
+
state_dict = load_file(weights_path)
|
| 21 |
+
self.model.load_state_dict(state_dict)
|
| 22 |
+
else:
|
| 23 |
+
print(f"Warning: Could not find {weights_path}")
|
| 24 |
+
|
| 25 |
+
self.model.eval()
|
| 26 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
+
self.model.to(self.device)
|
| 28 |
+
|
| 29 |
+
def score(self, sentences):
|
| 30 |
+
inputs = self.tokenizer(
|
| 31 |
+
sentences,
|
| 32 |
+
truncation=True,
|
| 33 |
+
padding=True,
|
| 34 |
+
return_tensors="pt"
|
| 35 |
+
).to(self.device)
|
| 36 |
+
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
logits = self.model(**inputs)["logits"]
|
| 39 |
+
|
| 40 |
+
return logits.sigmoid().cpu().tolist()
|
src/summarization/train.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/summarization/train.py
|
| 2 |
+
|
| 3 |
+
from datasets import load_from_disk
|
| 4 |
+
from transformers import Trainer, TrainingArguments
|
| 5 |
+
from model import SentenceRanker
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
# Load dataset
|
| 12 |
+
dataset = load_from_disk("data/processed/summarization_dataset")
|
| 13 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
| 14 |
+
|
| 15 |
+
model = SentenceRanker(MODEL_NAME)
|
| 16 |
+
|
| 17 |
+
training_args = TrainingArguments(
|
| 18 |
+
output_dir="outputs/summarization",
|
| 19 |
+
per_device_train_batch_size=16,
|
| 20 |
+
per_device_eval_batch_size=16,
|
| 21 |
+
num_train_epochs=2,
|
| 22 |
+
learning_rate=2e-5,
|
| 23 |
+
logging_steps=500,
|
| 24 |
+
save_steps=2000,
|
| 25 |
+
save_total_limit=2,
|
| 26 |
+
report_to="none",
|
| 27 |
+
fp16=torch.cuda.is_available()
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
trainer = Trainer(
|
| 31 |
+
model=model,
|
| 32 |
+
args=training_args,
|
| 33 |
+
train_dataset=dataset["train"],
|
| 34 |
+
eval_dataset=dataset["test"]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
trainer.train()
|
| 38 |
+
trainer.save_model("outputs/summarization/final")
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
src/summarization/utils.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/summarization/utils.py
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
def split_sentences(text):
|
| 5 |
+
return re.split(r'(?<=[.!?])\s+', text)
|