reyhanadr commited on
Commit
f165c76
·
1 Parent(s): abf6e68

add another probability score

Browse files
Files changed (1) hide show
  1. main.py +28 -26
main.py CHANGED
@@ -4,9 +4,10 @@ import re
4
  from transformers import BertTokenizer, BertForSequenceClassification
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
 
7
 
8
  # ====================================================================
9
- # 1. KELAS LOGIKA ANDA (Disalin dari kode Anda)
10
  # ====================================================================
11
 
12
  class TextCleaner:
@@ -52,8 +53,11 @@ class SentimentPredictor:
52
  self.model = model
53
  self.device = torch.device("cpu")
54
  self.model.to(self.device)
 
 
55
 
56
- def predict(self, text: str) -> (str, float):
 
57
  inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=280)
58
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
59
 
@@ -61,43 +65,36 @@ class SentimentPredictor:
61
  outputs = self.model(**inputs)
62
 
63
  logits = outputs.logits
64
- predicted_label = torch.argmax(logits, dim=1).item()
65
 
66
- probabilities = torch.softmax(logits, dim=1)
67
- confidence_score = probabilities[0][predicted_label].item()
68
-
69
- if predicted_label == 2:
70
- sentiment = 'Negatif'
71
- elif predicted_label == 1:
72
- sentiment = 'Netral'
73
- else: # predicted_label == 0
74
- sentiment = 'Positif'
 
75
 
76
- return sentiment, confidence_score
77
 
78
  # ====================================================================
79
- # 2. INISIALISASI MODEL & APLIKASI FASTAPI
80
- # (Ini hanya dijalankan sekali saat API pertama kali startet)
81
  # ====================================================================
82
 
83
  print("Memuat model dan tokenizer...")
84
- # Muat tokenizer dan model dasar
85
  tokenizer = BertTokenizer.from_pretrained('indolem/indobertweet-base-uncased')
86
  model = BertForSequenceClassification.from_pretrained('indolem/indobertweet-base-uncased', num_labels=3)
87
-
88
- # Muat bobot model yang sudah Anda latih
89
  model_path = 'model_indoBERTweet_100Epochs_sentiment.pth'
90
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
91
  model.load_state_dict(state_dict, strict=False)
92
  model.eval()
93
  print("Model berhasil dimuat.")
94
 
95
- # Buat instance dari kelas-kelas Anda
96
  text_cleaner = TextCleaner()
97
  sentiment_predictor = SentimentPredictor(tokenizer, model)
98
 
99
- # Inisialisasi aplikasi FastAPI
100
- # Baris ini ditambahkan untuk memaksa build ulang
101
  app = FastAPI(
102
  title="API Klasifikasi Sentimen",
103
  description="Sebuah API untuk menganalisis sentimen teks Bahasa Indonesia."
@@ -110,9 +107,11 @@ app = FastAPI(
110
  class TextInput(BaseModel):
111
  text: str
112
 
 
113
  class PredictionOutput(BaseModel):
114
  sentiment: str
115
  confidence: float
 
116
 
117
  # ====================================================================
118
  # 4. BUAT ENDPOINT PREDIKSI
@@ -124,11 +123,14 @@ def read_root():
124
 
125
  @app.post("/predict", response_model=PredictionOutput)
126
  def predict_sentiment(request: TextInput):
127
- # Langkah 1: Bersihkan teks input
128
  cleaned_text = text_cleaner.clean_review(request.text)
129
 
130
- # Langkah 2: Lakukan prediksi pada teks yang sudah bersih
131
- sentiment, confidence = sentiment_predictor.predict(cleaned_text)
132
 
133
- # Langkah 3: Kembalikan hasil prediksi
134
- return PredictionOutput(sentiment=sentiment, confidence=confidence)
 
 
 
 
 
4
  from transformers import BertTokenizer, BertForSequenceClassification
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
+ from typing import Dict
8
 
9
  # ====================================================================
10
+ # 1. KELAS LOGIKA ANDA (Tidak ada perubahan di TextCleaner)
11
  # ====================================================================
12
 
13
  class TextCleaner:
 
53
  self.model = model
54
  self.device = torch.device("cpu")
55
  self.model.to(self.device)
56
+ # --- [DIUBAH] --- Definisikan mapping label di sini agar mudah digunakan
57
+ self.label_mapping = {0: 'Positif', 1: 'Netral', 2: 'Negatif'}
58
 
59
+ # --- [DIUBAH] --- Tipe data kembalian (return type) diubah
60
+ def predict(self, text: str) -> (str, float, Dict[str, float]):
61
  inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=280)
62
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
63
 
 
65
  outputs = self.model(**inputs)
66
 
67
  logits = outputs.logits
 
68
 
69
+ # Hitung probabilitas untuk semua kelas
70
+ probabilities = torch.softmax(logits, dim=1)[0] # Ambil hasil pertama dari batch
71
+
72
+ # Dapatkan label dan skor kepercayaan dari probabilitas tertinggi
73
+ confidence_score = probabilities.max().item()
74
+ predicted_label_id = probabilities.argmax().item()
75
+ sentiment = self.label_mapping[predicted_label_id]
76
+
77
+ # --- [DIUBAH] --- Buat dictionary untuk semua skor probabilitas
78
+ all_scores = {self.label_mapping[i]: prob.item() for i, prob in enumerate(probabilities)}
79
 
80
+ return sentiment, confidence_score, all_scores
81
 
82
  # ====================================================================
83
+ # 2. INISIALISASI MODEL & APLIKASI FASTAPI (Tidak ada perubahan)
 
84
  # ====================================================================
85
 
86
  print("Memuat model dan tokenizer...")
 
87
  tokenizer = BertTokenizer.from_pretrained('indolem/indobertweet-base-uncased')
88
  model = BertForSequenceClassification.from_pretrained('indolem/indobertweet-base-uncased', num_labels=3)
 
 
89
  model_path = 'model_indoBERTweet_100Epochs_sentiment.pth'
90
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
91
  model.load_state_dict(state_dict, strict=False)
92
  model.eval()
93
  print("Model berhasil dimuat.")
94
 
 
95
  text_cleaner = TextCleaner()
96
  sentiment_predictor = SentimentPredictor(tokenizer, model)
97
 
 
 
98
  app = FastAPI(
99
  title="API Klasifikasi Sentimen",
100
  description="Sebuah API untuk menganalisis sentimen teks Bahasa Indonesia."
 
107
  class TextInput(BaseModel):
108
  text: str
109
 
110
+ # --- [DIUBAH] --- Model output diperbarui untuk menyertakan semua skor
111
  class PredictionOutput(BaseModel):
112
  sentiment: str
113
  confidence: float
114
+ all_scores: Dict[str, float]
115
 
116
  # ====================================================================
117
  # 4. BUAT ENDPOINT PREDIKSI
 
123
 
124
  @app.post("/predict", response_model=PredictionOutput)
125
  def predict_sentiment(request: TextInput):
 
126
  cleaned_text = text_cleaner.clean_review(request.text)
127
 
128
+ # --- [DIUBAH] --- Tangkap tiga nilai yang dikembalikan oleh metode predict
129
+ sentiment, confidence, all_scores = sentiment_predictor.predict(cleaned_text)
130
 
131
+ # --- [DIUBAH] --- Kembalikan hasil prediksi dalam struktur yang baru
132
+ return PredictionOutput(
133
+ sentiment=sentiment,
134
+ confidence=confidence,
135
+ all_scores=all_scores
136
+ )