EmmaL1 commited on
Commit
31f733b
·
verified ·
1 Parent(s): b4dc2a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -77
app.py CHANGED
@@ -1,84 +1,45 @@
 
1
  import streamlit as st
2
- import numpy as np
3
  from transformers import pipeline
 
4
 
5
- # 环境稳定化配置
6
- @st.cache_resource
7
- def load_pipelines():
8
- try:
9
- sentiment_pipe = pipeline(
10
- "text-classification",
11
- model="EmmaL1/CustomModel_amazon",
12
- return_all_scores=True,
13
- device="cpu",
14
- torch_dtype="float32"
15
- )
16
- qa_pipe = pipeline(
17
- "question-answering",
18
- model="distilbert/distilbert-base-cased-distilled-squad",
19
- device="cpu"
20
- )
21
- return sentiment_pipe, qa_pipe
22
- except Exception as e:
23
- st.error(f"Initialization Error: {str(e)}")
24
- st.stop()
25
 
26
- sentiment_pipeline, qa_pipeline = load_pipelines()
 
27
 
28
- def get_balanced_rating(prediction): probs = np.array([score['score'] for score in prediction[0][:5]])
29
- probs = probs / (probs.sum() + 1e-8) # 安全标准化
30
-
31
- # 优化后的阈值设置(可调整)
32
- rating_rules = [
33
- (0.95, 1), # 概率>95%才判为1星
34
- (0.85, 2),
35
- (0.70, 3),
36
- (0.50, 4),
37
- (0.00, 5) # 默认5星
38
- ]
39
-
40
- for threshold, rating in rating_rules:
41
- if probs[rating-1] > threshold:
42
- return rating, probs[rating-1]
43
-
44
- return 5, probs[4] # 默认返回5星
45
 
46
- def main():
47
- st.title("Amazon Review Rating System")
48
-
49
- user_input = st.text_area("Input Review Text:", height=150)
50
-
51
- if st.button("Analyze"):
52
- with st.spinner("Processing..."):
53
- try:
54
- # 获取预测
55
- prediction = sentiment_pipeline(user_input)
56
-
57
- # 计算优化后的评分
58
- rating, confidence = get_balanced_rating(prediction)
59
-
60
- # 显示结果
61
- st.subheader("Analysis Result")
62
- col1, col2 = st.columns(2)
63
- with col1:
64
- st.metric("Predicted Rating", f"{rating} stars")
65
- with col2:
66
- st.metric("Confidence", f"{confidence:.1%}")
67
-
68
- # 生成解释
69
- qa_result = qa_pipeline({
70
- 'question': f'What justifies this {rating}-star rating?',
71
- 'context': user_input
72
- })
73
- st.subheader("Key Factors")
74
- st.write(qa_result['answer'])
75
-
76
- # 调试信息(可选)
77
- if st.checkbox("Show debug info"):
78
- st.write("Raw probabilities:", [f"{p:.4f}" for p in prediction[0]])
79
-
80
- except Exception as e:
81
- st.error(f"Error: {str(e)}")
82
 
83
- if __name__ == "__main__":
84
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import part
2
  import streamlit as st
 
3
  from transformers import pipeline
4
+ import numpy as np # Import numpy
5
 
6
+ # Initialize sentiment analysis pipeline
7
+ sentiment_pipeline = pipeline(model="EmmaL1/CustomModel_amazon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Initialize question-answering pipeline
10
+ qa_pipeline = pipeline("question-answering", model="distilbert/distilbert-base-cased-distilled-squad")
11
 
12
+ # function part
13
+ # text classification
14
+ def textclassification():
15
+ st.title("Amazon Customer Sentiment Analysis:Ratings and Reasons")
16
+ st.write("Enter a sentence to analyze its rating and reason:")
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ user_input = st.text_input("Input your text:")
19
+ if user_input:
20
+ # Sentiment Analysis
21
+ sentiment_result = sentiment_pipeline(user_input)
22
+ sentiment = sentiment_result[0]["label"]
23
+ confidence = sentiment_result[0]["score"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ st.write(f"Sentiment: {sentiment}")
26
+ st.write(f"Confidence: {confidence:.2f}")
27
+
28
+ # Determine the rating based on confidence
29
+ if sentiment == "POSITIVE":
30
+ # Scale confidence to a rating of 1 to 5
31
+ rating = int(confidence * 4) + 1
32
+ else:
33
+ # For negative sentiment, invert the confidence score
34
+ rating = int((1 - confidence) * 4) + 1
35
+
36
+ # Display the rating
37
+ st.write(f"The rating is {rating} stars")
38
+
39
+ # Question Answering
40
+ qa_input = {
41
+ 'question': f'Why is the rating {rating} star?',
42
+ 'context': user_input # Use user input as context
43
+ }
44
+ qa_result = qa_pipeline(qa_input)
45
+ st.write(f"Reasons: {qa_result['answer']}")