File size: 2,704 Bytes
4d177d5 2e4bd0b 4d177d5 89f7c20 2e4bd0b 4d177d5 2e4bd0b 4d177d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
import torch
st.set_page_config(page_title="ABSA App", layout="wide")
st.title("Aspect-Based Sentiment Analysis (E2E-ABSA)")
@st.cache_resource
def load_models():
ner_tokenizer = AutoTokenizer.from_pretrained("hai2131/abte-bert")
ner_model = AutoModelForTokenClassification.from_pretrained("hai2131/abte-bert")
cls_tokenizer = AutoTokenizer.from_pretrained("hai2131/absa-bert")
cls_model = AutoModelForSequenceClassification.from_pretrained("hai2131/absa-bert")
return ner_tokenizer, ner_model, cls_tokenizer, cls_model
ner_tokenizer, ner_model, cls_tokenizer, cls_model = load_models()
id2label = ner_model.config.id2label
label2sentiment = {0: "negative", 1: "neutral", 2: "positive"}
def extract_aspect_terms(text):
inputs = ner_tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = ner_model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)[0]
tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [id2label[i.item()] for i in predictions]
aspects = []
current = ""
for token, label in zip(tokens, labels):
if label.startswith("B-"):
if current:
aspects.append(current)
current = token.replace("##", "")
elif label.startswith("I-") and current:
current += token.replace("##", "") if token.startswith("##") else " " + token
else:
if current:
aspects.append(current)
current = ""
if current:
aspects.append(current)
return list(set(aspects))
def classify_polarity(text, aspect):
inputs = cls_tokenizer(text, aspect, return_tensors="pt", truncation=True)
with torch.no_grad():
logits = cls_model(**inputs).logits
prediction = torch.argmax(logits, dim=1).item()
return label2sentiment[prediction]
text = st.text_area("Nhập một câu tiếng Anh để phân tích:",
"The food was amazing, but the service was terrible.")
if st.button("Phân tích"):
with st.spinner(" Đang phân tích..."):
aspects = extract_aspect_terms(text)
results = [(asp, classify_polarity(text, asp)) for asp in aspects]
if results:
st.markdown("## Kết quả:")
for asp, polarity in results:
emoji = {"positive": "✅", "negative": "❌", "neutral": "😐"}.get(polarity, "🔹")
st.markdown(f"- {emoji} **{asp}** — *{polarity}*")
else:
st.warning("⚠️ Không tìm thấy khía cạnh nào.")
|