import streamlit as st from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification import torch st.set_page_config(page_title="ABSA App", layout="wide") st.title("Aspect-Based Sentiment Analysis (E2E-ABSA)") @st.cache_resource def load_models(): ner_tokenizer = AutoTokenizer.from_pretrained("hai2131/abte-bert") ner_model = AutoModelForTokenClassification.from_pretrained("hai2131/abte-bert") cls_tokenizer = AutoTokenizer.from_pretrained("hai2131/absa-bert") cls_model = AutoModelForSequenceClassification.from_pretrained("hai2131/absa-bert") return ner_tokenizer, ner_model, cls_tokenizer, cls_model ner_tokenizer, ner_model, cls_tokenizer, cls_model = load_models() id2label = ner_model.config.id2label label2sentiment = {0: "negative", 1: "neutral", 2: "positive"} def extract_aspect_terms(text): inputs = ner_tokenizer(text, return_tensors="pt", truncation=True) with torch.no_grad(): outputs = ner_model(**inputs) predictions = torch.argmax(outputs.logits, dim=2)[0] tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) labels = [id2label[i.item()] for i in predictions] aspects = [] current = "" for token, label in zip(tokens, labels): if label.startswith("B-"): if current: aspects.append(current) current = token.replace("##", "") elif label.startswith("I-") and current: current += token.replace("##", "") if token.startswith("##") else " " + token else: if current: aspects.append(current) current = "" if current: aspects.append(current) return list(set(aspects)) def classify_polarity(text, aspect): inputs = cls_tokenizer(text, aspect, return_tensors="pt", truncation=True) with torch.no_grad(): logits = cls_model(**inputs).logits prediction = torch.argmax(logits, dim=1).item() return label2sentiment[prediction] text = st.text_area("Nhập một câu tiếng Anh để phân tích:", "The food was amazing, but the service was terrible.") if st.button("Phân tích"): with st.spinner(" Đang phân tích..."): aspects = extract_aspect_terms(text) results = [(asp, classify_polarity(text, asp)) for asp in aspects] if results: st.markdown("## Kết quả:") for asp, polarity in results: emoji = {"positive": "✅", "negative": "❌", "neutral": "😐"}.get(polarity, "🔹") st.markdown(f"- {emoji} **{asp}** — *{polarity}*") else: st.warning("⚠️ Không tìm thấy khía cạnh nào.")