|
|
import streamlit as st |
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
st.set_page_config(page_title="ABSA App", layout="wide") |
|
|
st.title("Aspect-Based Sentiment Analysis (E2E-ABSA)") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_models(): |
|
|
ner_tokenizer = AutoTokenizer.from_pretrained("hai2131/abte-bert") |
|
|
ner_model = AutoModelForTokenClassification.from_pretrained("hai2131/abte-bert") |
|
|
|
|
|
cls_tokenizer = AutoTokenizer.from_pretrained("hai2131/absa-bert") |
|
|
cls_model = AutoModelForSequenceClassification.from_pretrained("hai2131/absa-bert") |
|
|
|
|
|
return ner_tokenizer, ner_model, cls_tokenizer, cls_model |
|
|
|
|
|
ner_tokenizer, ner_model, cls_tokenizer, cls_model = load_models() |
|
|
id2label = ner_model.config.id2label |
|
|
label2sentiment = {0: "negative", 1: "neutral", 2: "positive"} |
|
|
|
|
|
def extract_aspect_terms(text): |
|
|
inputs = ner_tokenizer(text, return_tensors="pt", truncation=True) |
|
|
with torch.no_grad(): |
|
|
outputs = ner_model(**inputs) |
|
|
predictions = torch.argmax(outputs.logits, dim=2)[0] |
|
|
tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
|
labels = [id2label[i.item()] for i in predictions] |
|
|
|
|
|
aspects = [] |
|
|
current = "" |
|
|
for token, label in zip(tokens, labels): |
|
|
if label.startswith("B-"): |
|
|
if current: |
|
|
aspects.append(current) |
|
|
current = token.replace("##", "") |
|
|
elif label.startswith("I-") and current: |
|
|
current += token.replace("##", "") if token.startswith("##") else " " + token |
|
|
else: |
|
|
if current: |
|
|
aspects.append(current) |
|
|
current = "" |
|
|
if current: |
|
|
aspects.append(current) |
|
|
return list(set(aspects)) |
|
|
|
|
|
def classify_polarity(text, aspect): |
|
|
inputs = cls_tokenizer(text, aspect, return_tensors="pt", truncation=True) |
|
|
with torch.no_grad(): |
|
|
logits = cls_model(**inputs).logits |
|
|
prediction = torch.argmax(logits, dim=1).item() |
|
|
return label2sentiment[prediction] |
|
|
|
|
|
text = st.text_area("Nhập một câu tiếng Anh để phân tích:", |
|
|
"The food was amazing, but the service was terrible.") |
|
|
|
|
|
if st.button("Phân tích"): |
|
|
with st.spinner(" Đang phân tích..."): |
|
|
aspects = extract_aspect_terms(text) |
|
|
results = [(asp, classify_polarity(text, asp)) for asp in aspects] |
|
|
|
|
|
if results: |
|
|
st.markdown("## Kết quả:") |
|
|
for asp, polarity in results: |
|
|
emoji = {"positive": "✅", "negative": "❌", "neutral": "😐"}.get(polarity, "🔹") |
|
|
st.markdown(f"- {emoji} **{asp}** — *{polarity}*") |
|
|
else: |
|
|
st.warning("⚠️ Không tìm thấy khía cạnh nào.") |
|
|
|