DungSon's picture
Update app.py
21e929b verified
import os
import re
# Đặt biến cache phòng khi runtime override (khớp Dockerfile)
os.environ.setdefault("HF_HOME", "/data/hf")
os.environ.setdefault("HF_HUB_CACHE", "/data/hf/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", "/data/hf/transformers")
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")
import math, torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_ID = os.getenv("MODEL_ID", "DungSon/ViHateT5-base-HSD-Clone")
HF_TOKEN = os.getenv("HF_TOKEN", None) # nếu repo private, thêm secret này trong Settings → Repository secrets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, token=HF_TOKEN)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval()
torch.set_num_threads(1)
app = FastAPI(title="ViHateT5 API")
class Item(BaseModel):
text: str
def score_labels(enc, labels):
out = {}
for lab in labels:
ids = tok(lab, add_special_tokens=False).input_ids
dec = torch.tensor([[model.config.decoder_start_token_id]], device=enc.input_ids.device)
logp = 0.0
with torch.no_grad():
for t in ids:
logits = model(**enc, decoder_input_ids=dec).logits[:, -1, :]
logp += torch.log_softmax(logits, dim=-1)[0, t].item()
dec = torch.cat([dec, torch.tensor([[t]], device=dec.device)], dim=1)
out[lab] = logp
return out
def softmax_logs(d):
m = max(d.values())
ex = {k: math.exp(v - m) for k, v in d.items()}
Z = sum(ex.values())
return {k: ex[k]/Z for k in ex}
def generate_text(prompt: str, max_new_tokens: int = 64):
# Nếu model cần prefix tác vụ, thêm tại đây, ví dụ:
# prompt = f"hate-spans-detection: {prompt}"
enc = tok(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
out = model.generate(
**enc,
max_new_tokens=max_new_tokens,
num_beams=4,
do_sample=False,
early_stopping=True
)
return tok.decode(out[0], skip_special_tokens=True)
def extract_hate_spans(output_text: str):
# Hỗ trợ cả 2 kiểu: [hate]... [hate] hoặc [hate]...[/hate]
spans = []
# Kiểu 1: [hate]... [hate]
spans += re.findall(r"\[hate\](.*?)\[hate\]", output_text, flags=re.IGNORECASE|re.DOTALL)
# Kiểu 2: [hate]...[/hate]
spans += re.findall(r"\[hate\](.*?)\[/hate\]", output_text, flags=re.IGNORECASE|re.DOTALL)
# Làm sạch
spans = [s.strip() for s in spans if s.strip()]
return spans
@app.get("/health")
def health():
return {"status": "ok", "device": str(device)}
@app.post("/predict")
def predict(item: Item):
text = item.text.strip()
enc = tok(text, return_tensors="pt", truncation=True, max_length=512).to(device)
tox_labels = ["NONE", "TOXIC"]
tox_probs = softmax_logs(score_labels(enc, tox_labels))
tox_label = max(tox_probs, key=tox_probs.get)
hsd_labels = ["CLEAN", "OFFENSIVE", "HATE"]
hsd_probs = softmax_logs(score_labels(enc, hsd_labels))
hsd_label = max(hsd_probs, key=hsd_probs.get)
span_prompt = text
gen = generate_text(span_prompt, max_new_tokens=64)
spans = extract_hate_spans(gen)
return {
"toxic-speech-detection": {"label": tox_label, "probs": tox_probs},
"hate-speech-detection": {"label": hsd_label, "probs": hsd_probs},
"hate-spans-detection": {
"spans": spans if spans else [],
"raw": gen # giữ nguyên đầu ra để bạn debug định dạng
}
}