from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import warnings import uvicorn # Tắt warning lặt vặt warnings.filterwarnings("ignore") # Load model MODEL_NAME = "tarudesu/ViHateT5-base-HSD" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) # Khởi tạo FastAPI app = FastAPI(title="Vietnamese Toxic Comment Detection API") # Bật CORS cho tất cả domain app.add_middleware( CORSMiddleware, allow_origins=["*"], # Cho phép tất cả domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Schema cho input class CommentInput(BaseModel): text: str prefix: str = "hate-speech-detection" # Hàm dự đoán def predict_vihatet5(comment: str, prefix: str = "hate-speech-detection"): input_text = prefix + ": " + comment inputs = tokenizer( input_text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) output_ids = model.generate(**inputs, max_length=50) output = tokenizer.decode(output_ids[0], skip_special_tokens=True) return output @app.get("/") async def root(): """Health check endpoint with comprehensive API documentation""" return { "endpoints": { "POSR https://haiss123-check-comment.hf.space/predict": { "Resquest body": { "input": "string", "prefix": "hate-speech-detection or toxic-speech-detection" }, "Response": { "input": "string", "prefix": "hate-speech-detection or toxic-speech-detection", "prediction": "offensive,hate,..." } } } } # API route @app.post("/predict") def predict(input_data: CommentInput): result = predict_vihatet5(input_data.text, prefix=input_data.prefix) return { "input": input_data.text, "prefix": input_data.prefix, "prediction": result } # Chạy app if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)