Spaces:
Running
Running
File size: 5,388 Bytes
968e24d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | # src/summarization/inference.py
import sys
import os
import re
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from src.summarization.ranker import ImportanceRanker
from src.summarization.utils import split_sentences
from src.segmentation.judgement_segmenter import JudgmentSegmenter
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
import torch
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_NAME = "nsi319/legal-pegasus"
print(f"\nLoading Abstractive Model ({MODEL_NAME})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
pegasus_tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME)
pegasus_model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
print(f"β Legal-PEGASUS loaded on {device.upper()}")
def _pegasus_generate(text: str, max_length: int = 300, min_length: int = 100) -> str:
"""Run Legal-PEGASUS on a block of text and return the decoded summary."""
inputs = pegasus_tokenizer(
[text],
max_length=1024,
truncation=True,
padding=True,
return_tensors="pt"
).to(device)
outputs = pegasus_model.generate(
inputs["input_ids"],
max_length=max_length,
min_length=min_length,
num_beams=4, # Reduced from 8 for 2x speedup on CPU
length_penalty=1.2,
no_repeat_ngram_size=3,
repetition_penalty=1.3,
early_stopping=True,
)
decoded = pegasus_tokenizer.decode(outputs[0], skip_special_tokens=True)
return re.sub(r'\s+', ' ', decoded.replace("<n>", " ")).strip()
def summarize(judgment_file: str) -> dict:
"""
Speed-Optimized Two-Pass Pipeline:
1. Case Overview: Legal-BERT (Extraction) -> Legal-PEGASUS (Abstraction) [1 Pass]
2. Detailed Sections: Legal-BERT (Extraction) -> Direct Output [No Abstraction pass to save 5+ minutes]
"""
text = Path(judgment_file).read_text(encoding="utf-8", errors="ignore")
# ββ Step 1: Global sentence extraction βββββββββββββββββββββββββββββββββββββ
all_sentences = [s for s in split_sentences(text) if len(s.strip()) > 40]
if not all_sentences:
return {"overview": "Could not extract readable text."}
ranker = ImportanceRanker("outputs/summarization/final")
scores = ranker.score(all_sentences)
# Token-Aware Global Overview Extract (Limit to ~950 tokens for Pegasus)
indexed = list(enumerate(zip(all_sentences, scores)))
sorted_by_score = sorted(indexed, key=lambda x: x[1][1], reverse=True)
selected_indices = []
current_tokens = 0
MAX_TOKENS = 950
for idx, (sentence, score) in sorted_by_score:
tokens = len(pegasus_tokenizer.encode(sentence, add_special_tokens=False))
if current_tokens + tokens > MAX_TOKENS:
continue
selected_indices.append(idx)
current_tokens += tokens
if current_tokens >= MAX_TOKENS - 20:
break
# Restore chronological order
top_in_order = sorted([indexed[i] for i in selected_indices], key=lambda x: x[0])
global_extract = " ".join(s for _, (s, _) in top_in_order)
# ββ Pass 1: Abstractive Overview (The only heavy pass) ββββββββββββββββββββ
print("Generating Case Overview (Abstractive)...")
overview = _pegasus_generate(global_extract, max_length=250, min_length=80)
# ββ Pass 2: Extractive Section Breakdown (Instant) ββββββββββββββββββββββββ
segmenter = JudgmentSegmenter()
paragraphs = [p.strip() for p in text.split("\n\n") if len(p.strip()) > 20]
sections = segmenter.segment(paragraphs)
final_summary = {"overview": overview}
print("Generating Section Breakdowns (Extractive - Instant)...")
for section in sections:
sec_type = section.type.lower()
if sec_type == 'unknown': continue
sentences = [s for s in split_sentences(section.text) if len(s.strip()) > 40]
if not sentences: continue
sec_scores = ranker.score(sentences)
# Select top 3 per section for readability
s_indexed = list(enumerate(zip(sentences, sec_scores)))
top_k = sorted(s_indexed, key=lambda x: x[1][1], reverse=True)[:3]
top_k_ordered = sorted(top_k, key=lambda x: x[0])
# We use original sentences here to save 5-10 minutes of CPU time
final_summary[sec_type] = " ".join(s for _, (s, _) in top_k_ordered)
return final_summary
if __name__ == "__main__":
file = list(Path("data/processed/extracted/texts").glob("*.txt"))[0]
print(f"\nProcessing {file.name}...")
result = summarize(file)
print("\n\nCOMPREHENSIVE LEGAL SUMMARY (Global Legal-BERT + Legal-PEGASUS)\n" + "=" * 80)
print("\n[CASE OVERVIEW]")
print(result.get("overview", ""))
for sec in ['facts', 'issues', 'arguments', 'analysis', 'decision']:
if sec in result:
print(f"\n[{sec.upper()}]")
print(result[sec])
print("\n" + "=" * 80)
|