Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import re | |
| from typing import List, Tuple, Optional | |
| import gradio as gr | |
| import numpy as np | |
| from sklearn.cluster import KMeans | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModel, | |
| AutoModelForSeq2SeqLM, | |
| AutoConfig, | |
| ) | |
| # ----------------------------- | |
| # Defaults (feel free to change) | |
| # ----------------------------- | |
| DEFAULT_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # fast & solid | |
| # For legal focus, you can try: "nlpaueb/legal-bert-base-uncased" with mean-pooled embeddings (slower) | |
| DEFAULT_ABS_MODEL = "pszemraj/led-large-book-summary" # good LED variant for long docs | |
| FALLBACK_ABS_MODEL = "allenai/led-base-16384" | |
| MAX_INPUT_TOKENS = 12000 # safety cap before chunking for LED | |
| WINDOW_TOKENS = 3500 # per chunk for LED/Long models | |
| OVERLAP_TOKENS = 250 # overlap between chunks | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------- | |
| # Utilities | |
| # ----------------------------- | |
| def simple_sentence_split(text: str) -> List[str]: | |
| """ | |
| Robust-enough sentence splitter without external downloads. | |
| Splits by . ! ? and newline while keeping abbreviations modestly safe. | |
| """ | |
| # Normalize whitespace | |
| text = re.sub(r"\s+", " ", text).strip() | |
| # Anchor on punctuation that likely ends sentences | |
| candidates = re.split(r"(?<=[.!?])\s+", text) | |
| # Merge tiny fragments back (e.g., "No." followed by "23.") | |
| merged = [] | |
| buf = "" | |
| for c in candidates: | |
| frag = c.strip() | |
| if not frag: | |
| continue | |
| if not buf: | |
| buf = frag | |
| else: | |
| # if the fragment is very short (like section numbers), attach it back | |
| if len(frag) <= 3 and re.match(r"^[\(\)\[\]\dA-Za-z\-:;+.,]+$", frag): | |
| buf += " " + frag | |
| else: | |
| merged.append(buf) | |
| buf = frag | |
| if buf: | |
| merged.append(buf) | |
| # Filter empties and duplicates | |
| merged = [s.strip() for s in merged if s.strip()] | |
| return merged | |
| def mean_pooling(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Mean-pool token embeddings with attention mask.""" | |
| mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() | |
| summed = torch.sum(last_hidden_state * mask, dim=1) | |
| counts = torch.clamp(mask.sum(dim=1), min=1e-9) | |
| return summed / counts | |
| def load_embedder(model_name: str): | |
| """ | |
| Load an embedding model. | |
| If it's a sentence-transformers model, AutoModel works; we do manual mean pooling. | |
| """ | |
| tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| mdl = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
| mdl.to(DEVICE) | |
| mdl.eval() | |
| return tok, mdl | |
| def embed_sentences(sentences: List[str], tok, mdl, batch_size: int = 32) -> np.ndarray: | |
| vecs = [] | |
| for i in range(0, len(sentences), batch_size): | |
| batch = sentences[i:i+batch_size] | |
| enc = tok( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt" | |
| ) | |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} | |
| out = mdl(**enc) | |
| sent_emb = mean_pooling(out.last_hidden_state, enc["attention_mask"]) | |
| sent_emb = torch.nn.functional.normalize(sent_emb, p=2, dim=1) | |
| vecs.append(sent_emb.cpu().numpy()) | |
| return np.vstack(vecs) | |
| def choose_k(n_sent: int, user_k: Optional[int]) -> int: | |
| if user_k and user_k > 0: | |
| return min(user_k, n_sent) | |
| # heuristic: sqrt(n) but clamped | |
| k = max(5, int(math.sqrt(max(1, n_sent)))) | |
| return min(k, n_sent) | |
| def kmeans_select(sentences: List[str], embeddings: np.ndarray, k: int, pick: int = 1) -> List[int]: | |
| """ | |
| KMeans and pick `pick` sentences closest to each centroid. | |
| Returns sorted indices to preserve a logical reading flow. | |
| """ | |
| if k <= 0 or len(sentences) == 0: | |
| return [] | |
| # Edge case: fewer sentences than k | |
| k = min(k, len(sentences)) | |
| kmeans = KMeans(n_clusters=k, n_init="auto", random_state=42) | |
| labels = kmeans.fit_predict(embeddings) | |
| chosen = [] | |
| for c in range(k): | |
| idxs = np.where(labels == c)[0] | |
| # distances to centroid | |
| dists = np.linalg.norm(embeddings[idxs] - kmeans.cluster_centers_[c], axis=1) | |
| local_order = np.argsort(dists)[:pick] | |
| chosen.extend(idxs[local_order].tolist()) | |
| chosen = sorted(set(chosen)) | |
| return chosen | |
| def load_abstractive_model(model_name: str): | |
| cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
| mdl.to(DEVICE) | |
| mdl.eval() | |
| return tok, mdl, cfg | |
| def chunk_tokens(tokens: List[int], window: int, overlap: int) -> List[List[int]]: | |
| if window <= 0: | |
| return [tokens] | |
| chunks = [] | |
| i = 0 | |
| while i < len(tokens): | |
| chunks.append(tokens[i:i+window]) | |
| i += max(1, window - overlap) | |
| return chunks | |
| def run_abstractive( | |
| text: str, | |
| model_name: str, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| min_len: int = 40, | |
| window_tokens: int = WINDOW_TOKENS, | |
| overlap_tokens: int = OVERLAP_TOKENS, | |
| ) -> str: | |
| tok, mdl, cfg = load_abstractive_model(model_name) | |
| # Tokenize large text and process in windows | |
| enc = tok(text, return_tensors="pt", truncation=False) | |
| input_ids = enc["input_ids"].squeeze(0).tolist() | |
| if len(input_ids) <= window_tokens: | |
| enc = tok(text, return_tensors="pt", truncation=True, max_length=window_tokens).to(DEVICE) | |
| gen = mdl.generate( | |
| **enc, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=1.1, | |
| no_repeat_ngram_size=3, | |
| min_length=min_len, | |
| ) | |
| return tok.decode(gen[0], skip_special_tokens=True) | |
| # sliding window | |
| parts = [] | |
| for chunk in chunk_tokens(input_ids, window_tokens, overlap_tokens): | |
| enc_chunk = {"input_ids": torch.tensor([chunk]).to(DEVICE), | |
| "attention_mask": torch.ones((1, len(chunk)), dtype=torch.long, device=DEVICE)} | |
| gen = mdl.generate( | |
| **enc_chunk, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=1.1, | |
| no_repeat_ngram_size=3, | |
| min_length=min_len, | |
| ) | |
| parts.append(tok.decode(gen[0], skip_special_tokens=True)) | |
| # Simple merge; optionally re-summarize the stitched text once more | |
| stitched = "\n".join(parts) | |
| if len(stitched.split()) > 600: | |
| enc2 = tok(stitched, return_tensors="pt", truncation=True, max_length=window_tokens).to(DEVICE) | |
| gen2 = mdl.generate( | |
| **enc2, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=1.1, | |
| no_repeat_ngram_size=3, | |
| min_length=min_len, | |
| ) | |
| return tok.decode(gen2[0], skip_special_tokens=True) | |
| return stitched | |
| def pipeline( | |
| text: str, | |
| embed_model: str, | |
| abs_model: Optional[str], | |
| k_clusters: Optional[int], | |
| pick_per_cluster: int, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| min_len: int, | |
| ) -> Tuple[str, str, str]: | |
| """ | |
| Returns: (extractive_core, abstractive_summary, debug_info) | |
| """ | |
| if not text or not text.strip(): | |
| return "", "", "No input text." | |
| # 1) sentence split | |
| sentences = simple_sentence_split(text) | |
| if len(sentences) == 0: | |
| return "", "", "No sentences detected after splitting." | |
| # 2) embeddings | |
| etok, emdl = load_embedder(embed_model) | |
| embs = embed_sentences(sentences, etok, emdl, batch_size=32) | |
| # 3) clustering + representative pick | |
| k = choose_k(len(sentences), k_clusters) | |
| chosen_idx = kmeans_select(sentences, embs, k, pick=pick_per_cluster) | |
| extractive = " ".join([sentences[i] for i in chosen_idx]) | |
| # 4) abstractive (optional) | |
| abstractive = "" | |
| model_used = abs_model or "" | |
| if abs_model and abs_model.strip().lower() != "none": | |
| try: | |
| abstractive = run_abstractive( | |
| extractive if len(extractive) > 0 else text, | |
| model_name=abs_model, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| min_len=min_len, | |
| ) | |
| except Exception as e: | |
| # fall back to different LED if available | |
| if abs_model != FALLBACK_ABS_MODEL: | |
| try: | |
| abstractive = run_abstractive( | |
| extractive if len(extractive) > 0 else text, | |
| model_name=FALLBACK_ABS_MODEL, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| min_len=min_len, | |
| ) | |
| model_used = f"{abs_model} -> fell back to {FALLBACK_ABS_MODEL}" | |
| except Exception as e2: | |
| abstractive = "" | |
| model_used = f"{abs_model} (failed) & fallback failed" | |
| else: | |
| abstractive = "" | |
| model_used = f"{abs_model} (failed)" | |
| # debug | |
| dbg = ( | |
| f"Device: {DEVICE}\n" | |
| f"Embedder: {embed_model}\n" | |
| f"Abstractive model: {model_used or 'None'}\n" | |
| f"Sentences: {len(sentences)} | K: {k} | Pick/cluster: {pick_per_cluster}\n" | |
| f"Chosen indices (sorted): {chosen_idx[:50]}{'...' if len(chosen_idx) > 50 else ''}\n" | |
| ) | |
| return extractive, (abstractive or extractive), dbg | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| EXAMPLE_LEGAL = """IN THE SUPREME COURT OF INDIA | |
| Civil Appeal No. 1234 of 2021 | |
| The appellant contends that the High Court erred in overlooking binding precedent on limitation. | |
| The respondent argues that the delay is inordinate and unexplained. The core issue is whether | |
| sufficient cause exists under Section 5 of the Limitation Act. After hearing the parties and | |
| perusing the record, we find that the appellant was prevented by bona fide reasons. Accordingly, | |
| the delay is condoned subject to costs of Rs. 10,000. The matter is remanded to the High Court | |
| for disposal on merits in accordance with law.""" | |
| with gr.Blocks(title="Legal Summarizer (K-Means + BERT + LED)") as demo: | |
| gr.Markdown( | |
| """ | |
| # ⚖️ Legal Text Summarizer — K-Means + BERT + LED | |
| Upload/paste a judgment or order. We cluster sentences with BERT embeddings (extractive core), then optionally refine with an LED/Long model for an abstractive final summary. | |
| - **Embedding model**: any BERT/sentence-transformers model | |
| - **Abstractive model**: LED/LongT5/T5 from Hugging Face (can handle long docs with chunking) | |
| - Works well on Indian legal text; swap to `nlpaueb/legal-bert-base-uncased` for domain flavor (slower). | |
| """ | |
| ) | |
| with gr.Row(): | |
| text_in = gr.Textbox( | |
| label="Paste Legal Text", | |
| lines=18, | |
| placeholder="Paste a long judgment, order, or legal article…", | |
| value=EXAMPLE_LEGAL | |
| ) | |
| with gr.Accordion("Models & Settings", open=True): | |
| with gr.Row(): | |
| embed_model = gr.Textbox( | |
| label="Embedding Model (BERT/Sentence-Transformers)", | |
| value=DEFAULT_EMBED_MODEL, | |
| info="E.g., sentence-transformers/all-MiniLM-L6-v2 or nlpaueb/legal-bert-base-uncased" | |
| ) | |
| abs_model = gr.Textbox( | |
| label="Abstractive Model (LED/LongT5/T5) or 'none'", | |
| value=DEFAULT_ABS_MODEL, | |
| info="Try: pszemraj/led-large-book-summary, allenai/led-base-16384, google/long-t5-tglobal-base" | |
| ) | |
| with gr.Row(): | |
| k_clusters = gr.Number(label="K (clusters). Leave 0 to auto (≈√N)", value=0, precision=0) | |
| pick_per = gr.Slider(label="Sentences picked per cluster", minimum=1, maximum=3, value=1, step=1) | |
| with gr.Row(): | |
| max_new = gr.Slider(label="Max new tokens (abstractive)", minimum=64, maximum=1024, value=256, step=16) | |
| temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.2, value=0.7, step=0.05) | |
| top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.05) | |
| min_len = gr.Slider(label="Min summary length (tokens)", minimum=10, maximum=200, value=40, step=5) | |
| run_btn = gr.Button("Summarize 🚀", variant="primary") | |
| extractive_out = gr.Textbox(label="Extractive Core (cluster representatives)", lines=10) | |
| abstractive_out = gr.Textbox(label="Final Summary (abstractive if model provided)", lines=10) | |
| debug_out = gr.Textbox(label="Debug Info", lines=8) | |
| def _go(text, e_model, a_model, k, pick, mx, temp, topp, minl): | |
| k = int(k) if k else 0 | |
| return pipeline( | |
| text=text, | |
| embed_model=e_model.strip(), | |
| abs_model=a_model.strip() if a_model else "none", | |
| k_clusters=k, | |
| pick_per_cluster=int(pick), | |
| max_new_tokens=int(mx), | |
| temperature=float(temp), | |
| top_p=float(topp), | |
| min_len=int(minl), | |
| ) | |
| run_btn.click( | |
| _go, | |
| inputs=[text_in, embed_model, abs_model, k_clusters, pick_per, max_new, temperature, top_p, min_len], | |
| outputs=[extractive_out, abstractive_out, debug_out] | |
| ) | |
| if __name__ == "__main__": | |
| # For Spaces/Colab inline preview | |
| demo.launch() | |