import os import re # Đặt biến cache phòng khi runtime override (khớp Dockerfile) os.environ.setdefault("HF_HOME", "/data/hf") os.environ.setdefault("HF_HUB_CACHE", "/data/hf/hub") os.environ.setdefault("TRANSFORMERS_CACHE", "/data/hf/transformers") os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache") import math, torch from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM MODEL_ID = os.getenv("MODEL_ID", "DungSon/ViHateT5-base-HSD-Clone") HF_TOKEN = os.getenv("HF_TOKEN", None) # nếu repo private, thêm secret này trong Settings → Repository secrets device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, token=HF_TOKEN) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval() torch.set_num_threads(1) app = FastAPI(title="ViHateT5 API") class Item(BaseModel): text: str def score_labels(enc, labels): out = {} for lab in labels: ids = tok(lab, add_special_tokens=False).input_ids dec = torch.tensor([[model.config.decoder_start_token_id]], device=enc.input_ids.device) logp = 0.0 with torch.no_grad(): for t in ids: logits = model(**enc, decoder_input_ids=dec).logits[:, -1, :] logp += torch.log_softmax(logits, dim=-1)[0, t].item() dec = torch.cat([dec, torch.tensor([[t]], device=dec.device)], dim=1) out[lab] = logp return out def softmax_logs(d): m = max(d.values()) ex = {k: math.exp(v - m) for k, v in d.items()} Z = sum(ex.values()) return {k: ex[k]/Z for k in ex} def generate_text(prompt: str, max_new_tokens: int = 64): # Nếu model cần prefix tác vụ, thêm tại đây, ví dụ: # prompt = f"hate-spans-detection: {prompt}" enc = tok(prompt, return_tensors="pt", truncation=True, max_length=512).to(device) with torch.no_grad(): out = model.generate( **enc, max_new_tokens=max_new_tokens, num_beams=4, do_sample=False, early_stopping=True ) return tok.decode(out[0], skip_special_tokens=True) def extract_hate_spans(output_text: str): # Hỗ trợ cả 2 kiểu: [hate]... [hate] hoặc [hate]...[/hate] spans = [] # Kiểu 1: [hate]... [hate] spans += re.findall(r"\[hate\](.*?)\[hate\]", output_text, flags=re.IGNORECASE|re.DOTALL) # Kiểu 2: [hate]...[/hate] spans += re.findall(r"\[hate\](.*?)\[/hate\]", output_text, flags=re.IGNORECASE|re.DOTALL) # Làm sạch spans = [s.strip() for s in spans if s.strip()] return spans @app.get("/health") def health(): return {"status": "ok", "device": str(device)} @app.post("/predict") def predict(item: Item): text = item.text.strip() enc = tok(text, return_tensors="pt", truncation=True, max_length=512).to(device) tox_labels = ["NONE", "TOXIC"] tox_probs = softmax_logs(score_labels(enc, tox_labels)) tox_label = max(tox_probs, key=tox_probs.get) hsd_labels = ["CLEAN", "OFFENSIVE", "HATE"] hsd_probs = softmax_logs(score_labels(enc, hsd_labels)) hsd_label = max(hsd_probs, key=hsd_probs.get) span_prompt = text gen = generate_text(span_prompt, max_new_tokens=64) spans = extract_hate_spans(gen) return { "toxic-speech-detection": {"label": tox_label, "probs": tox_probs}, "hate-speech-detection": {"label": hsd_label, "probs": hsd_probs}, "hate-spans-detection": { "spans": spans if spans else [], "raw": gen # giữ nguyên đầu ra để bạn debug định dạng } }