ChatBotsTA commited on
Commit
5bf556d
Β·
verified Β·
1 Parent(s): e8c50a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+
5
+ # Set page config
6
+ st.set_page_config(
7
+ page_title="Tweet Classifier",
8
+ page_icon="🐦",
9
+ layout="wide"
10
+ )
11
+
12
+ # Custom CSS for better styling
13
+ st.markdown("""
14
+ <style>
15
+ .main-header {
16
+ font-size: 3rem;
17
+ color: #1DA1F2;
18
+ text-align: center;
19
+ margin-bottom: 2rem;
20
+ }
21
+ .result-box {
22
+ background-color: #f0f2f6;
23
+ padding: 2rem;
24
+ border-radius: 10px;
25
+ margin-top: 2rem;
26
+ }
27
+ .confidence-bar {
28
+ height: 20px;
29
+ background: linear-gradient(90deg, #ff4b4b 0%, #ffa500 50%, #00cc00 100%);
30
+ border-radius: 10px;
31
+ margin: 10px 0;
32
+ }
33
+ .label-badge {
34
+ padding: 0.5rem 1rem;
35
+ border-radius: 20px;
36
+ font-weight: bold;
37
+ margin: 0.2rem;
38
+ }
39
+ </style>
40
+ """, unsafe_allow_html=True)
41
+
42
+ # App title
43
+ st.markdown('<h1 class="main-header">🐦 Tweet Sentiment Classifier</h1>', unsafe_allow_html=True)
44
+ st.markdown("### Powered by your fine-tuned DistilBERT model (96.4% accuracy)")
45
+
46
+ # Initialize model (with caching)
47
+ @st.cache_resource
48
+ def load_model():
49
+ try:
50
+ # Load your fine-tuned model
51
+ model_name = "ChatBotsTA/distilbert-tweet-classifier"
52
+ classifier = pipeline(
53
+ "text-classification",
54
+ model=model_name,
55
+ tokenizer=model_name,
56
+ device=0 if torch.cuda.is_available() else -1
57
+ )
58
+ return classifier
59
+ except Exception as e:
60
+ st.error(f"Error loading model: {e}")
61
+ return None
62
+
63
+ # Load model
64
+ with st.spinner("πŸš€ Loading your fine-tuned model from Hugging Face..."):
65
+ classifier = load_model()
66
+
67
+ if classifier is None:
68
+ st.error("Could not load the model. Please check if the model exists on Hugging Face.")
69
+ st.stop()
70
+
71
+ # Label colors
72
+ label_colors = {
73
+ "positive": "🟒",
74
+ "negative": "πŸ”΄",
75
+ "litigious": "πŸ”΅",
76
+ "uncertainty": "🟑"
77
+ }
78
+
79
+ label_descriptions = {
80
+ "positive": "Positive sentiment/content",
81
+ "negative": "Negative sentiment",
82
+ "litigious": "Legal/contractual content",
83
+ "uncertainty": "Uncertain/ambiguous content"
84
+ }
85
+
86
+ # Input section
87
+ st.markdown("---")
88
+ st.markdown("## πŸ“ Enter Tweet Text to Analyze")
89
+
90
+ input_text = st.text_area(
91
+ "Paste tweet text here:",
92
+ height=150,
93
+ placeholder="Enter text to classify (e.g., 'This product is amazing!', 'I hate this service', 'The court case was dismissed')"
94
+ )
95
+
96
+ # Examples
97
+ with st.expander("πŸ’‘ Click for example texts"):
98
+ st.write("**Examples to try:**")
99
+ examples = [
100
+ "This is an amazing product! I love it!",
101
+ "I'm so frustrated with this service, terrible experience",
102
+ "The court case was dismissed due to lack of evidence",
103
+ "I'm not sure how I feel about this situation",
104
+ "This company's financial results exceeded all expectations",
105
+ "I might consider this option, but I need more information"
106
+ ]
107
+ for example in examples:
108
+ if st.button(example, key=example):
109
+ input_text = example
110
+
111
+ # Analyze button
112
+ if st.button("πŸ” Analyze Tweet", type="primary", use_container_width=True):
113
+ if input_text.strip():
114
+ with st.spinner("Analyzing..."):
115
+ try:
116
+ # Get prediction
117
+ result = classifier(input_text)[0]
118
+ label = result['label']
119
+ confidence = result['score']
120
+
121
+ # Display results
122
+ st.markdown("---")
123
+ st.markdown("## πŸ“Š Analysis Results")
124
+
125
+ # Result box
126
+ st.markdown(f'<div class="result-box">', unsafe_allow_html=True)
127
+
128
+ # Label and confidence
129
+ col1, col2 = st.columns([1, 2])
130
+ with col1:
131
+ st.markdown(f"### {label_colors.get(label, 'βšͺ')} **Prediction:**")
132
+ st.markdown(f'<span class="label-badge" style="background-color: {{
133
+ "positive": "#4CAF50",
134
+ "negative": "#F44336",
135
+ "litigious": "#2196F3",
136
+ "uncertainty": "#FFC107"
137
+ }.get(label, "#9E9E9E")}}; color: white;">{label.upper()}</span>', unsafe_allow_html=True)
138
+
139
+ with col2:
140
+ st.markdown(f"### πŸ“ˆ **Confidence:** {confidence:.1%}")
141
+ # Confidence bar
142
+ st.markdown(f'<div class="confidence-bar" style="width: {confidence*100}%;"></div>', unsafe_allow_html=True)
143
+
144
+ # Description
145
+ st.markdown(f"**Description:** {label_descriptions.get(label, '')}")
146
+
147
+ st.markdown('</div>', unsafe_allow_html=True)
148
+
149
+ # Raw scores (optional)
150
+ with st.expander("πŸ” View detailed scores"):
151
+ # Get all label scores
152
+ tokenizer = AutoTokenizer.from_pretrained("ChatBotsTA/distilbert-tweet-classifier")
153
+ model = AutoModelForSequenceClassification.from_pretrained("ChatBotsTA/distilbert-tweet-classifier")
154
+
155
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
156
+ with torch.no_grad():
157
+ outputs = model(**inputs)
158
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
159
+
160
+ scores = {model.config.id2label[i]: float(prob)
161
+ for i, prob in enumerate(probabilities[0])}
162
+
163
+ for lbl, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
164
+ st.write(f"{label_colors.get(lbl, 'βšͺ')} {lbl}: {score:.3f}")
165
+
166
+ except Exception as e:
167
+ st.error(f"Error during prediction: {e}")
168
+ else:
169
+ st.warning("Please enter some text to analyze!")
170
+
171
+ # Model info section
172
+ st.markdown("---")
173
+ st.markdown("## ℹ️ About This Model")
174
+
175
+ st.info("""
176
+ **Model Details:**
177
+ - **Base Model**: DistilBERT-base-uncased
178
+ - **Training**: Fine-tuned on 50,000 tweets
179
+ - **Accuracy**: 96.4% on validation set
180
+ - **Labels**: Positive, Negative, Litigious, Uncertainty
181
+ - **Created By**: You! 🎯
182
+
183
+ **How to use programmatically:**
184
+ ```python
185
+ from transformers import pipeline
186
+ classifier = pipeline("text-classification",
187
+ model="ChatBotsTA/distilbert-tweet-classifier")
188
+ result = classifier("Your text here")