jmprathab commited on
Commit
49e3fdb
·
verified ·
1 Parent(s): e27a584

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +224 -36
src/streamlit_app.py CHANGED
@@ -1,40 +1,228 @@
1
- import altair as alt
 
2
  import numpy as np
 
3
  import pandas as pd
4
  import streamlit as st
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import kagglehub
3
  import numpy as np
4
+ import os
5
  import pandas as pd
6
  import streamlit as st
7
+ import torch
8
+ import torch.nn as nn
9
 
10
+ MODEL_HANDLE = "prathabmurugan/dlgenai-emotion-classification/pyTorch/1a"
11
+ EMOTION_LABELS = ['anger', 'fear', 'joy', 'sadness', 'surprise']
12
+ THRESHOLDS = np.array([0.85, 0.43, 0.21, 0.7, 0.36])
13
+ MAX_LEN = 100
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+
17
+ class RobertaClassifier(nn.Module):
18
+ def __init__(self, model_name: str, num_labels: int, dropout: float = 0.3):
19
+ super().__init__()
20
+ self.roberta = AutoModel.from_pretrained(model_name)
21
+ hidden_size = self.roberta.config.hidden_size
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.classifier = nn.Linear(hidden_size, num_labels)
24
+
25
+ def forward(self, input_ids, attention_mask):
26
+ outputs = self.roberta(
27
+ input_ids=input_ids, attention_mask=attention_mask
28
+ )
29
+ pooled = outputs.pooler_output
30
+ pooled = self.dropout(pooled)
31
+ logits = self.classifier(pooled)
32
+ return logits
33
+
34
+
35
+ def standardize_space(text):
36
+ """Normalize whitespace in text."""
37
+ return " ".join(str(text).split())
38
+
39
+
40
+ @st.cache_resource
41
+ def load_resources():
42
+ status_container = st.empty()
43
+
44
+ # 1. Download Model Weights
45
+ status_container.info(
46
+ f"Downloading model weights from KaggleHub [{MODEL_HANDLE}]")
47
+ try:
48
+ model_dir = kagglehub.model_download(MODEL_HANDLE)
49
+ model_path = os.path.join(model_dir, "roberta_best_model.pth")
50
+
51
+ if not os.path.exists(model_path):
52
+ files = [f for f in os.listdir(model_dir) if f.endswith('.pth')]
53
+ if files:
54
+ model_path = os.path.join(model_dir, files[0])
55
+ else:
56
+ raise FileNotFoundError(
57
+ f"Could not find .pth file in [{model_dir}]")
58
+
59
+ except Exception as e:
60
+ status_container.error(f"Failed to download model [{e}]")
61
+ st.stop()
62
+
63
+ # 2. Initialize Architecture
64
+ status_container.info("Initializing RoBERTa architecture")
65
+ tokenizer = AutoTokenizer.from_pretrained("roberta-base")
66
+ model = RobertaClassifier("roberta-base", num_labels=5)
67
+
68
+ # 3. Load Weights
69
+ try:
70
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
71
+ model.to(DEVICE)
72
+ model.eval()
73
+ except Exception as e:
74
+ status_container.error(f"Error loading state dict [{e}]")
75
+ st.stop()
76
+
77
+ status_container.empty() # Clear the status messages
78
+ return model, tokenizer
79
+
80
+
81
+ def predict(texts, model, tokenizer):
82
+ # Preprocessing
83
+ processed_texts = [standardize_space(t) for t in texts]
84
+
85
+ # Tokenization
86
+ encodings = tokenizer(
87
+ processed_texts,
88
+ truncation=True,
89
+ max_length=MAX_LEN,
90
+ padding='max_length',
91
+ return_tensors='pt'
92
+ )
93
+
94
+ input_ids = encodings['input_ids'].to(DEVICE)
95
+ attention_mask = encodings['attention_mask'].to(DEVICE)
96
+
97
+ # Inference
98
+ with torch.no_grad():
99
+ logits = model(input_ids, attention_mask)
100
+ probs = torch.sigmoid(logits).cpu().numpy()
101
+
102
+ # Apply specific thresholds
103
+ preds = (probs > THRESHOLDS).astype(int)
104
+
105
+ return preds, probs
106
+
107
+
108
+ # Streamlit UI
109
+ st.set_page_config(page_title="Emotion Classifier", layout="centered")
110
+
111
+ st.title("Emotion Classification")
112
+ st.markdown(
113
+ "This app pulls a custom fine-tuned **RoBERTa** model from Kaggle to classify text into 5 emotions.")
114
+
115
+ # Load model
116
+ model, tokenizer = load_resources()
117
+
118
+ # Tabs for different input modes
119
+ tab1, tab2 = st.tabs(["Single Text Inference", "Batch CSV Inference"])
120
+
121
+ with tab1:
122
+ st.header("Test a single sentence")
123
+ user_input = st.text_area(
124
+ "Enter text here:", "Hello World!")
125
+
126
+ if st.button("Analyze Text", type="primary"):
127
+ if user_input.strip():
128
+ with st.spinner("Analyzing..."):
129
+ preds, probs = predict([user_input], model, tokenizer)
130
+
131
+ st.subheader("Results:")
132
+
133
+ # Display nicely
134
+ col1, col2 = st.columns(2)
135
+
136
+ with col1:
137
+ st.write("**Detected Emotions:**")
138
+ detected = []
139
+ for idx, is_present in enumerate(preds[0]):
140
+ if is_present:
141
+ detected.append(EMOTION_LABELS[idx].capitalize())
142
+
143
+ if detected:
144
+ for d in detected:
145
+ st.markdown(f"### ✅ {d}")
146
+ else:
147
+ st.markdown(
148
+ "*No specific emotion detected above thresholds.*")
149
+
150
+ with col2:
151
+ st.write("**Confidence Scores:**")
152
+ scores_df = pd.DataFrame({
153
+ "Emotion": EMOTION_LABELS,
154
+ "Score": probs[0],
155
+ "Threshold": THRESHOLDS,
156
+ "Detected": preds[0].astype(bool)
157
+ })
158
+ # Formatting the dataframe for visual appeal
159
+ st.dataframe(
160
+ scores_df.style.format(
161
+ {"Score": "{:.2%}", "Threshold": "{:.2f}"})
162
+ .background_gradient(subset=["Score"], cmap="Greens"),
163
+ hide_index=True,
164
+ use_container_width=True
165
+ )
166
+ else:
167
+ st.warning("Please enter some text.")
168
+
169
+ with tab2:
170
+ st.header("Batch Process (CSV)")
171
+ st.markdown("Upload a CSV file with a `text` and `id` column.")
172
+
173
+ uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
174
+
175
+ if uploaded_file is not None:
176
+ try:
177
+ input_df = pd.read_csv(uploaded_file)
178
+ if 'text' not in input_df.columns:
179
+ st.error("CSV must have a 'text' column.")
180
+ else:
181
+ st.info(
182
+ f"Loaded [{len(input_df)}] rows. Click below to start.")
183
+
184
+ if st.button("Generate Predictions"):
185
+ progress_bar = st.progress(0)
186
+ status_text = st.empty()
187
+
188
+ # Process in batches
189
+ batch_size = 16
190
+ all_preds = []
191
+ texts = input_df['text'].tolist()
192
+
193
+ for i in range(0, len(texts), batch_size):
194
+ batch_texts = texts[i:i + batch_size]
195
+ batch_preds, _ = predict(batch_texts, model, tokenizer)
196
+ all_preds.append(batch_preds)
197
+
198
+ # Update progress
199
+ progress = min((i + batch_size) / len(texts), 1.0)
200
+ progress_bar.progress(progress)
201
+ status_text.text(
202
+ f"Processed {i + len(batch_texts)}/{len(texts)} rows")
203
+
204
+ # Aggregate results
205
+ predictions_np = np.vstack(all_preds)
206
+ submission_df = pd.DataFrame(
207
+ predictions_np, columns=EMOTION_LABELS, dtype=int)
208
+
209
+ # Combine with original IDs
210
+ if 'id' in input_df.columns:
211
+ final_df = pd.concat(
212
+ [input_df[['id']], submission_df], axis=1)
213
+ else:
214
+ final_df = submission_df
215
+
216
+ st.success("Processing complete!")
217
+ st.dataframe(final_df.head(), use_container_width=True)
218
+
219
+ # Download button
220
+ csv = final_df.to_csv(index=False).encode('utf-8')
221
+ st.download_button(
222
+ label="Download Predictions CSV",
223
+ data=csv,
224
+ file_name="submission.csv",
225
+ mime="text/csv"
226
+ )
227
+ except Exception as e:
228
+ st.error(f"Error reading CSV: {e}")