add batch predict
Browse files
main.py
CHANGED
|
@@ -4,23 +4,19 @@ import re
|
|
| 4 |
from transformers import BertTokenizer, BertForSequenceClassification
|
| 5 |
from fastapi import FastAPI
|
| 6 |
from pydantic import BaseModel
|
| 7 |
-
from typing import Dict
|
| 8 |
|
| 9 |
# ====================================================================
|
| 10 |
-
# 1. KELAS LOGIKA ANDA (Tidak ada perubahan
|
| 11 |
# ====================================================================
|
| 12 |
|
| 13 |
class TextCleaner:
|
| 14 |
def __init__(self):
|
| 15 |
-
# Daftar karakter ini saya sederhanakan karena loop Anda sudah menangani huruf a-z
|
| 16 |
self.character = ['.', ',', ';', ':', '?', '!', '(', ')', '[', ']', '{', '}', '<', '>', '"', '/', '\'', '-', '@']
|
| 17 |
-
# Menambahkan semua huruf ke dalam daftar karakter untuk pembersihan
|
| 18 |
self.character.extend([chr(i) for i in range(ord('a'), ord('z') + 1)])
|
| 19 |
|
| 20 |
def repeatcharClean(self, text):
|
| 21 |
for char_to_clean in self.character:
|
| 22 |
-
# Menggunakan regex untuk mengganti 3 atau lebih karakter berulang menjadi satu
|
| 23 |
-
# Contoh: 'heloooo' -> 'helo'
|
| 24 |
pattern = re.compile(re.escape(char_to_clean) + r'{3,}')
|
| 25 |
text = pattern.sub(char_to_clean, text)
|
| 26 |
return text
|
|
@@ -29,21 +25,17 @@ class TextCleaner:
|
|
| 29 |
text = text.lower()
|
| 30 |
text = re.sub(r'\s+', ' ', text)
|
| 31 |
text = re.sub(r'[^\x00-\x7F]+', ' ', text)
|
| 32 |
-
|
| 33 |
new_text = []
|
| 34 |
for word in text.split(" "):
|
| 35 |
word = '@USER' if word.startswith('@') and len(word) > 1 else word
|
| 36 |
word = 'HTTPURL' if word.startswith('http') else word
|
| 37 |
new_text.append(word)
|
| 38 |
text = " ".join(new_text)
|
| 39 |
-
|
| 40 |
text = emoji.demojize(text)
|
| 41 |
text = re.sub(r':[A-Za-z_-]+:', ' ', text)
|
| 42 |
text = re.sub(r"([xX;:]'?[dDpPvVoO3)(])", ' ', text)
|
| 43 |
text = re.sub(r'["#$%&()*+,./:;<=>\[\]\\^_`{|}~]', ' ', text)
|
| 44 |
text = self.repeatcharClean(text)
|
| 45 |
-
|
| 46 |
-
# Membersihkan spasi berlebih yang mungkin muncul setelah pembersihan
|
| 47 |
text = re.sub(r'\s+', ' ', text).strip()
|
| 48 |
return text
|
| 49 |
|
|
@@ -53,10 +45,8 @@ class SentimentPredictor:
|
|
| 53 |
self.model = model
|
| 54 |
self.device = torch.device("cpu")
|
| 55 |
self.model.to(self.device)
|
| 56 |
-
# --- [DIUBAH] --- Definisikan mapping label di sini agar mudah digunakan
|
| 57 |
self.label_mapping = {0: 'Positif', 1: 'Netral', 2: 'Negatif'}
|
| 58 |
|
| 59 |
-
# --- [DIUBAH] --- Tipe data kembalian (return type) diubah
|
| 60 |
def predict(self, text: str) -> (str, float, Dict[str, float]):
|
| 61 |
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=280)
|
| 62 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
@@ -65,18 +55,13 @@ class SentimentPredictor:
|
|
| 65 |
outputs = self.model(**inputs)
|
| 66 |
|
| 67 |
logits = outputs.logits
|
|
|
|
| 68 |
|
| 69 |
-
# Hitung probabilitas untuk semua kelas
|
| 70 |
-
probabilities = torch.softmax(logits, dim=1)[0] # Ambil hasil pertama dari batch
|
| 71 |
-
|
| 72 |
-
# Dapatkan label dan skor kepercayaan dari probabilitas tertinggi
|
| 73 |
confidence_score = probabilities.max().item()
|
| 74 |
predicted_label_id = probabilities.argmax().item()
|
| 75 |
sentiment = self.label_mapping[predicted_label_id]
|
| 76 |
|
| 77 |
-
# --- [DIUBAH] --- Buat dictionary untuk semua skor probabilitas
|
| 78 |
all_scores = {self.label_mapping[i]: prob.item() for i, prob in enumerate(probabilities)}
|
| 79 |
-
|
| 80 |
return sentiment, confidence_score, all_scores
|
| 81 |
|
| 82 |
# ====================================================================
|
|
@@ -107,14 +92,20 @@ app = FastAPI(
|
|
| 107 |
class TextInput(BaseModel):
|
| 108 |
text: str
|
| 109 |
|
| 110 |
-
# --- [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
class PredictionOutput(BaseModel):
|
| 112 |
sentiment: str
|
| 113 |
confidence: float
|
| 114 |
all_scores: Dict[str, float]
|
| 115 |
|
| 116 |
# ====================================================================
|
| 117 |
-
# 4. BUAT ENDPOINT
|
| 118 |
# ====================================================================
|
| 119 |
|
| 120 |
@app.get("/")
|
|
@@ -124,13 +115,22 @@ def read_root():
|
|
| 124 |
@app.post("/predict", response_model=PredictionOutput)
|
| 125 |
def predict_sentiment(request: TextInput):
|
| 126 |
cleaned_text = text_cleaner.clean_review(request.text)
|
| 127 |
-
|
| 128 |
-
# --- [DIUBAH] --- Tangkap tiga nilai yang dikembalikan oleh metode predict
|
| 129 |
sentiment, confidence, all_scores = sentiment_predictor.predict(cleaned_text)
|
| 130 |
-
|
| 131 |
-
# --- [DIUBAH] --- Kembalikan hasil prediksi dalam struktur yang baru
|
| 132 |
return PredictionOutput(
|
| 133 |
sentiment=sentiment,
|
| 134 |
confidence=confidence,
|
| 135 |
all_scores=all_scores
|
| 136 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from transformers import BertTokenizer, BertForSequenceClassification
|
| 5 |
from fastapi import FastAPI
|
| 6 |
from pydantic import BaseModel
|
| 7 |
+
from typing import Dict, List
|
| 8 |
|
| 9 |
# ====================================================================
|
| 10 |
+
# 1. KELAS LOGIKA ANDA (Tidak ada perubahan)
|
| 11 |
# ====================================================================
|
| 12 |
|
| 13 |
class TextCleaner:
|
| 14 |
def __init__(self):
|
|
|
|
| 15 |
self.character = ['.', ',', ';', ':', '?', '!', '(', ')', '[', ']', '{', '}', '<', '>', '"', '/', '\'', '-', '@']
|
|
|
|
| 16 |
self.character.extend([chr(i) for i in range(ord('a'), ord('z') + 1)])
|
| 17 |
|
| 18 |
def repeatcharClean(self, text):
|
| 19 |
for char_to_clean in self.character:
|
|
|
|
|
|
|
| 20 |
pattern = re.compile(re.escape(char_to_clean) + r'{3,}')
|
| 21 |
text = pattern.sub(char_to_clean, text)
|
| 22 |
return text
|
|
|
|
| 25 |
text = text.lower()
|
| 26 |
text = re.sub(r'\s+', ' ', text)
|
| 27 |
text = re.sub(r'[^\x00-\x7F]+', ' ', text)
|
|
|
|
| 28 |
new_text = []
|
| 29 |
for word in text.split(" "):
|
| 30 |
word = '@USER' if word.startswith('@') and len(word) > 1 else word
|
| 31 |
word = 'HTTPURL' if word.startswith('http') else word
|
| 32 |
new_text.append(word)
|
| 33 |
text = " ".join(new_text)
|
|
|
|
| 34 |
text = emoji.demojize(text)
|
| 35 |
text = re.sub(r':[A-Za-z_-]+:', ' ', text)
|
| 36 |
text = re.sub(r"([xX;:]'?[dDpPvVoO3)(])", ' ', text)
|
| 37 |
text = re.sub(r'["#$%&()*+,./:;<=>\[\]\\^_`{|}~]', ' ', text)
|
| 38 |
text = self.repeatcharClean(text)
|
|
|
|
|
|
|
| 39 |
text = re.sub(r'\s+', ' ', text).strip()
|
| 40 |
return text
|
| 41 |
|
|
|
|
| 45 |
self.model = model
|
| 46 |
self.device = torch.device("cpu")
|
| 47 |
self.model.to(self.device)
|
|
|
|
| 48 |
self.label_mapping = {0: 'Positif', 1: 'Netral', 2: 'Negatif'}
|
| 49 |
|
|
|
|
| 50 |
def predict(self, text: str) -> (str, float, Dict[str, float]):
|
| 51 |
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=280)
|
| 52 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
| 55 |
outputs = self.model(**inputs)
|
| 56 |
|
| 57 |
logits = outputs.logits
|
| 58 |
+
probabilities = torch.softmax(logits, dim=1)[0]
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
confidence_score = probabilities.max().item()
|
| 61 |
predicted_label_id = probabilities.argmax().item()
|
| 62 |
sentiment = self.label_mapping[predicted_label_id]
|
| 63 |
|
|
|
|
| 64 |
all_scores = {self.label_mapping[i]: prob.item() for i, prob in enumerate(probabilities)}
|
|
|
|
| 65 |
return sentiment, confidence_score, all_scores
|
| 66 |
|
| 67 |
# ====================================================================
|
|
|
|
| 92 |
class TextInput(BaseModel):
|
| 93 |
text: str
|
| 94 |
|
| 95 |
+
# --- [PERBAIKAN] --- Menambahkan definisi BatchTextInput ---
|
| 96 |
+
# Model ini memberitahu FastAPI bahwa endpoint batch akan menerima
|
| 97 |
+
# sebuah objek JSON dengan satu key "texts" yang berisi daftar string.
|
| 98 |
+
class BatchTextInput(BaseModel):
|
| 99 |
+
texts: List[str]
|
| 100 |
+
# -----------------------------------------------------------
|
| 101 |
+
|
| 102 |
class PredictionOutput(BaseModel):
|
| 103 |
sentiment: str
|
| 104 |
confidence: float
|
| 105 |
all_scores: Dict[str, float]
|
| 106 |
|
| 107 |
# ====================================================================
|
| 108 |
+
# 4. BUAT ENDPOINT (Tidak ada perubahan logika)
|
| 109 |
# ====================================================================
|
| 110 |
|
| 111 |
@app.get("/")
|
|
|
|
| 115 |
@app.post("/predict", response_model=PredictionOutput)
|
| 116 |
def predict_sentiment(request: TextInput):
|
| 117 |
cleaned_text = text_cleaner.clean_review(request.text)
|
|
|
|
|
|
|
| 118 |
sentiment, confidence, all_scores = sentiment_predictor.predict(cleaned_text)
|
|
|
|
|
|
|
| 119 |
return PredictionOutput(
|
| 120 |
sentiment=sentiment,
|
| 121 |
confidence=confidence,
|
| 122 |
all_scores=all_scores
|
| 123 |
)
|
| 124 |
+
|
| 125 |
+
@app.post("/predict-batch", response_model=List[PredictionOutput])
|
| 126 |
+
def predict_sentiment_batch(request: BatchTextInput):
|
| 127 |
+
results = []
|
| 128 |
+
for text in request.texts:
|
| 129 |
+
cleaned_text = text_cleaner.clean_review(text)
|
| 130 |
+
sentiment, confidence, all_scores = sentiment_predictor.predict(cleaned_text)
|
| 131 |
+
results.append(PredictionOutput(
|
| 132 |
+
sentiment=sentiment,
|
| 133 |
+
confidence=confidence,
|
| 134 |
+
all_scores=all_scores
|
| 135 |
+
))
|
| 136 |
+
return results
|