indo_summary_AI / pretrained_summarizer.py
Flippinjack's picture
Deploy to hf space
6dcb6b4
"""
Indonesian Court Document Summarization using Pre-trained Models
This module uses pre-trained multilingual models (mT5 or mBART) that already
understand Indonesian language and can generate summaries without extensive training.
"""
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline
)
class IndonesianCourtSummarizer:
"""
Summarizer for Indonesian court documents using pre-trained models.
Supports multiple model backends:
- google/mt5-small (lightweight, good for quick inference)
- google/mt5-base (balanced performance and quality)
- csebuetnlp/mT5_multilingual_XLSum (fine-tuned for summarization)
- facebook/mbart-large-50 (multilingual BART)
"""
def __init__(self, model_name="csebuetnlp/mT5_multilingual_XLSum", device=None):
"""
Initialize the summarizer with a pre-trained model.
Args:
model_name: HuggingFace model identifier
device: Device to run on ('cuda', 'cpu', or None for auto-detect)
"""
print(f"Loading model: {model_name}")
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Using device: {self.device}")
# Load tokenizer and model
try:
# Try loading with fast tokenizer first
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
print(f"Fast tokenizer failed, trying slow tokenizer...")
# Fallback to slow tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
print("Model loaded successfully!\n")
def summarize(self,
document,
max_length=256,
min_length=50,
num_beams=4,
length_penalty=1.0,
early_stopping=True,
add_prefix=True):
"""
Generate summary for an Indonesian court document.
Args:
document: The full court document text
max_length: Maximum length of the summary
min_length: Minimum length of the summary
num_beams: Number of beams for beam search (higher = better quality, slower)
length_penalty: Length penalty (>1.0 = longer summaries, <1.0 = shorter)
early_stopping: Stop when all beams reach EOS
add_prefix: Add "summarize: " prefix for some models
Returns:
Generated summary text
"""
# Prepare input
if add_prefix:
input_text = f"summarize: {document}"
else:
input_text = document
# Tokenize
inputs = self.tokenizer(
input_text,
max_length=1024, # Max input length
truncation=True,
return_tensors="pt"
).to(self.device)
# Generate summary
with torch.no_grad():
summary_ids = self.model.generate(
inputs.input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
length_penalty=length_penalty,
early_stopping=early_stopping,
no_repeat_ngram_size=3 # Avoid repetition
)
# Decode
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary.strip()
def summarize_batch(self, documents, **kwargs):
"""
Summarize multiple documents in batch.
Args:
documents: List of court document texts
**kwargs: Additional arguments passed to summarize()
Returns:
List of generated summaries
"""
summaries = []
for i, doc in enumerate(documents):
print(f"Summarizing document {i+1}/{len(documents)}...")
summary = self.summarize(doc, **kwargs)
summaries.append(summary)
return summaries
def create_summarizer(model_choice="balanced"):
"""
Factory function to create a summarizer with recommended settings.
Args:
model_choice:
- "fast": Fastest, good for testing (mt5-small)
- "balanced": Good balance of speed and quality (mT5_multilingual_XLSum)
- "quality": Best quality, slower (mt5-base)
Returns:
IndonesianCourtSummarizer instance
"""
model_map = {
"fast": "google/mt5-small",
"balanced": "csebuetnlp/mT5_multilingual_XLSum",
"quality": "google/mt5-base"
}
model_name = model_map.get(model_choice, model_choice)
return IndonesianCourtSummarizer(model_name=model_name)
# Convenience function for quick summarization
def quick_summarize(document, model_choice="balanced", **kwargs):
"""
Quick one-line summarization function.
Args:
document: Court document text
model_choice: Model to use ("fast", "balanced", or "quality")
**kwargs: Additional arguments for summarize()
Returns:
Summary text
"""
summarizer = create_summarizer(model_choice)
return summarizer.summarize(document, **kwargs)
def main():
"""Example usage"""
# Example Indonesian court document
court_document = """PUTUSAN
Nomor 245/Pid.B/2024/PN.Jkt.Pst
DEMI KEADILAN BERDASARKAN KETUHANAN YANG MAHA ESA
Pengadilan Negeri Jakarta Pusat yang mengadili perkara pidana dengan acara pemeriksaan
biasa dalam tingkat pertama menjatuhkan putusan sebagai berikut dalam perkara Terdakwa
BUDI SANTOSO yang didakwa melakukan tindak pidana penggelapan. Terdakwa menerima uang
sejumlah Rp 350.000.000 dari PT Sejahtera Abadi untuk pembelian bahan baku namun
menggunakan uang tersebut untuk keperluan pribadi. Berdasarkan keterangan saksi-saksi
dan barang bukti, Majelis Hakim memutuskan Terdakwa terbukti secara sah dan meyakinkan
bersalah melakukan tindak pidana penggelapan sebagaimana diatur dalam Pasal 372 KUHP.
Terdakwa dijatuhi pidana penjara selama 2 tahun 6 bulan. Barang bukti berupa dokumen
perjanjian kerjasama, buku tabungan, dan bukti transfer dirampas untuk negara. Terdakwa
ditetapkan tetap ditahan dan dibebankan biaya perkara sebesar Rp 5.000."""
print("=" * 70)
print("Indonesian Court Document Summarization Demo")
print("Using Pre-trained Multilingual Models")
print("=" * 70)
print()
# Create summarizer
print("Creating summarizer with balanced model...")
summarizer = create_summarizer("balanced")
print("=" * 70)
print("ORIGINAL DOCUMENT:")
print("=" * 70)
print(court_document)
print()
print("=" * 70)
print("GENERATING SUMMARY...")
print("=" * 70)
# Generate summary
summary = summarizer.summarize(
court_document,
max_length=150,
min_length=30,
num_beams=4
)
print()
print("=" * 70)
print("GENERATED SUMMARY:")
print("=" * 70)
print(summary)
print("=" * 70)
if __name__ == "__main__":
main()