EmmaL1 commited on
Commit
b4dc2a8
·
verified ·
1 Parent(s): 02170d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -34
app.py CHANGED
@@ -1,67 +1,84 @@
1
  import streamlit as st
2
  import numpy as np
3
- try:
4
- from transformers import pipeline
5
- except ImportError:
6
- st.error("Transformers pipeline not available. Installing required version...")
7
- import subprocess
8
- import sys
9
- subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.30.0"])
10
- from transformers import pipeline
11
 
12
- # Initialize pipelines with error handling
13
  @st.cache_resource
14
  def load_pipelines():
15
  try:
16
  sentiment_pipe = pipeline(
17
- "text-classification",
18
  model="EmmaL1/CustomModel_amazon",
19
- return_all_scores=True
 
 
20
  )
21
  qa_pipe = pipeline(
22
- "question-answering",
23
- model="distilbert/distilbert-base-cased-distilled-squad"
 
24
  )
25
  return sentiment_pipe, qa_pipe
26
  except Exception as e:
27
- st.error(f"Pipeline initialization failed: {str(e)}")
28
  st.stop()
29
 
30
  sentiment_pipeline, qa_pipeline = load_pipelines()
31
 
32
- def get_rating(prediction):
33
- """Get rating (1-5 stars) from model prediction"""
34
- scores = [score['score'] for score in prediction[0]]
35
- return np.argmax(scores) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def main():
38
- st.title("Amazon Review Analysis")
39
- st.write("Analyze review sentiment and predict star rating")
40
 
41
- user_input = st.text_area("Enter review text:", height=150)
42
-
43
- if st.button("Analyze") and user_input:
44
  with st.spinner("Processing..."):
45
  try:
46
- # Get prediction
47
- sentiment_pred = sentiment_pipeline(user_input)
48
- rating = get_rating(sentiment_pred)
 
 
49
 
50
- # Display results
51
- st.subheader("Results")
52
- st.metric("Predicted Rating", f"{rating}/5 stars")
53
- st.metric("Sentiment", "Positive" if rating >= 3 else "Negative")
 
 
 
54
 
55
- # Generate justification
56
  qa_result = qa_pipeline({
57
- 'question': f'Why would this get {rating} stars?',
58
  'context': user_input
59
  })
60
- st.subheader("Key Reasons")
61
  st.write(qa_result['answer'])
62
 
 
 
 
 
63
  except Exception as e:
64
- st.error(f"Analysis failed: {str(e)}")
65
 
66
  if __name__ == "__main__":
67
  main()
 
1
  import streamlit as st
2
  import numpy as np
3
+ from transformers import pipeline
 
 
 
 
 
 
 
4
 
5
+ # 环境稳定化配置
6
  @st.cache_resource
7
  def load_pipelines():
8
  try:
9
  sentiment_pipe = pipeline(
10
+ "text-classification",
11
  model="EmmaL1/CustomModel_amazon",
12
+ return_all_scores=True,
13
+ device="cpu",
14
+ torch_dtype="float32"
15
  )
16
  qa_pipe = pipeline(
17
+ "question-answering",
18
+ model="distilbert/distilbert-base-cased-distilled-squad",
19
+ device="cpu"
20
  )
21
  return sentiment_pipe, qa_pipe
22
  except Exception as e:
23
+ st.error(f"Initialization Error: {str(e)}")
24
  st.stop()
25
 
26
  sentiment_pipeline, qa_pipeline = load_pipelines()
27
 
28
+ def get_balanced_rating(prediction): probs = np.array([score['score'] for score in prediction[0][:5]])
29
+ probs = probs / (probs.sum() + 1e-8) # 安全标准化
30
+
31
+ # 优化后的阈值设置(可调整)
32
+ rating_rules = [
33
+ (0.95, 1), # 概率>95%才判为1星
34
+ (0.85, 2),
35
+ (0.70, 3),
36
+ (0.50, 4),
37
+ (0.00, 5) # 默认5星
38
+ ]
39
+
40
+ for threshold, rating in rating_rules:
41
+ if probs[rating-1] > threshold:
42
+ return rating, probs[rating-1]
43
+
44
+ return 5, probs[4] # 默认返回5星
45
 
46
  def main():
47
+ st.title("Amazon Review Rating System")
 
48
 
49
+ user_input = st.text_area("Input Review Text:", height=150)
50
+
51
+ if st.button("Analyze"):
52
  with st.spinner("Processing..."):
53
  try:
54
+ # 获取预测
55
+ prediction = sentiment_pipeline(user_input)
56
+
57
+ # 计算优化后的评分
58
+ rating, confidence = get_balanced_rating(prediction)
59
 
60
+ # 显示结果
61
+ st.subheader("Analysis Result")
62
+ col1, col2 = st.columns(2)
63
+ with col1:
64
+ st.metric("Predicted Rating", f"{rating} stars")
65
+ with col2:
66
+ st.metric("Confidence", f"{confidence:.1%}")
67
 
68
+ # 生成解释
69
  qa_result = qa_pipeline({
70
+ 'question': f'What justifies this {rating}-star rating?',
71
  'context': user_input
72
  })
73
+ st.subheader("Key Factors")
74
  st.write(qa_result['answer'])
75
 
76
+ # 调试信息(可选)
77
+ if st.checkbox("Show debug info"):
78
+ st.write("Raw probabilities:", [f"{p:.4f}" for p in prediction[0]])
79
+
80
  except Exception as e:
81
+ st.error(f"Error: {str(e)}")
82
 
83
  if __name__ == "__main__":
84
  main()