Check_Comment / app.py
Haiss123's picture
Update app.py
35275d1 verified
raw
history blame
1.64 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import warnings
import uvicorn
# Tắt warning lặt vặt
warnings.filterwarnings("ignore")
# Load model
MODEL_NAME = "tarudesu/ViHateT5-base-HSD"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# Khởi tạo FastAPI
app = FastAPI(title="Vietnamese Toxic Comment Detection API")
# Bật CORS cho tất cả domain
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Cho phép tất cả domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Schema cho input
class CommentInput(BaseModel):
text: str
prefix: str = "hate-speech-detection"
# Hàm dự đoán
def predict_vihatet5(comment: str, prefix: str = "hate-speech-detection"):
input_text = prefix + ": " + comment
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
output_ids = model.generate(**inputs, max_length=50)
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output
# API route
@app.post("/predict")
def predict(input_data: CommentInput):
result = predict_vihatet5(input_data.text, prefix=input_data.prefix)
return {
"input": input_data.text,
"prefix": input_data.prefix,
"prediction": result
}
# Chạy app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)