Janish7 commited on
Commit
838736c
Β·
verified Β·
1 Parent(s): bdbfd2c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +77 -31
src/streamlit_app.py CHANGED
@@ -15,6 +15,7 @@ os.environ['HF_HOME'] = '/tmp/hf_cache'
15
  # Configure page
16
  st.set_page_config(
17
  page_title="CLIP Classifier",
 
18
  layout="wide"
19
  )
20
 
@@ -43,12 +44,26 @@ def classify_input(model, preprocess, device, input_data, positive_prompts, nega
43
  text_inputs = clip.tokenize(all_prompts).to(device)
44
 
45
  if input_type == "image":
46
- # Process image
47
  if isinstance(input_data, str): # URL
48
- response = requests.get(input_data)
 
49
  image = Image.open(io.BytesIO(response.content))
50
- else: # Uploaded file
51
- image = Image.open(input_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  image_input = preprocess(image).unsqueeze(0).to(device)
54
 
@@ -114,15 +129,15 @@ def main():
114
 
115
  # Sidebar for configuration
116
  with st.sidebar:
117
- st.header(" Configuration")
118
 
119
  # Input type selection
120
  input_type = st.radio("Select input type:", ["Image", "Text"])
121
 
122
- st.header(" Define Prompts")
123
 
124
  # Positive prompts
125
- st.subheader("Positive Prompts")
126
  positive_prompts_text = st.text_area(
127
  "Enter positive prompts (one per line):",
128
  value="happy face\nsmiling person\njoyful expression\npositive emotion",
@@ -131,7 +146,7 @@ def main():
131
  )
132
 
133
  # Negative prompts
134
- st.subheader("Negative Prompts")
135
  negative_prompts_text = st.text_area(
136
  "Enter negative prompts (one per line):",
137
  value="sad face\nangry person\nfrowning expression\nnegative emotion",
@@ -150,7 +165,7 @@ def main():
150
  col1, col2 = st.columns([1, 1])
151
 
152
  with col1:
153
- st.header(" Input")
154
 
155
  input_data = None
156
 
@@ -161,38 +176,59 @@ def main():
161
  if image_option == "Upload":
162
  uploaded_file = st.file_uploader(
163
  "Choose an image file",
164
- type=['png', 'jpg', 'jpeg', 'gif', 'bmp']
 
165
  )
166
- if uploaded_file:
167
  input_data = uploaded_file
168
- st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
 
 
 
 
 
 
169
 
170
  else: # URL
171
- image_url = st.text_input("Enter image URL:")
172
- if image_url:
173
  try:
174
- response = requests.get(image_url)
175
- image = Image.open(io.BytesIO(response.content))
176
- input_data = image_url
177
- st.image(image, caption="Image from URL", use_column_width=True)
178
- except Exception as e:
 
 
 
 
 
 
179
  st.error(f"Error loading image from URL: {e}")
 
 
180
 
181
  else: # Text input
182
  text_input = st.text_area(
183
  "Enter text to classify:",
184
  height=150,
185
- placeholder="Type your text here..."
 
186
  )
187
  if text_input.strip():
188
  input_data = text_input.strip()
189
  st.text_area("Text to classify:", value=text_input, height=100, disabled=True)
190
 
191
  with col2:
192
- st.header(" Results")
193
 
194
- if input_data and positive_prompts and negative_prompts:
195
- if st.button(" Classify", type="primary", use_container_width=True):
 
 
 
 
 
196
  with st.spinner("Classifying..."):
197
  result = classify_input(
198
  model, preprocess, device, input_data,
@@ -220,7 +256,7 @@ def main():
220
  st.metric("Negative Score", f"{result['negative_score']:.3f}")
221
 
222
  # Detailed breakdown
223
- st.subheader(" Detailed Scores")
224
 
225
  # Positive prompts scores
226
  st.write("**Positive Prompts:**")
@@ -231,15 +267,11 @@ def main():
231
  st.write("**Negative Prompts:**")
232
  for prompt, score in result['detailed_scores']['negative_prompts']:
233
  st.progress(float(score), text=f"{prompt}: {score:.3f}")
234
-
235
- elif not positive_prompts or not negative_prompts:
236
- st.warning(" Please define both positive and negative prompts in the sidebar.")
237
-
238
- elif not input_data:
239
- st.info(" Please provide input data to classify.")
240
 
241
  # Instructions
242
- with st.expander(" How to use this app"):
243
  st.markdown("""
244
  1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line)
245
  2. **Choose Input Type**: Select whether you want to classify images or text
@@ -252,7 +284,21 @@ def main():
252
  - **Image classification**: "happy dog, playful pet" vs "aggressive dog, angry animal"
253
  - **Text sentiment**: "positive review, good experience" vs "negative review, bad experience"
254
  - **Content moderation**: "safe content, family friendly" vs "inappropriate content, offensive material"
 
 
 
 
 
255
  """)
 
 
 
 
 
 
 
 
 
256
 
257
  if __name__ == "__main__":
258
  main()
 
15
  # Configure page
16
  st.set_page_config(
17
  page_title="CLIP Classifier",
18
+ page_icon="πŸ”",
19
  layout="wide"
20
  )
21
 
 
44
  text_inputs = clip.tokenize(all_prompts).to(device)
45
 
46
  if input_type == "image":
47
+ # Process image - Fixed handling of uploaded files
48
  if isinstance(input_data, str): # URL
49
+ response = requests.get(input_data, timeout=10)
50
+ response.raise_for_status() # Raise an error for bad status codes
51
  image = Image.open(io.BytesIO(response.content))
52
+ else: # Uploaded file - this is the key fix
53
+ # For uploaded files, we need to read the bytes and convert to PIL Image
54
+ if hasattr(input_data, 'read'):
55
+ # It's a file-like object (UploadedFile)
56
+ image_bytes = input_data.read()
57
+ image = Image.open(io.BytesIO(image_bytes))
58
+ # Reset file pointer for potential future reads
59
+ input_data.seek(0)
60
+ else:
61
+ # If it's already a PIL Image or other format
62
+ image = input_data
63
+
64
+ # Convert to RGB if necessary (handles different image modes)
65
+ if image.mode != 'RGB':
66
+ image = image.convert('RGB')
67
 
68
  image_input = preprocess(image).unsqueeze(0).to(device)
69
 
 
129
 
130
  # Sidebar for configuration
131
  with st.sidebar:
132
+ st.header("βš™οΈ Configuration")
133
 
134
  # Input type selection
135
  input_type = st.radio("Select input type:", ["Image", "Text"])
136
 
137
+ st.header("πŸ“ Define Prompts")
138
 
139
  # Positive prompts
140
+ st.subheader("βœ… Positive Prompts")
141
  positive_prompts_text = st.text_area(
142
  "Enter positive prompts (one per line):",
143
  value="happy face\nsmiling person\njoyful expression\npositive emotion",
 
146
  )
147
 
148
  # Negative prompts
149
+ st.subheader("❌ Negative Prompts")
150
  negative_prompts_text = st.text_area(
151
  "Enter negative prompts (one per line):",
152
  value="sad face\nangry person\nfrowning expression\nnegative emotion",
 
165
  col1, col2 = st.columns([1, 1])
166
 
167
  with col1:
168
+ st.header("πŸ“₯ Input")
169
 
170
  input_data = None
171
 
 
176
  if image_option == "Upload":
177
  uploaded_file = st.file_uploader(
178
  "Choose an image file",
179
+ type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'], # Added webp support
180
+ help="Upload an image file to classify"
181
  )
182
+ if uploaded_file is not None: # More explicit check
183
  input_data = uploaded_file
184
+ try:
185
+ # Display the uploaded image
186
+ st.image(uploaded_file, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
187
+ # Show file details
188
+ st.info(f"File size: {len(uploaded_file.getvalue())} bytes")
189
+ except Exception as e:
190
+ st.error(f"Error displaying uploaded image: {e}")
191
 
192
  else: # URL
193
+ image_url = st.text_input("Enter image URL:", placeholder="https://example.com/image.jpg")
194
+ if image_url.strip(): # Check for non-empty URL
195
  try:
196
+ # Add basic URL validation
197
+ if not image_url.startswith(('http://', 'https://')):
198
+ st.warning("Please enter a valid URL starting with http:// or https://")
199
+ else:
200
+ with st.spinner("Loading image..."):
201
+ response = requests.get(image_url, timeout=10)
202
+ response.raise_for_status()
203
+ image = Image.open(io.BytesIO(response.content))
204
+ input_data = image_url
205
+ st.image(image, caption="Image from URL", use_column_width=True)
206
+ except requests.exceptions.RequestException as e:
207
  st.error(f"Error loading image from URL: {e}")
208
+ except Exception as e:
209
+ st.error(f"Error processing image: {e}")
210
 
211
  else: # Text input
212
  text_input = st.text_area(
213
  "Enter text to classify:",
214
  height=150,
215
+ placeholder="Type your text here...",
216
+ help="Enter the text you want to classify"
217
  )
218
  if text_input.strip():
219
  input_data = text_input.strip()
220
  st.text_area("Text to classify:", value=text_input, height=100, disabled=True)
221
 
222
  with col2:
223
+ st.header("πŸ“Š Results")
224
 
225
+ # Check if we have all required inputs
226
+ if not positive_prompts or not negative_prompts:
227
+ st.warning("⚠️ Please define both positive and negative prompts in the sidebar.")
228
+ elif not input_data:
229
+ st.info("πŸ“ Please provide input data to classify.")
230
+ else:
231
+ if st.button("πŸš€ Classify", type="primary", use_container_width=True):
232
  with st.spinner("Classifying..."):
233
  result = classify_input(
234
  model, preprocess, device, input_data,
 
256
  st.metric("Negative Score", f"{result['negative_score']:.3f}")
257
 
258
  # Detailed breakdown
259
+ st.subheader("πŸ“ˆ Detailed Scores")
260
 
261
  # Positive prompts scores
262
  st.write("**Positive Prompts:**")
 
267
  st.write("**Negative Prompts:**")
268
  for prompt, score in result['detailed_scores']['negative_prompts']:
269
  st.progress(float(score), text=f"{prompt}: {score:.3f}")
270
+ else:
271
+ st.error("Classification failed. Please check your input and try again.")
 
 
 
 
272
 
273
  # Instructions
274
+ with st.expander("ℹ️ How to use this app"):
275
  st.markdown("""
276
  1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line)
277
  2. **Choose Input Type**: Select whether you want to classify images or text
 
284
  - **Image classification**: "happy dog, playful pet" vs "aggressive dog, angry animal"
285
  - **Text sentiment**: "positive review, good experience" vs "negative review, bad experience"
286
  - **Content moderation**: "safe content, family friendly" vs "inappropriate content, offensive material"
287
+
288
+ **Troubleshooting:**
289
+ - Make sure uploaded images are in supported formats (PNG, JPG, JPEG, GIF, BMP, WebP)
290
+ - For URLs, ensure they start with http:// or https://
291
+ - Check that both positive and negative prompts are defined
292
  """)
293
+
294
+ # Debug information (can be removed in production)
295
+ if st.checkbox("Show debug info", help="Check this to see debug information"):
296
+ st.subheader("Debug Information")
297
+ st.write(f"Device: {device}")
298
+ st.write(f"Input type: {input_type}")
299
+ st.write(f"Input data type: {type(input_data)}")
300
+ st.write(f"Positive prompts count: {len(positive_prompts) if positive_prompts else 0}")
301
+ st.write(f"Negative prompts count: {len(negative_prompts) if negative_prompts else 0}")
302
 
303
  if __name__ == "__main__":
304
  main()