Stellajin916 commited on
Commit
602715a
·
verified ·
1 Parent(s): 0987ff5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -46
app.py CHANGED
@@ -1,26 +1,22 @@
1
- import os
2
- import zipfile
3
- import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
6
 
7
- # 情感模型(京东)
8
  sentiment_model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
9
  sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
10
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
11
  sentiment_model.eval()
12
 
13
- # 解压你自己的多标签模型
14
- if not os.path.exists("result"):
15
- with zipfile.ZipFile("model_output.zip", "r") as zip_ref:
16
- zip_ref.extractall(".")
 
17
 
18
- # 加载你的多标签分类模型
19
- custom_tokenizer = AutoTokenizer.from_pretrained("result")
20
- custom_model = AutoModelForSequenceClassification.from_pretrained("result", use_safetensors=True)
21
- custom_model.eval()
22
-
23
- # 多标签类别
24
  label_map = {
25
  0: "Landscape & Culture",
26
  1: "Service & Facilities",
@@ -29,45 +25,71 @@ label_map = {
29
  4: "Interactive Activities",
30
  5: "Price & Consumption"
31
  }
 
 
 
 
 
 
32
 
33
- # 推理函数
34
- def analyze(text, threshold=0.5):
35
- # 情感分析
36
- inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
37
  with torch.no_grad():
38
- outputs = sentiment_model(**inputs)
39
- probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
40
- sentiment = "积极 (Positive)" if torch.argmax(outputs.logits) == 1 else "消极 (Negative)"
41
- sentiment_result = f"{sentiment}\nPositive: {probs[1]:.2f}, Negative: {probs[0]:.2f}"
42
 
43
- # 多标签分类
44
- inputs = custom_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  with torch.no_grad():
46
- logits = custom_model(**inputs).logits
 
47
  probs = torch.sigmoid(logits).squeeze().tolist()
48
- if isinstance(probs, float): # 单个标签时
 
49
  probs = [probs]
50
- results = [
51
- f"{label_map[i]} ({probs[i]:.2f})"
52
- for i in range(len(probs)) if probs[i] >= threshold
53
- ]
54
- if results:
55
- label_result = "\n".join(results)
56
  else:
57
- label_result = "The model was unable to identify the correct labels."
 
 
 
 
 
 
 
 
 
 
58
 
59
- return f"【Sentiment analysis】\n{sentiment_result}\n\n【Category of topic】\n{label_result}"
 
 
 
 
 
 
60
 
61
- # 创建 Gradio 页面
62
- demo = gr.Interface(
63
- fn=analyze,
64
- inputs=[
65
- gr.Textbox(lines=3, label="请输入评论内容"),
66
- gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.5, label="分类标签阈值")
67
- ],
68
- outputs="text",
69
- title="中文评论分析器",
70
- description="使用京东情感模型 + 自定义多标签模型,对评论内容进行双重分析"
71
- )
72
 
73
- demo.launch()
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
+ import gradio as gr
5
+ import os
6
 
7
+ # --------- Sentiment Model (Binary, expanded to 3 classes) ---------
8
  sentiment_model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
9
  sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
10
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
11
  sentiment_model.eval()
12
 
13
+ # --------- Multi-label Classification Model (Your model) ---------
14
+ label_dir = "./result"
15
+ label_tokenizer = BertTokenizer.from_pretrained(label_dir)
16
+ label_model = BertForSequenceClassification.from_pretrained(label_dir)
17
+ label_model.eval()
18
 
19
+ # Label categories
 
 
 
 
 
20
  label_map = {
21
  0: "Landscape & Culture",
22
  1: "Service & Facilities",
 
25
  4: "Interactive Activities",
26
  5: "Price & Consumption"
27
  }
28
+ threshold = 0.5
29
+
30
+ # --------- Inference Function ---------
31
+ def analyze(text):
32
+ if not text.strip():
33
+ return "Please enter a valid comment.", "Please enter a valid comment."
34
 
35
+ # --- Sentiment Analysis ---
36
+ sent_inputs = sentiment_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=128)
 
 
37
  with torch.no_grad():
38
+ sent_outputs = sentiment_model(**sent_inputs)
39
+ probs = torch.softmax(sent_outputs.logits, dim=1).squeeze().tolist()
 
 
40
 
41
+ pos_prob, neg_prob = probs[1], probs[0]
42
+ if abs(pos_prob - neg_prob) < 0.2:
43
+ sentiment_label = "Neutral"
44
+ elif pos_prob > neg_prob:
45
+ sentiment_label = "Positive"
46
+ else:
47
+ sentiment_label = "Negative"
48
+
49
+ sentiment_result = (
50
+ f"Prediction: {sentiment_label}\n\n"
51
+ f"Sentiment Scores:\n"
52
+ f"Positive: {pos_prob:.2f}\n"
53
+ f"Neutral: {1 - abs(pos_prob - neg_prob):.2f} (estimated)\n"
54
+ f"Negative: {neg_prob:.2f}"
55
+ )
56
+
57
+ # --- Label Prediction ---
58
+ label_inputs = label_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
59
  with torch.no_grad():
60
+ label_outputs = label_model(**label_inputs)
61
+ logits = label_outputs.logits
62
  probs = torch.sigmoid(logits).squeeze().tolist()
63
+
64
+ if isinstance(probs, float):
65
  probs = [probs]
66
+
67
+ selected_labels = [label_map[i] for i, p in enumerate(probs) if p >= threshold]
68
+ if selected_labels:
69
+ label_result = "Predicted Tags:\n" + "\n".join([f"{label_map[i]} ({probs[i]:.2f})" for i in range(len(probs)) if probs[i] >= threshold])
 
 
70
  else:
71
+ label_result = "No confident labels identified by the model."
72
+
73
+ return sentiment_result, label_result
74
+
75
+ # --------- Gradio Web UI ---------
76
+ with gr.Blocks(title="Sentiment + Tag Analysis System") as demo:
77
+ gr.Markdown("## 🌟 Comment Analyzer")
78
+ gr.Markdown(
79
+ "This tool analyzes **Chinese product reviews** using deep learning models. "
80
+ "It predicts both **sentiment polarity** (Positive / Neutral / Negative) and **semantic category tags** (6 themes)."
81
+ )
82
 
83
+ with gr.Row():
84
+ with gr.Column():
85
+ input_box = gr.Textbox(label="Enter a JD.com review", placeholder="e.g., The park is peaceful and the staff are friendly...", lines=4)
86
+ submit_btn = gr.Button("🔍 Analyze")
87
+ with gr.Column():
88
+ sentiment_output = gr.Textbox(label="Sentiment Result", lines=6)
89
+ label_output = gr.Textbox(label="Tag Classification Result", lines=6)
90
 
91
+ submit_btn.click(fn=analyze, inputs=input_box, outputs=[sentiment_output, label_output])
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # --------- Run App ---------
94
+ if __name__ == "__main__":
95
+ demo.launch()