File size: 2,773 Bytes
2a16478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from typing import List, Dict, Union
from transformers import pipeline
import json

class DocumentClassifier:
    def __init__(self, model_type: str = "zero-shot"):
        """
        Inisialisasi klasifikator dokumen.
        Untuk baseline capstone, kita gunakan Zero-Shot Classification berbasis model multilingual
        agar langsung bekerja tanpa perlu training pada tahap awal.
        Nantinya, kita bisa membuat model fine-tuned BERT atau BiLSTM di sini.
        """
        self.model_type = model_type
        self.categories = ["Laporan", "Surat Resmi", "Berita", "Invoice", "Pendidikan", "Keuangan", "IT"]
        
        if self.model_type == "zero-shot":
            # Using a multilingual zero-shot model that works well for Indonesian
            model_name = os.getenv("CLASSIFIER_MODEL", "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
            print(f"Loading zero-shot classification model: {model_name}...")
            try:
                self.classifier = pipeline("zero-shot-classification", model=model_name)
            except Exception as e:
                print(f"Error loading model: {e}")
                self.classifier = None
        else:
            # Placeholder for future fine-tuned BERT or BiLSTM
            print(f"Model type {model_type} selected. Ensure weights are available in models/ dir.")
            self.classifier = None

    def set_categories(self, categories: List[str]):
        """Mengubah daftar kategori dinamis untuk zero-shot classification"""
        self.categories = categories

    def classify(self, text: str, return_all_scores: bool = False) -> Union[str, Dict]:
        """
        Mengklasifikasikan dokumen ke dalam salah satu kategori.
        """
        if not text or len(text.strip()) < 10:
            return "Tidak Diketahui"
            
        if self.model_type == "zero-shot" and self.classifier:
            try:
                # Limit text length to prevent memory/token errors
                chunk = text[:1500] 
                
                result = self.classifier(chunk, self.categories)
                
                if return_all_scores:
                    return {
                        "labels": result["labels"],
                        "scores": result["scores"]
                    }
                else:
                    # Return the top predicted label
                    return result["labels"][0]
            except Exception as e:
                print(f"Classification error: {e}")
                return "Error"
        else:
            return "Model not initialized"

# Example usage:
# classifier = DocumentClassifier()
# label = classifier.classify("Tagihan pembelian server bulan Maret Rp 10.000.000")
# print(label) # Likely "Invoice"