NewsMind / api.py
Hieu18012005's picture
Upload api.py with huggingface_hub
d3b4db0 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import httpx
from transformers import AutoTokenizer, AutoModel
from fastapi import FastAPI
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
load_dotenv()
HF_API_KEY = os.getenv("HF_API_KEY", "")
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
# =========================
# CONFIG
# =========================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
LOCAL_MODEL_PATH = os.path.join(BASE_DIR, "model", "bertpho", "model.pt")
HF_MODEL_REPO = "Hieu18012005/newsmind-phobert" # Model weights trên HF Hub
BASE_MODEL = "vinai/phobert-base"
MAX_LEN = 512
HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
HF_MODEL_ROUTED = "meta-llama/Llama-3.1-8B-Instruct:cerebras"
HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
def get_model_path():
"""Tìm model.pt: local trước, nếu không có thì download từ HF Hub."""
if os.path.exists(LOCAL_MODEL_PATH):
print(f"📂 Dùng model local: {LOCAL_MODEL_PATH}")
return LOCAL_MODEL_PATH
# Auto-download từ HuggingFace Hub
print(f"⬇️ Đang download model từ HF Hub: {HF_MODEL_REPO}...")
cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache")
path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="model.pt",
cache_dir=cache_dir,
)
print(f"✅ Model đã download: {path}")
return path
label_map = {0: "cong-nghe", 1: "kinh-doanh", 2: "the-gioi", 3: "the-thao"}
label_map_vi = {
"cong-nghe": "Công nghệ",
"kinh-doanh": "Kinh doanh",
"the-gioi": "Thế giới",
"the-thao": "Thể thao",
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =========================
# EXPAND POSITION EMBEDDING
# =========================
def expand_position_embeddings(model, new_max_length=512):
pos_emb_layer = model.embeddings.position_embeddings
old_weight = pos_emb_layer.weight.data
old_max, hidden = old_weight.shape
new_max = new_max_length + 2
if old_max == new_max:
return
new_weight = F.interpolate(
old_weight.T.unsqueeze(0), size=new_max, mode="linear", align_corners=False
).squeeze(0).T
new_emb = nn.Embedding(new_max, hidden)
new_emb.weight = nn.Parameter(new_weight)
model.embeddings.position_embeddings = new_emb
model.embeddings.register_buffer(
"position_ids", torch.arange(new_max).expand((1, -1)), persistent=False
)
model.embeddings.register_buffer(
"token_type_ids", torch.zeros((1, new_max), dtype=torch.long), persistent=False
)
# =========================
# MODEL
# =========================
class PhoBERTClassifier(nn.Module):
def __init__(self, model_name, num_classes):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
expand_position_embeddings(self.bert, 512)
hidden = self.bert.config.hidden_size
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(hidden, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.pooler_output
pooled = self.dropout(pooled)
return self.classifier(pooled)
# =========================
# LOAD MODEL (1 lần khi khởi động)
# =========================
print("Đang load model PhoBERT...")
MODEL_PATH = get_model_path()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
model = PhoBERTClassifier(BASE_MODEL, 4)
state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print(f"Model sẵn sàng trên {device}")
# =========================
# PREDICT
# =========================
def predict(text: str):
inputs = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LEN,
return_tensors="pt",
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
logits = model(input_ids, attention_mask)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
label = label_map[pred]
confidence = round(probs[0][pred].item(), 4)
all_probs = {
label_map[i]: round(probs[0][i].item() * 100, 2) for i in range(4)
}
return label, confidence, all_probs
# =========================
# FASTAPI APP
# =========================
app = FastAPI(title="NewsMind API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="."), name="static")
# =========================
# SCHEMAS
# =========================
class TextIn(BaseModel):
text: str
class ChatIn(BaseModel):
messages: list[dict]
system: str = ""
# =========================
# ROUTES
# =========================
@app.get("/")
def root():
return FileResponse("index.html")
@app.get("/health")
def health():
return {"status": "ok", "device": str(device)}
@app.get("/config")
def config():
# Không trả key ra client nữa — chỉ báo có hay không
return {"has_hf_key": bool(HF_API_KEY)}
@app.post("/classify_text")
def classify(body: TextIn):
if not body.text.strip():
return {"error": "Text rỗng"}
label, confidence, all_probs = predict(body.text)
words = body.text.strip().split()
word_count = len(words)
read_time = max(1, round(word_count / 200))
if confidence >= 0.80:
tier = "HIGH"
elif confidence >= 0.55:
tier = "MED"
else:
tier = "LOW"
sentences = [s.strip() for s in body.text.replace("!", ".").replace("?", ".").split(".") if len(s.strip()) > 20]
top_sentences = ". ".join(sentences[:2]) + ("." if sentences else "")
return {
"label": label,
"label_vi": label_map_vi[label],
"confidence": confidence,
"confidence_tier": tier,
"all_probs": all_probs,
"word_count": word_count,
"read_time": read_time,
"top_sentences": top_sentences,
}
@app.post("/classify_batch")
def classify_batch(items: list[TextIn]):
results = []
for i, item in enumerate(items):
if not item.text.strip():
results.append({"index": i, "error": "Text rỗng"})
continue
label, confidence, all_probs = predict(item.text)
words = item.text.strip().split()
tier = "HIGH" if confidence >= 0.80 else ("MED" if confidence >= 0.55 else "LOW")
results.append({
"index": i,
"label": label,
"label_vi": label_map_vi[label],
"confidence": confidence,
"confidence_tier": tier,
"all_probs": all_probs,
"word_count": len(words),
"preview": item.text[:80] + ("..." if len(item.text) > 80 else ""),
})
return results
@app.post("/chat")
async def chat(body: ChatIn):
"""Proxy LLM call qua backend — tránh CORS và ẩn API key."""
if not HF_API_KEY:
return {"error": "Chưa có HF_API_KEY trong file .env"}
msgs = []
if body.system:
msgs.append({"role": "system", "content": body.system})
msgs.extend(body.messages)
try:
async with httpx.AsyncClient(timeout=30) as client:
r = await client.post(
HF_API_URL,
headers={
"Authorization": f"Bearer {HF_API_KEY}",
"Content-Type": "application/json",
},
json={
"model": HF_MODEL_ROUTED,
"messages": msgs,
"max_tokens": 500,
"temperature": 0.65,
},
)
if not r.is_success:
try:
err = r.json()
if isinstance(err, dict):
msg = err.get("error", {})
if isinstance(msg, dict):
msg = msg.get("message", str(err))
else:
msg = str(msg)
else:
msg = str(err)
except Exception:
msg = r.text[:300]
return {"error": f"HF API {r.status_code}: {msg}"}
try:
data = r.json()
reply = data["choices"][0]["message"]["content"].strip()
except Exception as e:
return {"error": f"Parse lỗi response HF: {e} | Raw: {r.text[:200]}"}
return {"reply": reply}
except httpx.TimeoutException:
return {"error": "Timeout khi gọi HF API (>30s). Thử lại sau."}
except Exception as e:
return {"error": str(e)}