Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import httpx | |
| from transformers import AutoTokenizer, AutoModel | |
| from fastapi import FastAPI | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download | |
| load_dotenv() | |
| HF_API_KEY = os.getenv("HF_API_KEY", "") | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| LOCAL_MODEL_PATH = os.path.join(BASE_DIR, "model", "bertpho", "model.pt") | |
| HF_MODEL_REPO = "Hieu18012005/newsmind-phobert" # Model weights trên HF Hub | |
| BASE_MODEL = "vinai/phobert-base" | |
| MAX_LEN = 512 | |
| HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct" | |
| HF_MODEL_ROUTED = "meta-llama/Llama-3.1-8B-Instruct:cerebras" | |
| HF_API_URL = "https://router.huggingface.co/v1/chat/completions" | |
| def get_model_path(): | |
| """Tìm model.pt: local trước, nếu không có thì download từ HF Hub.""" | |
| if os.path.exists(LOCAL_MODEL_PATH): | |
| print(f"📂 Dùng model local: {LOCAL_MODEL_PATH}") | |
| return LOCAL_MODEL_PATH | |
| # Auto-download từ HuggingFace Hub | |
| print(f"⬇️ Đang download model từ HF Hub: {HF_MODEL_REPO}...") | |
| cache_dir = os.environ.get("HF_HOME", "/tmp/hf_cache") | |
| path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename="model.pt", | |
| cache_dir=cache_dir, | |
| ) | |
| print(f"✅ Model đã download: {path}") | |
| return path | |
| label_map = {0: "cong-nghe", 1: "kinh-doanh", 2: "the-gioi", 3: "the-thao"} | |
| label_map_vi = { | |
| "cong-nghe": "Công nghệ", | |
| "kinh-doanh": "Kinh doanh", | |
| "the-gioi": "Thế giới", | |
| "the-thao": "Thể thao", | |
| } | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ========================= | |
| # EXPAND POSITION EMBEDDING | |
| # ========================= | |
| def expand_position_embeddings(model, new_max_length=512): | |
| pos_emb_layer = model.embeddings.position_embeddings | |
| old_weight = pos_emb_layer.weight.data | |
| old_max, hidden = old_weight.shape | |
| new_max = new_max_length + 2 | |
| if old_max == new_max: | |
| return | |
| new_weight = F.interpolate( | |
| old_weight.T.unsqueeze(0), size=new_max, mode="linear", align_corners=False | |
| ).squeeze(0).T | |
| new_emb = nn.Embedding(new_max, hidden) | |
| new_emb.weight = nn.Parameter(new_weight) | |
| model.embeddings.position_embeddings = new_emb | |
| model.embeddings.register_buffer( | |
| "position_ids", torch.arange(new_max).expand((1, -1)), persistent=False | |
| ) | |
| model.embeddings.register_buffer( | |
| "token_type_ids", torch.zeros((1, new_max), dtype=torch.long), persistent=False | |
| ) | |
| # ========================= | |
| # MODEL | |
| # ========================= | |
| class PhoBERTClassifier(nn.Module): | |
| def __init__(self, model_name, num_classes): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| expand_position_embeddings(self.bert, 512) | |
| hidden = self.bert.config.hidden_size | |
| self.dropout = nn.Dropout(0.1) | |
| self.classifier = nn.Linear(hidden, num_classes) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = outputs.pooler_output | |
| pooled = self.dropout(pooled) | |
| return self.classifier(pooled) | |
| # ========================= | |
| # LOAD MODEL (1 lần khi khởi động) | |
| # ========================= | |
| print("Đang load model PhoBERT...") | |
| MODEL_PATH = get_model_path() | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False) | |
| model = PhoBERTClassifier(BASE_MODEL, 4) | |
| state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| print(f"Model sẵn sàng trên {device}") | |
| # ========================= | |
| # PREDICT | |
| # ========================= | |
| def predict(text: str): | |
| inputs = tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=MAX_LEN, | |
| return_tensors="pt", | |
| ) | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| logits = model(input_ids, attention_mask) | |
| probs = torch.softmax(logits, dim=1) | |
| pred = torch.argmax(probs, dim=1).item() | |
| label = label_map[pred] | |
| confidence = round(probs[0][pred].item(), 4) | |
| all_probs = { | |
| label_map[i]: round(probs[0][i].item() * 100, 2) for i in range(4) | |
| } | |
| return label, confidence, all_probs | |
| # ========================= | |
| # FASTAPI APP | |
| # ========================= | |
| app = FastAPI(title="NewsMind API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory="."), name="static") | |
| # ========================= | |
| # SCHEMAS | |
| # ========================= | |
| class TextIn(BaseModel): | |
| text: str | |
| class ChatIn(BaseModel): | |
| messages: list[dict] | |
| system: str = "" | |
| # ========================= | |
| # ROUTES | |
| # ========================= | |
| def root(): | |
| return FileResponse("index.html") | |
| def health(): | |
| return {"status": "ok", "device": str(device)} | |
| def config(): | |
| # Không trả key ra client nữa — chỉ báo có hay không | |
| return {"has_hf_key": bool(HF_API_KEY)} | |
| def classify(body: TextIn): | |
| if not body.text.strip(): | |
| return {"error": "Text rỗng"} | |
| label, confidence, all_probs = predict(body.text) | |
| words = body.text.strip().split() | |
| word_count = len(words) | |
| read_time = max(1, round(word_count / 200)) | |
| if confidence >= 0.80: | |
| tier = "HIGH" | |
| elif confidence >= 0.55: | |
| tier = "MED" | |
| else: | |
| tier = "LOW" | |
| sentences = [s.strip() for s in body.text.replace("!", ".").replace("?", ".").split(".") if len(s.strip()) > 20] | |
| top_sentences = ". ".join(sentences[:2]) + ("." if sentences else "") | |
| return { | |
| "label": label, | |
| "label_vi": label_map_vi[label], | |
| "confidence": confidence, | |
| "confidence_tier": tier, | |
| "all_probs": all_probs, | |
| "word_count": word_count, | |
| "read_time": read_time, | |
| "top_sentences": top_sentences, | |
| } | |
| def classify_batch(items: list[TextIn]): | |
| results = [] | |
| for i, item in enumerate(items): | |
| if not item.text.strip(): | |
| results.append({"index": i, "error": "Text rỗng"}) | |
| continue | |
| label, confidence, all_probs = predict(item.text) | |
| words = item.text.strip().split() | |
| tier = "HIGH" if confidence >= 0.80 else ("MED" if confidence >= 0.55 else "LOW") | |
| results.append({ | |
| "index": i, | |
| "label": label, | |
| "label_vi": label_map_vi[label], | |
| "confidence": confidence, | |
| "confidence_tier": tier, | |
| "all_probs": all_probs, | |
| "word_count": len(words), | |
| "preview": item.text[:80] + ("..." if len(item.text) > 80 else ""), | |
| }) | |
| return results | |
| async def chat(body: ChatIn): | |
| """Proxy LLM call qua backend — tránh CORS và ẩn API key.""" | |
| if not HF_API_KEY: | |
| return {"error": "Chưa có HF_API_KEY trong file .env"} | |
| msgs = [] | |
| if body.system: | |
| msgs.append({"role": "system", "content": body.system}) | |
| msgs.extend(body.messages) | |
| try: | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| r = await client.post( | |
| HF_API_URL, | |
| headers={ | |
| "Authorization": f"Bearer {HF_API_KEY}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={ | |
| "model": HF_MODEL_ROUTED, | |
| "messages": msgs, | |
| "max_tokens": 500, | |
| "temperature": 0.65, | |
| }, | |
| ) | |
| if not r.is_success: | |
| try: | |
| err = r.json() | |
| if isinstance(err, dict): | |
| msg = err.get("error", {}) | |
| if isinstance(msg, dict): | |
| msg = msg.get("message", str(err)) | |
| else: | |
| msg = str(msg) | |
| else: | |
| msg = str(err) | |
| except Exception: | |
| msg = r.text[:300] | |
| return {"error": f"HF API {r.status_code}: {msg}"} | |
| try: | |
| data = r.json() | |
| reply = data["choices"][0]["message"]["content"].strip() | |
| except Exception as e: | |
| return {"error": f"Parse lỗi response HF: {e} | Raw: {r.text[:200]}"} | |
| return {"reply": reply} | |
| except httpx.TimeoutException: | |
| return {"error": "Timeout khi gọi HF API (>30s). Thử lại sau."} | |
| except Exception as e: | |
| return {"error": str(e)} |