sentiment-analysis / utils.py
TinTinDo's picture
Upload 5 files
9262cf2 verified
Raw
History Blame Contribute Delete
1.57 kB
# utils.py
import torch
import time
import re
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
LABELS = {0: "Negative", 1: "Positive"}
MAX_LENGTH = 256
@st.cache_resource # Cache model — chỉ load 1 lần
def load_model():
print("🔄 Loading model from HuggingFace Hub...")
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"✅ Model loaded on {device}")
return tokenizer, model, device
def clean_text(text: str) -> str:
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'[^a-zA-Z\s]', '', text)
return text.lower().strip()
def predict_sentiment(text: str, tokenizer, model, device) -> dict:
start = time.time()
inputs = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=MAX_LENGTH,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
pred = torch.argmax(probs, dim=-1).item()
conf = probs[0][pred].item()
return {
"sentiment" : LABELS[pred],
"confidence" : round(conf, 4),
"inference_time_ms": round((time.time() - start) * 1000, 2)
}