File size: 7,255 Bytes
6dcb6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
Indonesian Court Document Summarization using Pre-trained Models

This module uses pre-trained multilingual models (mT5 or mBART) that already
understand Indonesian language and can generate summaries without extensive training.
"""

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    pipeline
)


class IndonesianCourtSummarizer:
    """
    Summarizer for Indonesian court documents using pre-trained models.

    Supports multiple model backends:
    - google/mt5-small (lightweight, good for quick inference)
    - google/mt5-base (balanced performance and quality)
    - csebuetnlp/mT5_multilingual_XLSum (fine-tuned for summarization)
    - facebook/mbart-large-50 (multilingual BART)
    """

    def __init__(self, model_name="csebuetnlp/mT5_multilingual_XLSum", device=None):
        """
        Initialize the summarizer with a pre-trained model.

        Args:
            model_name: HuggingFace model identifier
            device: Device to run on ('cuda', 'cpu', or None for auto-detect)
        """
        print(f"Loading model: {model_name}")

        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        print(f"Using device: {self.device}")

        # Load tokenizer and model
        try:
            # Try loading with fast tokenizer first
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        except Exception as e:
            print(f"Fast tokenizer failed, trying slow tokenizer...")
            # Fallback to slow tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()

        print("Model loaded successfully!\n")

    def summarize(self,
                  document,
                  max_length=256,
                  min_length=50,
                  num_beams=4,
                  length_penalty=1.0,
                  early_stopping=True,
                  add_prefix=True):
        """
        Generate summary for an Indonesian court document.

        Args:
            document: The full court document text
            max_length: Maximum length of the summary
            min_length: Minimum length of the summary
            num_beams: Number of beams for beam search (higher = better quality, slower)
            length_penalty: Length penalty (>1.0 = longer summaries, <1.0 = shorter)
            early_stopping: Stop when all beams reach EOS
            add_prefix: Add "summarize: " prefix for some models

        Returns:
            Generated summary text
        """
        # Prepare input
        if add_prefix:
            input_text = f"summarize: {document}"
        else:
            input_text = document

        # Tokenize
        inputs = self.tokenizer(
            input_text,
            max_length=1024,  # Max input length
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        # Generate summary
        with torch.no_grad():
            summary_ids = self.model.generate(
                inputs.input_ids,
                max_length=max_length,
                min_length=min_length,
                num_beams=num_beams,
                length_penalty=length_penalty,
                early_stopping=early_stopping,
                no_repeat_ngram_size=3  # Avoid repetition
            )

        # Decode
        summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

        return summary.strip()

    def summarize_batch(self, documents, **kwargs):
        """
        Summarize multiple documents in batch.

        Args:
            documents: List of court document texts
            **kwargs: Additional arguments passed to summarize()

        Returns:
            List of generated summaries
        """
        summaries = []
        for i, doc in enumerate(documents):
            print(f"Summarizing document {i+1}/{len(documents)}...")
            summary = self.summarize(doc, **kwargs)
            summaries.append(summary)
        return summaries


def create_summarizer(model_choice="balanced"):
    """
    Factory function to create a summarizer with recommended settings.

    Args:
        model_choice:
            - "fast": Fastest, good for testing (mt5-small)
            - "balanced": Good balance of speed and quality (mT5_multilingual_XLSum)
            - "quality": Best quality, slower (mt5-base)

    Returns:
        IndonesianCourtSummarizer instance
    """
    model_map = {
        "fast": "google/mt5-small",
        "balanced": "csebuetnlp/mT5_multilingual_XLSum",
        "quality": "google/mt5-base"
    }

    model_name = model_map.get(model_choice, model_choice)
    return IndonesianCourtSummarizer(model_name=model_name)


# Convenience function for quick summarization
def quick_summarize(document, model_choice="balanced", **kwargs):
    """
    Quick one-line summarization function.

    Args:
        document: Court document text
        model_choice: Model to use ("fast", "balanced", or "quality")
        **kwargs: Additional arguments for summarize()

    Returns:
        Summary text
    """
    summarizer = create_summarizer(model_choice)
    return summarizer.summarize(document, **kwargs)


def main():
    """Example usage"""

    # Example Indonesian court document
    court_document = """PUTUSAN
Nomor 245/Pid.B/2024/PN.Jkt.Pst

DEMI KEADILAN BERDASARKAN KETUHANAN YANG MAHA ESA

Pengadilan Negeri Jakarta Pusat yang mengadili perkara pidana dengan acara pemeriksaan
biasa dalam tingkat pertama menjatuhkan putusan sebagai berikut dalam perkara Terdakwa
BUDI SANTOSO yang didakwa melakukan tindak pidana penggelapan. Terdakwa menerima uang
sejumlah Rp 350.000.000 dari PT Sejahtera Abadi untuk pembelian bahan baku namun
menggunakan uang tersebut untuk keperluan pribadi. Berdasarkan keterangan saksi-saksi
dan barang bukti, Majelis Hakim memutuskan Terdakwa terbukti secara sah dan meyakinkan
bersalah melakukan tindak pidana penggelapan sebagaimana diatur dalam Pasal 372 KUHP.
Terdakwa dijatuhi pidana penjara selama 2 tahun 6 bulan. Barang bukti berupa dokumen
perjanjian kerjasama, buku tabungan, dan bukti transfer dirampas untuk negara. Terdakwa
ditetapkan tetap ditahan dan dibebankan biaya perkara sebesar Rp 5.000."""

    print("=" * 70)
    print("Indonesian Court Document Summarization Demo")
    print("Using Pre-trained Multilingual Models")
    print("=" * 70)
    print()

    # Create summarizer
    print("Creating summarizer with balanced model...")
    summarizer = create_summarizer("balanced")

    print("=" * 70)
    print("ORIGINAL DOCUMENT:")
    print("=" * 70)
    print(court_document)
    print()

    print("=" * 70)
    print("GENERATING SUMMARY...")
    print("=" * 70)

    # Generate summary
    summary = summarizer.summarize(
        court_document,
        max_length=150,
        min_length=30,
        num_beams=4
    )

    print()
    print("=" * 70)
    print("GENERATED SUMMARY:")
    print("=" * 70)
    print(summary)
    print("=" * 70)


if __name__ == "__main__":
    main()