ABSA / app.py
hai2131's picture
Update app.py
89f7c20 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
import torch
st.set_page_config(page_title="ABSA App", layout="wide")
st.title("Aspect-Based Sentiment Analysis (E2E-ABSA)")
@st.cache_resource
def load_models():
ner_tokenizer = AutoTokenizer.from_pretrained("hai2131/abte-bert")
ner_model = AutoModelForTokenClassification.from_pretrained("hai2131/abte-bert")
cls_tokenizer = AutoTokenizer.from_pretrained("hai2131/absa-bert")
cls_model = AutoModelForSequenceClassification.from_pretrained("hai2131/absa-bert")
return ner_tokenizer, ner_model, cls_tokenizer, cls_model
ner_tokenizer, ner_model, cls_tokenizer, cls_model = load_models()
id2label = ner_model.config.id2label
label2sentiment = {0: "negative", 1: "neutral", 2: "positive"}
def extract_aspect_terms(text):
inputs = ner_tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = ner_model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)[0]
tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [id2label[i.item()] for i in predictions]
aspects = []
current = ""
for token, label in zip(tokens, labels):
if label.startswith("B-"):
if current:
aspects.append(current)
current = token.replace("##", "")
elif label.startswith("I-") and current:
current += token.replace("##", "") if token.startswith("##") else " " + token
else:
if current:
aspects.append(current)
current = ""
if current:
aspects.append(current)
return list(set(aspects))
def classify_polarity(text, aspect):
inputs = cls_tokenizer(text, aspect, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = cls_model(**inputs).logits
prediction = torch.argmax(logits, dim=1).item()
return label2sentiment[prediction]
text = st.text_area("Nhập một câu tiếng Anh để phân tích:",
"The food was amazing, but the service was terrible.")
if st.button("Phân tích"):
with st.spinner(" Đang phân tích..."):
aspects = extract_aspect_terms(text)
results = [(asp, classify_polarity(text, asp)) for asp in aspects]
if results:
st.markdown("## Kết quả:")
for asp, polarity in results:
emoji = {"positive": "✅", "negative": "❌", "neutral": "😐"}.get(polarity, "🔹")
st.markdown(f"- {emoji} **{asp}** — *{polarity}*")
else:
st.warning("⚠️ Không tìm thấy khía cạnh nào.")