hai2131 commited on
Commit
4d177d5
·
verified ·
1 Parent(s): 6cf61c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
3
+ import torch
4
+
5
+ st.set_page_config(page_title="ABSA App", layout="wide")
6
+ st.title("Aspect-Based Sentiment Analysis (E2E-ABSA)")
7
+
8
+ @st.cache_resource
9
+ def load_models():
10
+ ner_tokenizer = AutoTokenizer.from_pretrained("hai2131/abte-bert")
11
+ ner_model = AutoModelForTokenClassification.from_pretrained("hai2131/abte-bert")
12
+
13
+ cls_tokenizer = AutoTokenizer.from_pretrained("hai2131/absa-bert")
14
+ cls_model = AutoModelForSequenceClassification.from_pretrained("hai2131/absa-bert")
15
+
16
+ return ner_tokenizer, ner_model, cls_tokenizer, cls_model
17
+
18
+ ner_tokenizer, ner_model, cls_tokenizer, cls_model = load_models()
19
+ id2label = ner_model.config.id2label
20
+ label2sentiment = {0: "negative", 1: "neutral", 2: "positive"}
21
+
22
+ def extract_aspect_terms(text):
23
+ inputs = ner_tokenizer(text, return_tensors="pt", truncation=True)
24
+ with torch.no_grad():
25
+ outputs = ner_model(**inputs)
26
+ predictions = torch.argmax(outputs.logits, dim=2)[0]
27
+ tokens = ner_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
28
+ labels = [id2label[i.item()] for i in predictions]
29
+
30
+ aspects = []
31
+ current = ""
32
+ for token, label in zip(tokens, labels):
33
+ if label.startswith("B-"):
34
+ if current:
35
+ aspects.append(current)
36
+ current = token.replace("##", "")
37
+ elif label.startswith("I-") and current:
38
+ current += token.replace("##", "") if token.startswith("##") else " " + token
39
+ else:
40
+ if current:
41
+ aspects.append(current)
42
+ current = ""
43
+ if current:
44
+ aspects.append(current)
45
+ return list(set(aspects))
46
+
47
+ def classify_polarity(text, aspect):
48
+ inputs = cls_tokenizer(text, aspect, return_tensors="pt", truncation=True)
49
+ with torch.no_grad():
50
+ logits = cls_model(**inputs).logits
51
+ prediction = torch.argmax(logits, dim=1).item()
52
+ return label2sentiment[prediction]
53
+
54
+ text = st.text_area("✍️ Nhập một câu tiếng Anh để phân tích:",
55
+ "The food was amazing, but the service was terrible.")
56
+
57
+ if st.button("🚀 Phân tích"):
58
+ with st.spinner("Đang phân tích..."):
59
+ aspects = extract_aspect_terms(text)
60
+ results = [(asp, classify_polarity(text, asp)) for asp in aspects]
61
+
62
+ if results:
63
+ st.markdown("## 🎯 Kết quả:")
64
+ for asp, polarity in results:
65
+ emoji = {"positive": "✅", "negative": "❌", "neutral": "😐"}.get(polarity, "🔹")
66
+ st.markdown(f"- {emoji} **{asp}** — *{polarity}*")
67
+ else:
68
+ st.warning("⚠️ Không tìm thấy khía cạnh nào.")