File size: 5,388 Bytes
968e24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# src/summarization/inference.py
import sys
import os
import re
from pathlib import Path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

from src.summarization.ranker import ImportanceRanker
from src.summarization.utils import split_sentences
from src.segmentation.judgement_segmenter import JudgmentSegmenter
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
import torch

# ── Model ──────────────────────────────────────────────────────────────────────
MODEL_NAME = "nsi319/legal-pegasus"
print(f"\nLoading Abstractive Model ({MODEL_NAME})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
pegasus_tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME)
pegasus_model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
print(f"βœ“ Legal-PEGASUS loaded on {device.upper()}")


def _pegasus_generate(text: str, max_length: int = 300, min_length: int = 100) -> str:
    """Run Legal-PEGASUS on a block of text and return the decoded summary."""
    inputs = pegasus_tokenizer(
        [text],
        max_length=1024,
        truncation=True,
        padding=True,
        return_tensors="pt"
    ).to(device)

    outputs = pegasus_model.generate(
        inputs["input_ids"],
        max_length=max_length,
        min_length=min_length,
        num_beams=4,                # Reduced from 8 for 2x speedup on CPU
        length_penalty=1.2,
        no_repeat_ngram_size=3,
        repetition_penalty=1.3,
        early_stopping=True,
    )
    decoded = pegasus_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return re.sub(r'\s+', ' ', decoded.replace("<n>", " ")).strip()


def summarize(judgment_file: str) -> dict:
    """
    Speed-Optimized Two-Pass Pipeline:
    1. Case Overview: Legal-BERT (Extraction) -> Legal-PEGASUS (Abstraction) [1 Pass]
    2. Detailed Sections: Legal-BERT (Extraction) -> Direct Output [No Abstraction pass to save 5+ minutes]
    """
    text = Path(judgment_file).read_text(encoding="utf-8", errors="ignore")

    # ── Step 1: Global sentence extraction ─────────────────────────────────────
    all_sentences = [s for s in split_sentences(text) if len(s.strip()) > 40]
    if not all_sentences:
        return {"overview": "Could not extract readable text."}

    ranker = ImportanceRanker("outputs/summarization/final")
    scores = ranker.score(all_sentences)

    # Token-Aware Global Overview Extract (Limit to ~950 tokens for Pegasus)
    indexed = list(enumerate(zip(all_sentences, scores)))
    sorted_by_score = sorted(indexed, key=lambda x: x[1][1], reverse=True)
    
    selected_indices = []
    current_tokens = 0
    MAX_TOKENS = 950
    
    for idx, (sentence, score) in sorted_by_score:
        tokens = len(pegasus_tokenizer.encode(sentence, add_special_tokens=False))
        if current_tokens + tokens > MAX_TOKENS:
            continue
            
        selected_indices.append(idx)
        current_tokens += tokens
        if current_tokens >= MAX_TOKENS - 20:
            break
            
    # Restore chronological order
    top_in_order = sorted([indexed[i] for i in selected_indices], key=lambda x: x[0])
    global_extract = " ".join(s for _, (s, _) in top_in_order)

    # ── Pass 1: Abstractive Overview (The only heavy pass) ────────────────────
    print("Generating Case Overview (Abstractive)...")
    overview = _pegasus_generate(global_extract, max_length=250, min_length=80)

    # ── Pass 2: Extractive Section Breakdown (Instant) ────────────────────────
    segmenter = JudgmentSegmenter()
    paragraphs = [p.strip() for p in text.split("\n\n") if len(p.strip()) > 20]
    sections = segmenter.segment(paragraphs)

    final_summary = {"overview": overview}

    print("Generating Section Breakdowns (Extractive - Instant)...")
    for section in sections:
        sec_type = section.type.lower()
        if sec_type == 'unknown': continue

        sentences = [s for s in split_sentences(section.text) if len(s.strip()) > 40]
        if not sentences: continue

        sec_scores = ranker.score(sentences)
        # Select top 3 per section for readability
        s_indexed = list(enumerate(zip(sentences, sec_scores)))
        top_k = sorted(s_indexed, key=lambda x: x[1][1], reverse=True)[:3]
        top_k_ordered = sorted(top_k, key=lambda x: x[0])
        
        # We use original sentences here to save 5-10 minutes of CPU time
        final_summary[sec_type] = " ".join(s for _, (s, _) in top_k_ordered)

    return final_summary


if __name__ == "__main__":
    file = list(Path("data/processed/extracted/texts").glob("*.txt"))[0]
    print(f"\nProcessing {file.name}...")
    result = summarize(file)

    print("\n\nCOMPREHENSIVE LEGAL SUMMARY (Global Legal-BERT + Legal-PEGASUS)\n" + "=" * 80)
    print("\n[CASE OVERVIEW]")
    print(result.get("overview", ""))
    for sec in ['facts', 'issues', 'arguments', 'analysis', 'decision']:
        if sec in result:
            print(f"\n[{sec.upper()}]")
            print(result[sec])
    print("\n" + "=" * 80)