reyhanadr commited on
Commit
99f3ba3
·
1 Parent(s): f165c76

add batch predict

Browse files
Files changed (1) hide show
  1. main.py +24 -24
main.py CHANGED
@@ -4,23 +4,19 @@ import re
4
  from transformers import BertTokenizer, BertForSequenceClassification
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
- from typing import Dict
8
 
9
  # ====================================================================
10
- # 1. KELAS LOGIKA ANDA (Tidak ada perubahan di TextCleaner)
11
  # ====================================================================
12
 
13
  class TextCleaner:
14
  def __init__(self):
15
- # Daftar karakter ini saya sederhanakan karena loop Anda sudah menangani huruf a-z
16
  self.character = ['.', ',', ';', ':', '?', '!', '(', ')', '[', ']', '{', '}', '<', '>', '"', '/', '\'', '-', '@']
17
- # Menambahkan semua huruf ke dalam daftar karakter untuk pembersihan
18
  self.character.extend([chr(i) for i in range(ord('a'), ord('z') + 1)])
19
 
20
  def repeatcharClean(self, text):
21
  for char_to_clean in self.character:
22
- # Menggunakan regex untuk mengganti 3 atau lebih karakter berulang menjadi satu
23
- # Contoh: 'heloooo' -> 'helo'
24
  pattern = re.compile(re.escape(char_to_clean) + r'{3,}')
25
  text = pattern.sub(char_to_clean, text)
26
  return text
@@ -29,21 +25,17 @@ class TextCleaner:
29
  text = text.lower()
30
  text = re.sub(r'\s+', ' ', text)
31
  text = re.sub(r'[^\x00-\x7F]+', ' ', text)
32
-
33
  new_text = []
34
  for word in text.split(" "):
35
  word = '@USER' if word.startswith('@') and len(word) > 1 else word
36
  word = 'HTTPURL' if word.startswith('http') else word
37
  new_text.append(word)
38
  text = " ".join(new_text)
39
-
40
  text = emoji.demojize(text)
41
  text = re.sub(r':[A-Za-z_-]+:', ' ', text)
42
  text = re.sub(r"([xX;:]'?[dDpPvVoO3)(])", ' ', text)
43
  text = re.sub(r'["#$%&()*+,./:;<=>\[\]\\^_`{|}~]', ' ', text)
44
  text = self.repeatcharClean(text)
45
-
46
- # Membersihkan spasi berlebih yang mungkin muncul setelah pembersihan
47
  text = re.sub(r'\s+', ' ', text).strip()
48
  return text
49
 
@@ -53,10 +45,8 @@ class SentimentPredictor:
53
  self.model = model
54
  self.device = torch.device("cpu")
55
  self.model.to(self.device)
56
- # --- [DIUBAH] --- Definisikan mapping label di sini agar mudah digunakan
57
  self.label_mapping = {0: 'Positif', 1: 'Netral', 2: 'Negatif'}
58
 
59
- # --- [DIUBAH] --- Tipe data kembalian (return type) diubah
60
  def predict(self, text: str) -> (str, float, Dict[str, float]):
61
  inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=280)
62
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -65,18 +55,13 @@ class SentimentPredictor:
65
  outputs = self.model(**inputs)
66
 
67
  logits = outputs.logits
 
68
 
69
- # Hitung probabilitas untuk semua kelas
70
- probabilities = torch.softmax(logits, dim=1)[0] # Ambil hasil pertama dari batch
71
-
72
- # Dapatkan label dan skor kepercayaan dari probabilitas tertinggi
73
  confidence_score = probabilities.max().item()
74
  predicted_label_id = probabilities.argmax().item()
75
  sentiment = self.label_mapping[predicted_label_id]
76
 
77
- # --- [DIUBAH] --- Buat dictionary untuk semua skor probabilitas
78
  all_scores = {self.label_mapping[i]: prob.item() for i, prob in enumerate(probabilities)}
79
-
80
  return sentiment, confidence_score, all_scores
81
 
82
  # ====================================================================
@@ -107,14 +92,20 @@ app = FastAPI(
107
  class TextInput(BaseModel):
108
  text: str
109
 
110
- # --- [DIUBAH] --- Model output diperbarui untuk menyertakan semua skor
 
 
 
 
 
 
111
  class PredictionOutput(BaseModel):
112
  sentiment: str
113
  confidence: float
114
  all_scores: Dict[str, float]
115
 
116
  # ====================================================================
117
- # 4. BUAT ENDPOINT PREDIKSI
118
  # ====================================================================
119
 
120
  @app.get("/")
@@ -124,13 +115,22 @@ def read_root():
124
  @app.post("/predict", response_model=PredictionOutput)
125
  def predict_sentiment(request: TextInput):
126
  cleaned_text = text_cleaner.clean_review(request.text)
127
-
128
- # --- [DIUBAH] --- Tangkap tiga nilai yang dikembalikan oleh metode predict
129
  sentiment, confidence, all_scores = sentiment_predictor.predict(cleaned_text)
130
-
131
- # --- [DIUBAH] --- Kembalikan hasil prediksi dalam struktur yang baru
132
  return PredictionOutput(
133
  sentiment=sentiment,
134
  confidence=confidence,
135
  all_scores=all_scores
136
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers import BertTokenizer, BertForSequenceClassification
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
+ from typing import Dict, List
8
 
9
  # ====================================================================
10
+ # 1. KELAS LOGIKA ANDA (Tidak ada perubahan)
11
  # ====================================================================
12
 
13
  class TextCleaner:
14
  def __init__(self):
 
15
  self.character = ['.', ',', ';', ':', '?', '!', '(', ')', '[', ']', '{', '}', '<', '>', '"', '/', '\'', '-', '@']
 
16
  self.character.extend([chr(i) for i in range(ord('a'), ord('z') + 1)])
17
 
18
  def repeatcharClean(self, text):
19
  for char_to_clean in self.character:
 
 
20
  pattern = re.compile(re.escape(char_to_clean) + r'{3,}')
21
  text = pattern.sub(char_to_clean, text)
22
  return text
 
25
  text = text.lower()
26
  text = re.sub(r'\s+', ' ', text)
27
  text = re.sub(r'[^\x00-\x7F]+', ' ', text)
 
28
  new_text = []
29
  for word in text.split(" "):
30
  word = '@USER' if word.startswith('@') and len(word) > 1 else word
31
  word = 'HTTPURL' if word.startswith('http') else word
32
  new_text.append(word)
33
  text = " ".join(new_text)
 
34
  text = emoji.demojize(text)
35
  text = re.sub(r':[A-Za-z_-]+:', ' ', text)
36
  text = re.sub(r"([xX;:]'?[dDpPvVoO3)(])", ' ', text)
37
  text = re.sub(r'["#$%&()*+,./:;<=>\[\]\\^_`{|}~]', ' ', text)
38
  text = self.repeatcharClean(text)
 
 
39
  text = re.sub(r'\s+', ' ', text).strip()
40
  return text
41
 
 
45
  self.model = model
46
  self.device = torch.device("cpu")
47
  self.model.to(self.device)
 
48
  self.label_mapping = {0: 'Positif', 1: 'Netral', 2: 'Negatif'}
49
 
 
50
  def predict(self, text: str) -> (str, float, Dict[str, float]):
51
  inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=280)
52
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
55
  outputs = self.model(**inputs)
56
 
57
  logits = outputs.logits
58
+ probabilities = torch.softmax(logits, dim=1)[0]
59
 
 
 
 
 
60
  confidence_score = probabilities.max().item()
61
  predicted_label_id = probabilities.argmax().item()
62
  sentiment = self.label_mapping[predicted_label_id]
63
 
 
64
  all_scores = {self.label_mapping[i]: prob.item() for i, prob in enumerate(probabilities)}
 
65
  return sentiment, confidence_score, all_scores
66
 
67
  # ====================================================================
 
92
  class TextInput(BaseModel):
93
  text: str
94
 
95
+ # --- [PERBAIKAN] --- Menambahkan definisi BatchTextInput ---
96
+ # Model ini memberitahu FastAPI bahwa endpoint batch akan menerima
97
+ # sebuah objek JSON dengan satu key "texts" yang berisi daftar string.
98
+ class BatchTextInput(BaseModel):
99
+ texts: List[str]
100
+ # -----------------------------------------------------------
101
+
102
  class PredictionOutput(BaseModel):
103
  sentiment: str
104
  confidence: float
105
  all_scores: Dict[str, float]
106
 
107
  # ====================================================================
108
+ # 4. BUAT ENDPOINT (Tidak ada perubahan logika)
109
  # ====================================================================
110
 
111
  @app.get("/")
 
115
  @app.post("/predict", response_model=PredictionOutput)
116
  def predict_sentiment(request: TextInput):
117
  cleaned_text = text_cleaner.clean_review(request.text)
 
 
118
  sentiment, confidence, all_scores = sentiment_predictor.predict(cleaned_text)
 
 
119
  return PredictionOutput(
120
  sentiment=sentiment,
121
  confidence=confidence,
122
  all_scores=all_scores
123
  )
124
+
125
+ @app.post("/predict-batch", response_model=List[PredictionOutput])
126
+ def predict_sentiment_batch(request: BatchTextInput):
127
+ results = []
128
+ for text in request.texts:
129
+ cleaned_text = text_cleaner.clean_review(text)
130
+ sentiment, confidence, all_scores = sentiment_predictor.predict(cleaned_text)
131
+ results.append(PredictionOutput(
132
+ sentiment=sentiment,
133
+ confidence=confidence,
134
+ all_scores=all_scores
135
+ ))
136
+ return results