VenujaDeSilva commited on
Commit
87f549a
·
verified ·
1 Parent(s): a0bfe60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -20
app.py CHANGED
@@ -3,10 +3,54 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import joblib
5
  import numpy as np
 
 
6
 
7
- # -----------------------
8
- # Load model + tokenizer
9
- # -----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @st.cache_resource
11
  def load_model():
12
  model = AutoModelForSequenceClassification.from_pretrained(".")
@@ -14,19 +58,16 @@ def load_model():
14
  return model, tokenizer
15
 
16
  model, tokenizer = load_model()
17
-
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model = model.to(device)
20
 
21
- # -----------------------
22
  # Load MultiLabelBinarizer
23
- # -----------------------
24
  mlb = joblib.load("mlb.joblib")
25
  labels = mlb.classes_
26
 
27
- # -----------------------
28
  # Prediction function
29
- # -----------------------
30
  def predict_tags(text, threshold=0.3):
31
  encoded = tokenizer(
32
  text,
@@ -46,24 +87,89 @@ def predict_tags(text, threshold=0.3):
46
 
47
  return predicted_tags, probs
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # -----------------------
51
- # Streamlit UI
52
- # -----------------------
53
- st.title("🔮 StackOverflow Tag Predictor (BERT)")
54
 
55
- user_text = st.text_area("Enter a StackOverflow question:", height=150)
56
- threshold = st.slider("Prediction threshold", 0.0, 1.0, 0.30)
57
 
58
- if st.button("Predict Tags"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if not user_text.strip():
60
- st.warning("Please enter a question")
61
  else:
62
- predicted_tags, probs = predict_tags(user_text, threshold)
 
63
 
 
 
64
  if len(predicted_tags) == 0:
65
- st.error("No tags predicted (try lowering the threshold)")
66
  else:
67
- st.subheader("🏷️ Predicted Tags:")
68
  for t in predicted_tags:
69
- st.success(t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import joblib
5
  import numpy as np
6
+ import pandas as pd
7
+ import altair as alt
8
 
9
+ # ---------------------------------------------------------
10
+ # Custom CSS for Fun, Colorful UI
11
+ # ---------------------------------------------------------
12
+ st.markdown("""
13
+ <style>
14
+ /* Animated gradient title */
15
+ .title-gradient {
16
+ font-size: 40px;
17
+ font-weight: 900;
18
+ text-align: center;
19
+ background: linear-gradient(90deg, #ff0080, #ff8c00, #40e0d0, #8a2be2);
20
+ -webkit-background-clip: text;
21
+ -webkit-text-fill-color: transparent;
22
+ animation: glow 4s ease-in-out infinite;
23
+ }
24
+
25
+ @keyframes glow {
26
+ 0% { filter: drop-shadow(0 0 2px #ff0080); }
27
+ 50% { filter: drop-shadow(0 0 8px #40e0d0); }
28
+ 100% { filter: drop-shadow(0 0 2px #ff0080); }
29
+ }
30
+
31
+ /* Tag pill styling */
32
+ .tag-pill {
33
+ display: inline-block;
34
+ padding: 8px 14px;
35
+ margin: 4px;
36
+ background-color: #ff6ec7;
37
+ color: white;
38
+ border-radius: 20px;
39
+ font-size: 14px;
40
+ font-weight: 600;
41
+ }
42
+
43
+ /* Centered subtle text */
44
+ .center {
45
+ text-align: center;
46
+ color: #666;
47
+ }
48
+ </style>
49
+ """, unsafe_allow_html=True)
50
+
51
+ # ---------------------------------------------------------
52
+ # Load Model + Tokenizer
53
+ # ---------------------------------------------------------
54
  @st.cache_resource
55
  def load_model():
56
  model = AutoModelForSequenceClassification.from_pretrained(".")
 
58
  return model, tokenizer
59
 
60
  model, tokenizer = load_model()
 
61
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
  model = model.to(device)
63
 
 
64
  # Load MultiLabelBinarizer
 
65
  mlb = joblib.load("mlb.joblib")
66
  labels = mlb.classes_
67
 
68
+ # ---------------------------------------------------------
69
  # Prediction function
70
+ # ---------------------------------------------------------
71
  def predict_tags(text, threshold=0.3):
72
  encoded = tokenizer(
73
  text,
 
87
 
88
  return predicted_tags, probs
89
 
90
+ # ---------------------------------------------------------
91
+ # 🎨 Sidebar
92
+ # ---------------------------------------------------------
93
+ st.sidebar.header("⚙️ Settings")
94
+
95
+ threshold = st.sidebar.slider(
96
+ "Prediction Threshold",
97
+ 0.0, 1.0, 0.30,
98
+ help="Lower = more tags, Higher = fewer but more confident"
99
+ )
100
+
101
+ st.sidebar.markdown("""
102
+ ### 🤖 Model Info
103
+ - BERT-based tag predictor
104
+ - Multi-label classification
105
+ - Trained on StackOverflow dataset
106
+ """)
107
+
108
+ st.sidebar.markdown("---")
109
+ st.sidebar.markdown("Made with ❤️ using Streamlit + Transformers")
110
 
111
+ # ---------------------------------------------------------
112
+ # 🎉 Title + Description
113
+ # ---------------------------------------------------------
 
114
 
115
+ st.markdown("<h1 class='title-gradient'>✨ StackOverflow Tag Predictor ✨</h1>", unsafe_allow_html=True)
116
+ st.markdown("<p class='center'>Ask any technical question and watch the magic happen! 🪄</p>", unsafe_allow_html=True)
117
 
118
+ # ---------------------------------------------------------
119
+ # Example Questions
120
+ # ---------------------------------------------------------
121
+ st.markdown("### 🎯 Try an example:")
122
+ examples = [
123
+ "How do I fix a TypeError in Python when concatenating lists?",
124
+ "What is the recommended way to deploy a React application?",
125
+ "Why does my SQL JOIN return duplicate rows?"
126
+ ]
127
+
128
+ cols = st.columns(len(examples))
129
+ for i, ex in enumerate(examples):
130
+ if cols[i].button(f"Example {i+1}"):
131
+ st.session_state["example_text"] = ex
132
+
133
+ user_text = st.text_area(
134
+ "✍️ Enter your StackOverflow question:",
135
+ value=st.session_state.get("example_text", ""),
136
+ height=150
137
+ )
138
+
139
+ # ---------------------------------------------------------
140
+ # Predict Button
141
+ # ---------------------------------------------------------
142
+ if st.button("🔮 Predict Tags!"):
143
  if not user_text.strip():
144
+ st.warning("Please enter a question first ✏️")
145
  else:
146
+ with st.spinner("✨ Analyzing your question… summoning the tag spirits… 🔮"):
147
+ predicted_tags, probs = predict_tags(user_text, threshold)
148
 
149
+ # Display tags
150
+ st.markdown("## 🏷️ Predicted Tags:")
151
  if len(predicted_tags) == 0:
152
+ st.error("😕 No tags predicted try lowering the threshold!")
153
  else:
 
154
  for t in predicted_tags:
155
+ st.markdown(f"<span class='tag-pill'>#{t}</span>", unsafe_allow_html=True)
156
+
157
+ # Probability Chart
158
+ st.markdown("### 📊 Tag Probability Chart")
159
+
160
+ df = pd.DataFrame({
161
+ "Tag": labels,
162
+ "Probability": probs
163
+ })
164
+
165
+ chart = alt.Chart(df).mark_bar(color="#ff6ec7").encode(
166
+ x="Probability:Q",
167
+ y=alt.Y("Tag:N", sort="-x")
168
+ ).properties(height=350)
169
+
170
+ st.altair_chart(chart, use_container_width=True)
171
+
172
+ # ---------------------------------------------------------
173
+ # Footer
174
+ # ---------------------------------------------------------
175
+ st.markdown("<p class='center'>✨ Powered by BERT • Hugging Face • Streamlit</p>", unsafe_allow_html=True)