Ringkas-In / src /core /classifier.py
anthonysigid's picture
deploy SummAIrizer apps to spaces
2a16478
Raw
History Blame Contribute Delete
2.77 kB
import os
from typing import List, Dict, Union
from transformers import pipeline
import json
class DocumentClassifier:
def __init__(self, model_type: str = "zero-shot"):
"""
Inisialisasi klasifikator dokumen.
Untuk baseline capstone, kita gunakan Zero-Shot Classification berbasis model multilingual
agar langsung bekerja tanpa perlu training pada tahap awal.
Nantinya, kita bisa membuat model fine-tuned BERT atau BiLSTM di sini.
"""
self.model_type = model_type
self.categories = ["Laporan", "Surat Resmi", "Berita", "Invoice", "Pendidikan", "Keuangan", "IT"]
if self.model_type == "zero-shot":
# Using a multilingual zero-shot model that works well for Indonesian
model_name = os.getenv("CLASSIFIER_MODEL", "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
print(f"Loading zero-shot classification model: {model_name}...")
try:
self.classifier = pipeline("zero-shot-classification", model=model_name)
except Exception as e:
print(f"Error loading model: {e}")
self.classifier = None
else:
# Placeholder for future fine-tuned BERT or BiLSTM
print(f"Model type {model_type} selected. Ensure weights are available in models/ dir.")
self.classifier = None
def set_categories(self, categories: List[str]):
"""Mengubah daftar kategori dinamis untuk zero-shot classification"""
self.categories = categories
def classify(self, text: str, return_all_scores: bool = False) -> Union[str, Dict]:
"""
Mengklasifikasikan dokumen ke dalam salah satu kategori.
"""
if not text or len(text.strip()) < 10:
return "Tidak Diketahui"
if self.model_type == "zero-shot" and self.classifier:
try:
# Limit text length to prevent memory/token errors
chunk = text[:1500]
result = self.classifier(chunk, self.categories)
if return_all_scores:
return {
"labels": result["labels"],
"scores": result["scores"]
}
else:
# Return the top predicted label
return result["labels"][0]
except Exception as e:
print(f"Classification error: {e}")
return "Error"
else:
return "Model not initialized"
# Example usage:
# classifier = DocumentClassifier()
# label = classifier.classify("Tagihan pembelian server bulan Maret Rp 10.000.000")
# print(label) # Likely "Invoice"