Legal / app.py
harshitmahour360's picture
Create app.py
53ea575 verified
import os
import math
import re
from typing import List, Tuple, Optional
import gradio as gr
import numpy as np
from sklearn.cluster import KMeans
import torch
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSeq2SeqLM,
AutoConfig,
)
# -----------------------------
# Defaults (feel free to change)
# -----------------------------
DEFAULT_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # fast & solid
# For legal focus, you can try: "nlpaueb/legal-bert-base-uncased" with mean-pooled embeddings (slower)
DEFAULT_ABS_MODEL = "pszemraj/led-large-book-summary" # good LED variant for long docs
FALLBACK_ABS_MODEL = "allenai/led-base-16384"
MAX_INPUT_TOKENS = 12000 # safety cap before chunking for LED
WINDOW_TOKENS = 3500 # per chunk for LED/Long models
OVERLAP_TOKENS = 250 # overlap between chunks
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Utilities
# -----------------------------
def simple_sentence_split(text: str) -> List[str]:
"""
Robust-enough sentence splitter without external downloads.
Splits by . ! ? and newline while keeping abbreviations modestly safe.
"""
# Normalize whitespace
text = re.sub(r"\s+", " ", text).strip()
# Anchor on punctuation that likely ends sentences
candidates = re.split(r"(?<=[.!?])\s+", text)
# Merge tiny fragments back (e.g., "No." followed by "23.")
merged = []
buf = ""
for c in candidates:
frag = c.strip()
if not frag:
continue
if not buf:
buf = frag
else:
# if the fragment is very short (like section numbers), attach it back
if len(frag) <= 3 and re.match(r"^[\(\)\[\]\dA-Za-z\-:;+.,]+$", frag):
buf += " " + frag
else:
merged.append(buf)
buf = frag
if buf:
merged.append(buf)
# Filter empties and duplicates
merged = [s.strip() for s in merged if s.strip()]
return merged
def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Mean-pool token embeddings with attention mask."""
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
summed = torch.sum(last_hidden_state * mask, dim=1)
counts = torch.clamp(mask.sum(dim=1), min=1e-9)
return summed / counts
def load_embedder(model_name: str):
"""
Load an embedding model.
If it's a sentence-transformers model, AutoModel works; we do manual mean pooling.
"""
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
mdl = AutoModel.from_pretrained(model_name, trust_remote_code=True)
mdl.to(DEVICE)
mdl.eval()
return tok, mdl
@torch.inference_mode()
def embed_sentences(sentences: List[str], tok, mdl, batch_size: int = 32) -> np.ndarray:
vecs = []
for i in range(0, len(sentences), batch_size):
batch = sentences[i:i+batch_size]
enc = tok(
batch,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
out = mdl(**enc)
sent_emb = mean_pooling(out.last_hidden_state, enc["attention_mask"])
sent_emb = torch.nn.functional.normalize(sent_emb, p=2, dim=1)
vecs.append(sent_emb.cpu().numpy())
return np.vstack(vecs)
def choose_k(n_sent: int, user_k: Optional[int]) -> int:
if user_k and user_k > 0:
return min(user_k, n_sent)
# heuristic: sqrt(n) but clamped
k = max(5, int(math.sqrt(max(1, n_sent))))
return min(k, n_sent)
def kmeans_select(sentences: List[str], embeddings: np.ndarray, k: int, pick: int = 1) -> List[int]:
"""
KMeans and pick `pick` sentences closest to each centroid.
Returns sorted indices to preserve a logical reading flow.
"""
if k <= 0 or len(sentences) == 0:
return []
# Edge case: fewer sentences than k
k = min(k, len(sentences))
kmeans = KMeans(n_clusters=k, n_init="auto", random_state=42)
labels = kmeans.fit_predict(embeddings)
chosen = []
for c in range(k):
idxs = np.where(labels == c)[0]
# distances to centroid
dists = np.linalg.norm(embeddings[idxs] - kmeans.cluster_centers_[c], axis=1)
local_order = np.argsort(dists)[:pick]
chosen.extend(idxs[local_order].tolist())
chosen = sorted(set(chosen))
return chosen
def load_abstractive_model(model_name: str):
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
mdl.to(DEVICE)
mdl.eval()
return tok, mdl, cfg
def chunk_tokens(tokens: List[int], window: int, overlap: int) -> List[List[int]]:
if window <= 0:
return [tokens]
chunks = []
i = 0
while i < len(tokens):
chunks.append(tokens[i:i+window])
i += max(1, window - overlap)
return chunks
@torch.inference_mode()
def run_abstractive(
text: str,
model_name: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
min_len: int = 40,
window_tokens: int = WINDOW_TOKENS,
overlap_tokens: int = OVERLAP_TOKENS,
) -> str:
tok, mdl, cfg = load_abstractive_model(model_name)
# Tokenize large text and process in windows
enc = tok(text, return_tensors="pt", truncation=False)
input_ids = enc["input_ids"].squeeze(0).tolist()
if len(input_ids) <= window_tokens:
enc = tok(text, return_tensors="pt", truncation=True, max_length=window_tokens).to(DEVICE)
gen = mdl.generate(
**enc,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
min_length=min_len,
)
return tok.decode(gen[0], skip_special_tokens=True)
# sliding window
parts = []
for chunk in chunk_tokens(input_ids, window_tokens, overlap_tokens):
enc_chunk = {"input_ids": torch.tensor([chunk]).to(DEVICE),
"attention_mask": torch.ones((1, len(chunk)), dtype=torch.long, device=DEVICE)}
gen = mdl.generate(
**enc_chunk,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
min_length=min_len,
)
parts.append(tok.decode(gen[0], skip_special_tokens=True))
# Simple merge; optionally re-summarize the stitched text once more
stitched = "\n".join(parts)
if len(stitched.split()) > 600:
enc2 = tok(stitched, return_tensors="pt", truncation=True, max_length=window_tokens).to(DEVICE)
gen2 = mdl.generate(
**enc2,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
min_length=min_len,
)
return tok.decode(gen2[0], skip_special_tokens=True)
return stitched
def pipeline(
text: str,
embed_model: str,
abs_model: Optional[str],
k_clusters: Optional[int],
pick_per_cluster: int,
max_new_tokens: int,
temperature: float,
top_p: float,
min_len: int,
) -> Tuple[str, str, str]:
"""
Returns: (extractive_core, abstractive_summary, debug_info)
"""
if not text or not text.strip():
return "", "", "No input text."
# 1) sentence split
sentences = simple_sentence_split(text)
if len(sentences) == 0:
return "", "", "No sentences detected after splitting."
# 2) embeddings
etok, emdl = load_embedder(embed_model)
embs = embed_sentences(sentences, etok, emdl, batch_size=32)
# 3) clustering + representative pick
k = choose_k(len(sentences), k_clusters)
chosen_idx = kmeans_select(sentences, embs, k, pick=pick_per_cluster)
extractive = " ".join([sentences[i] for i in chosen_idx])
# 4) abstractive (optional)
abstractive = ""
model_used = abs_model or ""
if abs_model and abs_model.strip().lower() != "none":
try:
abstractive = run_abstractive(
extractive if len(extractive) > 0 else text,
model_name=abs_model,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
min_len=min_len,
)
except Exception as e:
# fall back to different LED if available
if abs_model != FALLBACK_ABS_MODEL:
try:
abstractive = run_abstractive(
extractive if len(extractive) > 0 else text,
model_name=FALLBACK_ABS_MODEL,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
min_len=min_len,
)
model_used = f"{abs_model} -> fell back to {FALLBACK_ABS_MODEL}"
except Exception as e2:
abstractive = ""
model_used = f"{abs_model} (failed) & fallback failed"
else:
abstractive = ""
model_used = f"{abs_model} (failed)"
# debug
dbg = (
f"Device: {DEVICE}\n"
f"Embedder: {embed_model}\n"
f"Abstractive model: {model_used or 'None'}\n"
f"Sentences: {len(sentences)} | K: {k} | Pick/cluster: {pick_per_cluster}\n"
f"Chosen indices (sorted): {chosen_idx[:50]}{'...' if len(chosen_idx) > 50 else ''}\n"
)
return extractive, (abstractive or extractive), dbg
# -----------------------------
# Gradio UI
# -----------------------------
EXAMPLE_LEGAL = """IN THE SUPREME COURT OF INDIA
Civil Appeal No. 1234 of 2021
The appellant contends that the High Court erred in overlooking binding precedent on limitation.
The respondent argues that the delay is inordinate and unexplained. The core issue is whether
sufficient cause exists under Section 5 of the Limitation Act. After hearing the parties and
perusing the record, we find that the appellant was prevented by bona fide reasons. Accordingly,
the delay is condoned subject to costs of Rs. 10,000. The matter is remanded to the High Court
for disposal on merits in accordance with law."""
with gr.Blocks(title="Legal Summarizer (K-Means + BERT + LED)") as demo:
gr.Markdown(
"""
# ⚖️ Legal Text Summarizer — K-Means + BERT + LED
Upload/paste a judgment or order. We cluster sentences with BERT embeddings (extractive core), then optionally refine with an LED/Long model for an abstractive final summary.
- **Embedding model**: any BERT/sentence-transformers model
- **Abstractive model**: LED/LongT5/T5 from Hugging Face (can handle long docs with chunking)
- Works well on Indian legal text; swap to `nlpaueb/legal-bert-base-uncased` for domain flavor (slower).
"""
)
with gr.Row():
text_in = gr.Textbox(
label="Paste Legal Text",
lines=18,
placeholder="Paste a long judgment, order, or legal article…",
value=EXAMPLE_LEGAL
)
with gr.Accordion("Models & Settings", open=True):
with gr.Row():
embed_model = gr.Textbox(
label="Embedding Model (BERT/Sentence-Transformers)",
value=DEFAULT_EMBED_MODEL,
info="E.g., sentence-transformers/all-MiniLM-L6-v2 or nlpaueb/legal-bert-base-uncased"
)
abs_model = gr.Textbox(
label="Abstractive Model (LED/LongT5/T5) or 'none'",
value=DEFAULT_ABS_MODEL,
info="Try: pszemraj/led-large-book-summary, allenai/led-base-16384, google/long-t5-tglobal-base"
)
with gr.Row():
k_clusters = gr.Number(label="K (clusters). Leave 0 to auto (≈√N)", value=0, precision=0)
pick_per = gr.Slider(label="Sentences picked per cluster", minimum=1, maximum=3, value=1, step=1)
with gr.Row():
max_new = gr.Slider(label="Max new tokens (abstractive)", minimum=64, maximum=1024, value=256, step=16)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.2, value=0.7, step=0.05)
top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.05)
min_len = gr.Slider(label="Min summary length (tokens)", minimum=10, maximum=200, value=40, step=5)
run_btn = gr.Button("Summarize 🚀", variant="primary")
extractive_out = gr.Textbox(label="Extractive Core (cluster representatives)", lines=10)
abstractive_out = gr.Textbox(label="Final Summary (abstractive if model provided)", lines=10)
debug_out = gr.Textbox(label="Debug Info", lines=8)
def _go(text, e_model, a_model, k, pick, mx, temp, topp, minl):
k = int(k) if k else 0
return pipeline(
text=text,
embed_model=e_model.strip(),
abs_model=a_model.strip() if a_model else "none",
k_clusters=k,
pick_per_cluster=int(pick),
max_new_tokens=int(mx),
temperature=float(temp),
top_p=float(topp),
min_len=int(minl),
)
run_btn.click(
_go,
inputs=[text_in, embed_model, abs_model, k_clusters, pick_per, max_new, temperature, top_p, min_len],
outputs=[extractive_out, abstractive_out, debug_out]
)
if __name__ == "__main__":
# For Spaces/Colab inline preview
demo.launch()