Spaces:
Sleeping
Sleeping
Upload api.py with huggingface_hub
Browse files
api.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_phobert_api.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
import torch
|
| 7 |
+
import re
|
| 8 |
+
import json
|
| 9 |
+
import emoji
|
| 10 |
+
from underthesea import word_tokenize
|
| 11 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
|
| 12 |
+
|
| 13 |
+
# Khởi tạo FastAPI app
|
| 14 |
+
app = FastAPI(
|
| 15 |
+
title="PhoBERT Emotion Classification API",
|
| 16 |
+
description="API dự đoán cảm xúc của câu tiếng Việt sử dụng PhoBERT.",
|
| 17 |
+
version="1.0"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
###############################################################################
|
| 21 |
+
# TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
|
| 22 |
+
###############################################################################
|
| 23 |
+
emoji_mapping = {
|
| 24 |
+
"😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
|
| 25 |
+
"🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
|
| 26 |
+
"🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
|
| 27 |
+
"😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
|
| 28 |
+
"🤑": "[satisfaction]",
|
| 29 |
+
"🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
|
| 30 |
+
"😏": "[sarcasm]",
|
| 31 |
+
"😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
|
| 32 |
+
"😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
|
| 33 |
+
"😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
|
| 34 |
+
"🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
|
| 35 |
+
"🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
|
| 36 |
+
"😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
|
| 37 |
+
"😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
|
| 38 |
+
"😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
|
| 39 |
+
"😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
|
| 40 |
+
"😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
###############################################################################
|
| 44 |
+
# HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
|
| 45 |
+
###############################################################################
|
| 46 |
+
def replace_emojis(sentence, emoji_mapping):
|
| 47 |
+
processed_sentence = []
|
| 48 |
+
for char in sentence:
|
| 49 |
+
if char in emoji_mapping:
|
| 50 |
+
processed_sentence.append(emoji_mapping[char])
|
| 51 |
+
elif not emoji.is_emoji(char):
|
| 52 |
+
processed_sentence.append(char)
|
| 53 |
+
return ''.join(processed_sentence)
|
| 54 |
+
|
| 55 |
+
def remove_profanity(sentence):
|
| 56 |
+
profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
|
| 57 |
+
words = sentence.split()
|
| 58 |
+
filtered = [w for w in words if w.lower() not in profane_words]
|
| 59 |
+
return ' '.join(filtered)
|
| 60 |
+
|
| 61 |
+
def remove_special_characters(sentence):
|
| 62 |
+
return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)
|
| 63 |
+
|
| 64 |
+
def normalize_whitespace(sentence):
|
| 65 |
+
return ' '.join(sentence.split())
|
| 66 |
+
|
| 67 |
+
def remove_repeated_characters(sentence):
|
| 68 |
+
return re.sub(r"(.)\1{2,}", r"\1", sentence)
|
| 69 |
+
|
| 70 |
+
def replace_numbers(sentence):
|
| 71 |
+
return re.sub(r"\d+", "[number]", sentence)
|
| 72 |
+
|
| 73 |
+
def tokenize_underthesea(sentence):
|
| 74 |
+
tokens = word_tokenize(sentence)
|
| 75 |
+
return " ".join(tokens)
|
| 76 |
+
|
| 77 |
+
# Nếu có abbreviations.json, load nó. Nếu không thì để rỗng.
|
| 78 |
+
try:
|
| 79 |
+
with open("abbreviations.json", "r", encoding="utf-8") as f:
|
| 80 |
+
abbreviations = json.load(f)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
abbreviations = {}
|
| 83 |
+
|
| 84 |
+
def preprocess_sentence(sentence):
|
| 85 |
+
sentence = sentence.lower()
|
| 86 |
+
sentence = replace_emojis(sentence, emoji_mapping)
|
| 87 |
+
sentence = remove_profanity(sentence)
|
| 88 |
+
sentence = remove_special_characters(sentence)
|
| 89 |
+
sentence = normalize_whitespace(sentence)
|
| 90 |
+
# Thay thế từ viết tắt nếu có trong abbreviations
|
| 91 |
+
words = sentence.split()
|
| 92 |
+
replaced = []
|
| 93 |
+
for w in words:
|
| 94 |
+
if w in abbreviations:
|
| 95 |
+
replaced.append(" ".join(abbreviations[w]))
|
| 96 |
+
else:
|
| 97 |
+
replaced.append(w)
|
| 98 |
+
sentence = " ".join(replaced)
|
| 99 |
+
sentence = remove_repeated_characters(sentence)
|
| 100 |
+
sentence = replace_numbers(sentence)
|
| 101 |
+
sentence = tokenize_underthesea(sentence)
|
| 102 |
+
return sentence
|
| 103 |
+
|
| 104 |
+
###############################################################################
|
| 105 |
+
# LOAD CHECKPOINT
|
| 106 |
+
###############################################################################
|
| 107 |
+
checkpoint_dir = "./checkpoint" # Đường dẫn đến folder checkpoint
|
| 108 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 109 |
+
|
| 110 |
+
print("Loading config...")
|
| 111 |
+
config = AutoConfig.from_pretrained(checkpoint_dir)
|
| 112 |
+
|
| 113 |
+
# Mapping id to label theo thứ tự bạn cung cấp
|
| 114 |
+
custom_id2label = {
|
| 115 |
+
0: 'Anger',
|
| 116 |
+
1: 'Disgust',
|
| 117 |
+
2: 'Enjoyment',
|
| 118 |
+
3: 'Fear',
|
| 119 |
+
4: 'Other',
|
| 120 |
+
5: 'Sadness',
|
| 121 |
+
6: 'Surprise'
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if hasattr(config, "id2label") and config.id2label:
|
| 125 |
+
if all(label.startswith("LABEL_") for label in config.id2label.values()):
|
| 126 |
+
id2label = custom_id2label
|
| 127 |
+
else:
|
| 128 |
+
id2label = {int(k): v for k, v in config.id2label.items()}
|
| 129 |
+
else:
|
| 130 |
+
id2label = custom_id2label
|
| 131 |
+
|
| 132 |
+
print("id2label loaded:", id2label)
|
| 133 |
+
|
| 134 |
+
print("Loading tokenizer...")
|
| 135 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
|
| 136 |
+
|
| 137 |
+
print("Loading model...")
|
| 138 |
+
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
|
| 139 |
+
model.to(device)
|
| 140 |
+
model.eval()
|
| 141 |
+
|
| 142 |
+
###############################################################################
|
| 143 |
+
# HÀM PREDICT
|
| 144 |
+
###############################################################################
|
| 145 |
+
label2message = {
|
| 146 |
+
'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
|
| 147 |
+
'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
|
| 148 |
+
'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
|
| 149 |
+
'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
|
| 150 |
+
'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
|
| 151 |
+
'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
|
| 152 |
+
'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def predict_text(text: str) -> str:
|
| 156 |
+
text_proc = preprocess_sentence(text)
|
| 157 |
+
inputs = tokenizer(
|
| 158 |
+
[text_proc],
|
| 159 |
+
padding=True,
|
| 160 |
+
truncation=True,
|
| 161 |
+
max_length=256,
|
| 162 |
+
return_tensors="pt"
|
| 163 |
+
).to(device)
|
| 164 |
+
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
outputs = model(**inputs)
|
| 167 |
+
pred_id = outputs.logits.argmax(dim=-1).item()
|
| 168 |
+
|
| 169 |
+
if pred_id in id2label:
|
| 170 |
+
label = id2label[pred_id]
|
| 171 |
+
message = label2message.get(label, "")
|
| 172 |
+
if message:
|
| 173 |
+
return f"Dự đoán cảm xúc: {label}. {message}"
|
| 174 |
+
else:
|
| 175 |
+
return f"Dự đoán cảm xúc: {label}."
|
| 176 |
+
else:
|
| 177 |
+
return f"Nhãn không xác định (id={pred_id})"
|
| 178 |
+
|
| 179 |
+
###############################################################################
|
| 180 |
+
# ĐỊNH NGHĨA MODEL INPUT
|
| 181 |
+
###############################################################################
|
| 182 |
+
class InputText(BaseModel):
|
| 183 |
+
text: str
|
| 184 |
+
|
| 185 |
+
###############################################################################
|
| 186 |
+
# API ENDPOINT
|
| 187 |
+
###############################################################################
|
| 188 |
+
@app.post("/predict")
|
| 189 |
+
def predict(input_text: InputText):
|
| 190 |
+
"""
|
| 191 |
+
Nhận một câu tiếng Việt và trả về dự đoán cảm xúc.
|
| 192 |
+
"""
|
| 193 |
+
result = predict_text(input_text.text)
|
| 194 |
+
return {"result": result}
|
| 195 |
+
|
| 196 |
+
###############################################################################
|
| 197 |
+
# CHẠY API SERVER
|
| 198 |
+
###############################################################################
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
import uvicorn
|
| 201 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|