Janish7 commited on
Commit
fa19b8a
·
verified ·
1 Parent(s): 48280a9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +86 -64
src/streamlit_app.py CHANGED
@@ -39,34 +39,52 @@ def classify_input(model, preprocess, device, input_data, positive_prompts, nega
39
  Classify input based on positive and negative prompts using CLIP
40
  """
41
  try:
 
 
 
 
42
  # Prepare text prompts
43
  all_prompts = positive_prompts + negative_prompts
44
  text_inputs = clip.tokenize(all_prompts).to(device)
45
 
46
  if input_type == "image":
47
- # Process image - Fixed handling of uploaded files
48
  if isinstance(input_data, str): # URL
 
49
  response = requests.get(input_data, timeout=10)
50
- response.raise_for_status() # Raise an error for bad status codes
51
  image = Image.open(io.BytesIO(response.content))
52
- elif isinstance(input_data, bytes): # Uploaded file as bytes
53
- # Convert bytes directly to PIL Image
54
  image = Image.open(io.BytesIO(input_data))
55
- else:
56
- # Fallback for other formats
57
- if hasattr(input_data, 'read'):
58
- # It's a file-like object (UploadedFile)
59
- image_bytes = input_data.read()
60
- image = Image.open(io.BytesIO(image_bytes))
61
- # Reset file pointer for potential future reads
62
- input_data.seek(0)
63
- else:
64
- # If it's already a PIL Image or other format
65
- image = input_data
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Convert to RGB if necessary (handles different image modes)
68
  if image.mode != 'RGB':
69
  image = image.convert('RGB')
 
 
 
70
 
71
  image_input = preprocess(image).unsqueeze(0).to(device)
72
 
@@ -81,6 +99,7 @@ def classify_input(model, preprocess, device, input_data, positive_prompts, nega
81
 
82
  elif input_type == "text":
83
  # Process text input
 
84
  input_text = clip.tokenize([input_data]).to(device)
85
 
86
  with torch.no_grad():
@@ -102,6 +121,8 @@ def classify_input(model, preprocess, device, input_data, positive_prompts, nega
102
  is_positive = positive_total > negative_total
103
  confidence = max(positive_total, negative_total)
104
 
 
 
105
  return {
106
  'classification': 'Positive' if is_positive else 'Negative',
107
  'confidence': float(confidence),
@@ -115,10 +136,12 @@ def classify_input(model, preprocess, device, input_data, positive_prompts, nega
115
 
116
  except Exception as e:
117
  st.error(f"Error during classification: {e}")
 
 
118
  return None
119
 
120
  def main():
121
- st.title("🔍 CLIP-Based Custom Classifier")
122
  st.markdown("### Define your own positive and negative prompts to classify images or text!")
123
 
124
  # Load model
@@ -132,15 +155,15 @@ def main():
132
 
133
  # Sidebar for configuration
134
  with st.sidebar:
135
- st.header("⚙️ Configuration")
136
 
137
  # Input type selection
138
  input_type = st.radio("Select input type:", ["Image", "Text"])
139
 
140
- st.header("📝 Define Prompts")
141
 
142
  # Positive prompts
143
- st.subheader("Positive Prompts")
144
  positive_prompts_text = st.text_area(
145
  "Enter positive prompts (one per line):",
146
  value="happy face\nsmiling person\njoyful expression\npositive emotion",
@@ -149,7 +172,7 @@ def main():
149
  )
150
 
151
  # Negative prompts
152
- st.subheader("Negative Prompts")
153
  negative_prompts_text = st.text_area(
154
  "Enter negative prompts (one per line):",
155
  value="sad face\nangry person\nfrowning expression\nnegative emotion",
@@ -168,7 +191,7 @@ def main():
168
  col1, col2 = st.columns([1, 1])
169
 
170
  with col1:
171
- st.header("📥 Input")
172
 
173
  input_data = None
174
 
@@ -179,42 +202,43 @@ def main():
179
  if image_option == "Upload":
180
  uploaded_file = st.file_uploader(
181
  "Choose an image file",
182
- type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'], # Added webp support
183
  help="Upload an image file to classify"
184
  )
185
- if uploaded_file is not None: # More explicit check
 
 
 
 
 
 
 
 
186
  try:
187
- # Read the file data once and store it
188
- file_bytes = uploaded_file.getvalue() # This gets the bytes without moving file pointer
189
- input_data = file_bytes # Store bytes instead of file object
190
-
191
- # Display the uploaded image
192
- st.image(file_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
193
- # Show file details
194
- st.info(f"File size: {len(file_bytes)} bytes")
195
- st.success("✅ Image uploaded successfully!")
196
  except Exception as e:
197
- st.error(f"Error processing uploaded image: {e}")
198
- input_data = None
199
 
200
  else: # URL
201
  image_url = st.text_input("Enter image URL:", placeholder="https://example.com/image.jpg")
202
- if image_url.strip(): # Check for non-empty URL
203
- try:
204
- # Add basic URL validation
205
- if not image_url.startswith(('http://', 'https://')):
206
- st.warning("Please enter a valid URL starting with http:// or https://")
207
- else:
208
  with st.spinner("Loading image..."):
209
  response = requests.get(image_url, timeout=10)
210
  response.raise_for_status()
211
  image = Image.open(io.BytesIO(response.content))
212
  input_data = image_url
213
  st.image(image, caption="Image from URL", use_column_width=True)
214
- except requests.exceptions.RequestException as e:
215
- st.error(f"Error loading image from URL: {e}")
216
- except Exception as e:
217
- st.error(f"Error processing image: {e}")
218
 
219
  else: # Text input
220
  text_input = st.text_area(
@@ -228,16 +252,23 @@ def main():
228
  st.text_area("Text to classify:", value=text_input, height=100, disabled=True)
229
 
230
  with col2:
231
- st.header("📊 Results")
 
 
 
 
 
 
232
 
233
  # Check if we have all required inputs
234
  if not positive_prompts or not negative_prompts:
235
- st.warning("⚠️ Please define both positive and negative prompts in the sidebar.")
236
  elif not input_data:
237
- st.info("📝 Please provide input data to classify.")
238
  else:
239
- if st.button("🚀 Classify", type="primary", use_container_width=True):
240
  with st.spinner("Classifying..."):
 
241
  result = classify_input(
242
  model, preprocess, device, input_data,
243
  positive_prompts, negative_prompts,
@@ -245,6 +276,8 @@ def main():
245
  )
246
 
247
  if result:
 
 
248
  # Main classification result
249
  classification = result['classification']
250
  confidence = result['confidence']
@@ -264,7 +297,7 @@ def main():
264
  st.metric("Negative Score", f"{result['negative_score']:.3f}")
265
 
266
  # Detailed breakdown
267
- st.subheader("📈 Detailed Scores")
268
 
269
  # Positive prompts scores
270
  st.write("**Positive Prompts:**")
@@ -276,10 +309,10 @@ def main():
276
  for prompt, score in result['detailed_scores']['negative_prompts']:
277
  st.progress(float(score), text=f"{prompt}: {score:.3f}")
278
  else:
279
- st.error("Classification failed. Please check your input and try again.")
280
 
281
  # Instructions
282
- with st.expander("ℹ️ How to use this app"):
283
  st.markdown("""
284
  1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line)
285
  2. **Choose Input Type**: Select whether you want to classify images or text
@@ -297,19 +330,8 @@ def main():
297
  - Make sure uploaded images are in supported formats (PNG, JPG, JPEG, GIF, BMP, WebP)
298
  - For URLs, ensure they start with http:// or https://
299
  - Check that both positive and negative prompts are defined
 
300
  """)
301
-
302
- # Debug information (can be removed in production)
303
- if st.checkbox("Show debug info", help="Check this to see debug information"):
304
- st.subheader("Debug Information")
305
- st.write(f"Device: {device}")
306
- st.write(f"Input type: {input_type}")
307
- st.write(f"Input data: {input_data is not None}")
308
- if input_data and hasattr(input_data, '__len__'):
309
- st.write(f"Input data length: {len(input_data)}")
310
- st.write(f"Input data type: {type(input_data)}")
311
- st.write(f"Positive prompts count: {len(positive_prompts) if positive_prompts else 0}")
312
- st.write(f"Negative prompts count: {len(negative_prompts) if negative_prompts else 0}")
313
 
314
  if __name__ == "__main__":
315
  main()
 
39
  Classify input based on positive and negative prompts using CLIP
40
  """
41
  try:
42
+ # Debug information
43
+ st.write(f"DEBUG: Input data type: {type(input_data)}")
44
+ st.write(f"DEBUG: Input type: {input_type}")
45
+
46
  # Prepare text prompts
47
  all_prompts = positive_prompts + negative_prompts
48
  text_inputs = clip.tokenize(all_prompts).to(device)
49
 
50
  if input_type == "image":
51
+ # Process image
52
  if isinstance(input_data, str): # URL
53
+ st.write("DEBUG: Processing URL image")
54
  response = requests.get(input_data, timeout=10)
55
+ response.raise_for_status()
56
  image = Image.open(io.BytesIO(response.content))
57
+ elif isinstance(input_data, bytes): # Raw bytes
58
+ st.write("DEBUG: Processing bytes image")
59
  image = Image.open(io.BytesIO(input_data))
60
+ else: # UploadedFile object
61
+ st.write("DEBUG: Processing UploadedFile object")
62
+ # Try multiple methods to read the file
63
+ try:
64
+ # Method 1: Use getvalue()
65
+ if hasattr(input_data, 'getvalue'):
66
+ image_bytes = input_data.getvalue()
67
+ image = Image.open(io.BytesIO(image_bytes))
68
+ st.write("DEBUG: Successfully read using getvalue()")
69
+ # Method 2: Use read()
70
+ elif hasattr(input_data, 'read'):
71
+ input_data.seek(0) # Reset to beginning
72
+ image_bytes = input_data.read()
73
+ image = Image.open(io.BytesIO(image_bytes))
74
+ st.write("DEBUG: Successfully read using read()")
75
+ else:
76
+ st.error("DEBUG: Cannot read uploaded file")
77
+ return None
78
+ except Exception as read_error:
79
+ st.error(f"DEBUG: Error reading file: {read_error}")
80
+ return None
81
 
82
+ # Convert to RGB if necessary
83
  if image.mode != 'RGB':
84
  image = image.convert('RGB')
85
+ st.write(f"DEBUG: Converted image from {image.mode} to RGB")
86
+
87
+ st.write(f"DEBUG: Image size: {image.size}")
88
 
89
  image_input = preprocess(image).unsqueeze(0).to(device)
90
 
 
99
 
100
  elif input_type == "text":
101
  # Process text input
102
+ st.write("DEBUG: Processing text input")
103
  input_text = clip.tokenize([input_data]).to(device)
104
 
105
  with torch.no_grad():
 
121
  is_positive = positive_total > negative_total
122
  confidence = max(positive_total, negative_total)
123
 
124
+ st.write("DEBUG: Classification completed successfully")
125
+
126
  return {
127
  'classification': 'Positive' if is_positive else 'Negative',
128
  'confidence': float(confidence),
 
136
 
137
  except Exception as e:
138
  st.error(f"Error during classification: {e}")
139
+ import traceback
140
+ st.error(f"Traceback: {traceback.format_exc()}")
141
  return None
142
 
143
  def main():
144
+ st.title("CLIP-Based Custom Classifier")
145
  st.markdown("### Define your own positive and negative prompts to classify images or text!")
146
 
147
  # Load model
 
155
 
156
  # Sidebar for configuration
157
  with st.sidebar:
158
+ st.header("Configuration")
159
 
160
  # Input type selection
161
  input_type = st.radio("Select input type:", ["Image", "Text"])
162
 
163
+ st.header("Define Prompts")
164
 
165
  # Positive prompts
166
+ st.subheader("Positive Prompts")
167
  positive_prompts_text = st.text_area(
168
  "Enter positive prompts (one per line):",
169
  value="happy face\nsmiling person\njoyful expression\npositive emotion",
 
172
  )
173
 
174
  # Negative prompts
175
+ st.subheader("Negative Prompts")
176
  negative_prompts_text = st.text_area(
177
  "Enter negative prompts (one per line):",
178
  value="sad face\nangry person\nfrowning expression\nnegative emotion",
 
191
  col1, col2 = st.columns([1, 1])
192
 
193
  with col1:
194
+ st.header("Input")
195
 
196
  input_data = None
197
 
 
202
  if image_option == "Upload":
203
  uploaded_file = st.file_uploader(
204
  "Choose an image file",
205
+ type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'],
206
  help="Upload an image file to classify"
207
  )
208
+
209
+ if uploaded_file is not None:
210
+ st.write(f"File name: {uploaded_file.name}")
211
+ st.write(f"File type: {uploaded_file.type}")
212
+ st.write(f"File size: {uploaded_file.size} bytes")
213
+
214
+ # Store the uploaded file directly
215
+ input_data = uploaded_file
216
+
217
  try:
218
+ # Display the uploaded image using the file object
219
+ st.image(uploaded_file, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
220
+ st.success("Image uploaded successfully!")
 
 
 
 
 
 
221
  except Exception as e:
222
+ st.error(f"Error displaying uploaded image: {e}")
223
+ st.write(f"Error details: {str(e)}")
224
 
225
  else: # URL
226
  image_url = st.text_input("Enter image URL:", placeholder="https://example.com/image.jpg")
227
+ if image_url.strip():
228
+ if not image_url.startswith(('http://', 'https://')):
229
+ st.warning("Please enter a valid URL starting with http:// or https://")
230
+ else:
231
+ try:
 
232
  with st.spinner("Loading image..."):
233
  response = requests.get(image_url, timeout=10)
234
  response.raise_for_status()
235
  image = Image.open(io.BytesIO(response.content))
236
  input_data = image_url
237
  st.image(image, caption="Image from URL", use_column_width=True)
238
+ except requests.exceptions.RequestException as e:
239
+ st.error(f"Error loading image from URL: {e}")
240
+ except Exception as e:
241
+ st.error(f"Error processing image: {e}")
242
 
243
  else: # Text input
244
  text_input = st.text_area(
 
252
  st.text_area("Text to classify:", value=text_input, height=100, disabled=True)
253
 
254
  with col2:
255
+ st.header("Results")
256
+
257
+ # Show current status
258
+ st.write("Status Check:")
259
+ st.write(f"- Input data available: {input_data is not None}")
260
+ st.write(f"- Positive prompts: {len(positive_prompts) if positive_prompts else 0}")
261
+ st.write(f"- Negative prompts: {len(negative_prompts) if negative_prompts else 0}")
262
 
263
  # Check if we have all required inputs
264
  if not positive_prompts or not negative_prompts:
265
+ st.warning("Please define both positive and negative prompts in the sidebar.")
266
  elif not input_data:
267
+ st.info("Please provide input data to classify.")
268
  else:
269
+ if st.button("Classify", type="primary", use_container_width=True):
270
  with st.spinner("Classifying..."):
271
+ st.write("Starting classification...")
272
  result = classify_input(
273
  model, preprocess, device, input_data,
274
  positive_prompts, negative_prompts,
 
276
  )
277
 
278
  if result:
279
+ st.write("Classification successful!")
280
+
281
  # Main classification result
282
  classification = result['classification']
283
  confidence = result['confidence']
 
297
  st.metric("Negative Score", f"{result['negative_score']:.3f}")
298
 
299
  # Detailed breakdown
300
+ st.subheader("Detailed Scores")
301
 
302
  # Positive prompts scores
303
  st.write("**Positive Prompts:**")
 
309
  for prompt, score in result['detailed_scores']['negative_prompts']:
310
  st.progress(float(score), text=f"{prompt}: {score:.3f}")
311
  else:
312
+ st.error("Classification failed. Check the debug messages above.")
313
 
314
  # Instructions
315
+ with st.expander("How to use this app"):
316
  st.markdown("""
317
  1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line)
318
  2. **Choose Input Type**: Select whether you want to classify images or text
 
330
  - Make sure uploaded images are in supported formats (PNG, JPG, JPEG, GIF, BMP, WebP)
331
  - For URLs, ensure they start with http:// or https://
332
  - Check that both positive and negative prompts are defined
333
+ - Look at the debug messages for detailed error information
334
  """)
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  if __name__ == "__main__":
337
  main()