Janish7 commited on
Commit
0cc8c93
·
verified ·
1 Parent(s): df7d3ef

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +107 -179
src/streamlit_app.py CHANGED
@@ -5,13 +5,10 @@ from PIL import Image
5
  import numpy as np
6
  import io
7
  import requests
 
8
  import os
9
  from typing import List, Tuple
10
 
11
- # Set cache directories to writable locations
12
- os.environ['TORCH_HOME'] = '/tmp/torch_cache'
13
- os.environ['HF_HOME'] = '/tmp/hf_cache'
14
-
15
  # Configure page
16
  st.set_page_config(
17
  page_title="CLIP Classifier",
@@ -23,92 +20,49 @@ st.set_page_config(
23
  def load_clip_model():
24
  """Load CLIP model and preprocessing function"""
25
  try:
26
- # Ensure cache directories exist
27
- os.makedirs('/tmp/torch_cache', exist_ok=True)
28
- os.makedirs('/tmp/clip_models', exist_ok=True)
29
-
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
- model, preprocess = clip.load("ViT-B/32", device=device, download_root="/tmp/clip_models")
32
  return model, preprocess, device
33
  except Exception as e:
34
  st.error(f"Error loading CLIP model: {e}")
35
  return None, None, None
36
 
37
- def classify_input(model, preprocess, device, input_data, positive_prompts, negative_prompts, input_type="image"):
38
  """
39
- Classify input based on positive and negative prompts using CLIP
40
  """
41
  try:
42
- # Debug information
43
- st.write(f"DEBUG: Input data type: {type(input_data)}")
44
- st.write(f"DEBUG: Input type: {input_type}")
45
-
46
  # Prepare text prompts
47
  all_prompts = positive_prompts + negative_prompts
48
  text_inputs = clip.tokenize(all_prompts).to(device)
49
 
50
- if input_type == "image":
51
- # Process image
52
- if isinstance(input_data, str): # URL
53
- st.write("DEBUG: Processing URL image")
54
- response = requests.get(input_data, timeout=10)
55
- response.raise_for_status()
56
- image = Image.open(io.BytesIO(response.content))
57
- elif isinstance(input_data, bytes): # Raw bytes
58
- st.write("DEBUG: Processing bytes image")
59
- image = Image.open(io.BytesIO(input_data))
60
- else: # UploadedFile object
61
- st.write("DEBUG: Processing UploadedFile object")
62
- # Try multiple methods to read the file
63
- try:
64
- # Method 1: Use getvalue()
65
- if hasattr(input_data, 'getvalue'):
66
- image_bytes = input_data.getvalue()
67
- image = Image.open(io.BytesIO(image_bytes))
68
- st.write("DEBUG: Successfully read using getvalue()")
69
- # Method 2: Use read()
70
- elif hasattr(input_data, 'read'):
71
- input_data.seek(0) # Reset to beginning
72
- image_bytes = input_data.read()
73
- image = Image.open(io.BytesIO(image_bytes))
74
- st.write("DEBUG: Successfully read using read()")
75
- else:
76
- st.error("DEBUG: Cannot read uploaded file")
77
- return None
78
- except Exception as read_error:
79
- st.error(f"DEBUG: Error reading file: {read_error}")
80
- return None
81
-
82
- # Convert to RGB if necessary
83
- if image.mode != 'RGB':
84
- image = image.convert('RGB')
85
- st.write(f"DEBUG: Converted image from {image.mode} to RGB")
86
-
87
- st.write(f"DEBUG: Image size: {image.size}")
88
-
89
- image_input = preprocess(image).unsqueeze(0).to(device)
90
-
91
- # Get features
92
- with torch.no_grad():
93
- image_features = model.encode_image(image_input)
94
- text_features = model.encode_text(text_inputs)
95
-
96
- # Calculate similarities
97
- similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1)
98
- similarities = similarities[0].cpu().numpy()
99
 
100
- elif input_type == "text":
101
- # Process text input
102
- st.write("DEBUG: Processing text input")
103
- input_text = clip.tokenize([input_data]).to(device)
 
 
 
 
 
 
104
 
105
- with torch.no_grad():
106
- input_features = model.encode_text(input_text)
107
- text_features = model.encode_text(text_inputs)
108
-
109
- # Calculate similarities
110
- similarities = (100.0 * input_features @ text_features.T).softmax(dim=-1)
111
- similarities = similarities[0].cpu().numpy()
112
 
113
  # Calculate scores for positive and negative categories
114
  positive_scores = similarities[:len(positive_prompts)]
@@ -121,8 +75,6 @@ def classify_input(model, preprocess, device, input_data, positive_prompts, nega
121
  is_positive = positive_total > negative_total
122
  confidence = max(positive_total, negative_total)
123
 
124
- st.write("DEBUG: Classification completed successfully")
125
-
126
  return {
127
  'classification': 'Positive' if is_positive else 'Negative',
128
  'confidence': float(confidence),
@@ -136,16 +88,15 @@ def classify_input(model, preprocess, device, input_data, positive_prompts, nega
136
 
137
  except Exception as e:
138
  st.error(f"Error during classification: {e}")
139
- import traceback
140
- st.error(f"Traceback: {traceback.format_exc()}")
141
  return None
142
 
143
  def main():
144
  st.title("CLIP-Based Custom Classifier")
145
- st.markdown("### Define your own positive and negative prompts to classify images or text!")
146
 
147
  # Load model
148
- model, preprocess, device = load_clip_model()
 
149
 
150
  if model is None:
151
  st.error("Failed to load CLIP model. Please check your installation.")
@@ -157,9 +108,6 @@ def main():
157
  with st.sidebar:
158
  st.header("Configuration")
159
 
160
- # Input type selection
161
- input_type = st.radio("Select input type:", ["Image", "Text"])
162
-
163
  st.header("Define Prompts")
164
 
165
  # Positive prompts
@@ -191,93 +139,76 @@ def main():
191
  col1, col2 = st.columns([1, 1])
192
 
193
  with col1:
194
- st.header("Input")
195
 
196
- input_data = None
 
197
 
198
- if input_type == "Image":
199
- # Image input options
200
- image_option = st.radio("Choose image source:", ["Upload", "URL"])
201
-
202
- if image_option == "Upload":
203
- uploaded_file = st.file_uploader(
204
- "Choose an image file",
205
- type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'],
206
- help="Upload an image file to classify"
207
- )
208
-
209
- if uploaded_file is not None:
210
- st.write(f"File name: {uploaded_file.name}")
211
- st.write(f"File type: {uploaded_file.type}")
212
- st.write(f"File size: {uploaded_file.size} bytes")
213
-
214
- # Store the uploaded file directly
215
- input_data = uploaded_file
216
-
217
- try:
218
- # Display the uploaded image using the file object
219
- st.image(uploaded_file, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
220
- st.success("Image uploaded successfully!")
221
- except Exception as e:
222
- st.error(f"Error displaying uploaded image: {e}")
223
- st.write(f"Error details: {str(e)}")
224
 
225
- else: # URL
226
- image_url = st.text_input("Enter image URL:", placeholder="https://example.com/image.jpg")
227
- if image_url.strip():
228
- if not image_url.startswith(('http://', 'https://')):
229
- st.warning("Please enter a valid URL starting with http:// or https://")
230
- else:
231
- try:
232
- with st.spinner("Loading image..."):
233
- response = requests.get(image_url, timeout=10)
234
- response.raise_for_status()
235
- image = Image.open(io.BytesIO(response.content))
236
- input_data = image_url
237
- st.image(image, caption="Image from URL", use_column_width=True)
238
- except requests.exceptions.RequestException as e:
239
- st.error(f"Error loading image from URL: {e}")
240
- except Exception as e:
241
- st.error(f"Error processing image: {e}")
242
 
243
- else: # Text input
244
- text_input = st.text_area(
245
- "Enter text to classify:",
246
- height=150,
247
- placeholder="Type your text here...",
248
- help="Enter the text you want to classify"
249
  )
250
- if text_input.strip():
251
- input_data = text_input.strip()
252
- st.text_area("Text to classify:", value=text_input, height=100, disabled=True)
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  with col2:
255
- st.header("Results")
256
 
257
- # Show current status
258
- st.write("Status Check:")
259
- st.write(f"- Input data available: {input_data is not None}")
260
- st.write(f"- Positive prompts: {len(positive_prompts) if positive_prompts else 0}")
261
- st.write(f"- Negative prompts: {len(negative_prompts) if negative_prompts else 0}")
 
262
 
263
- # Check if we have all required inputs
264
  if not positive_prompts or not negative_prompts:
265
  st.warning("Please define both positive and negative prompts in the sidebar.")
266
- elif not input_data:
267
- st.info("Please provide input data to classify.")
268
  else:
269
- if st.button("Classify", type="primary", use_container_width=True):
 
 
270
  with st.spinner("Classifying..."):
271
- st.write("Starting classification...")
272
  result = classify_input(
273
- model, preprocess, device, input_data,
274
- positive_prompts, negative_prompts,
275
- input_type.lower()
276
  )
277
 
278
  if result:
279
- st.write("Classification successful!")
280
-
281
  # Main classification result
282
  classification = result['classification']
283
  confidence = result['confidence']
@@ -287,10 +218,10 @@ def main():
287
  st.markdown(f"### Classification: <span style='color: {color}'>{classification}</span>",
288
  unsafe_allow_html=True)
289
 
290
- # Confidence and scores
291
- st.metric("Confidence", f"{confidence:.3f}")
292
-
293
- col_pos, col_neg = st.columns(2)
294
  with col_pos:
295
  st.metric("Positive Score", f"{result['positive_score']:.3f}")
296
  with col_neg:
@@ -300,37 +231,34 @@ def main():
300
  st.subheader("Detailed Scores")
301
 
302
  # Positive prompts scores
303
- st.write("**Positive Prompts:**")
304
- for prompt, score in result['detailed_scores']['positive_prompts']:
305
- st.progress(float(score), text=f"{prompt}: {score:.3f}")
306
 
307
  # Negative prompts scores
308
- st.write("**Negative Prompts:**")
309
- for prompt, score in result['detailed_scores']['negative_prompts']:
310
- st.progress(float(score), text=f"{prompt}: {score:.3f}")
311
  else:
312
- st.error("Classification failed. Check the debug messages above.")
313
 
314
  # Instructions
315
  with st.expander("How to use this app"):
316
  st.markdown("""
 
317
  1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line)
318
- 2. **Choose Input Type**: Select whether you want to classify images or text
319
- 3. **Provide Input**:
320
- - For images: Upload a file or provide a URL
321
- - For text: Type or paste your text
322
- 4. **Classify**: Click the "Classify" button to see results
323
 
324
- **Examples of prompts:**
325
- - **Image classification**: "happy dog, playful pet" vs "aggressive dog, angry animal"
326
- - **Text sentiment**: "positive review, good experience" vs "negative review, bad experience"
327
- - **Content moderation**: "safe content, family friendly" vs "inappropriate content, offensive material"
328
 
329
- **Troubleshooting:**
330
- - Make sure uploaded images are in supported formats (PNG, JPG, JPEG, GIF, BMP, WebP)
331
- - For URLs, ensure they start with http:// or https://
332
- - Check that both positive and negative prompts are defined
333
- - Look at the debug messages for detailed error information
334
  """)
335
 
336
  if __name__ == "__main__":
 
5
  import numpy as np
6
  import io
7
  import requests
8
+ import tempfile
9
  import os
10
  from typing import List, Tuple
11
 
 
 
 
 
12
  # Configure page
13
  st.set_page_config(
14
  page_title="CLIP Classifier",
 
20
  def load_clip_model():
21
  """Load CLIP model and preprocessing function"""
22
  try:
 
 
 
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model, preprocess = clip.load("ViT-B/32", device=device)
25
  return model, preprocess, device
26
  except Exception as e:
27
  st.error(f"Error loading CLIP model: {e}")
28
  return None, None, None
29
 
30
+ def classify_input(model, preprocess, device, image_data, positive_prompts, negative_prompts):
31
  """
32
+ Classify image based on positive and negative prompts using CLIP
33
  """
34
  try:
 
 
 
 
35
  # Prepare text prompts
36
  all_prompts = positive_prompts + negative_prompts
37
  text_inputs = clip.tokenize(all_prompts).to(device)
38
 
39
+ # Process image
40
+ if isinstance(image_data, str): # URL
41
+ response = requests.get(image_data, timeout=10)
42
+ response.raise_for_status()
43
+ image = Image.open(io.BytesIO(response.content))
44
+ else: # PIL Image or uploaded file
45
+ if hasattr(image_data, 'read'):
46
+ # Handle Streamlit UploadedFile
47
+ image_bytes = image_data.read()
48
+ image = Image.open(io.BytesIO(image_bytes))
49
+ else:
50
+ image = image_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Convert to RGB if necessary
53
+ if image.mode != 'RGB':
54
+ image = image.convert('RGB')
55
+
56
+ image_input = preprocess(image).unsqueeze(0).to(device)
57
+
58
+ # Get features
59
+ with torch.no_grad():
60
+ image_features = model.encode_image(image_input)
61
+ text_features = model.encode_text(text_inputs)
62
 
63
+ # Calculate similarities
64
+ similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1)
65
+ similarities = similarities[0].cpu().numpy()
 
 
 
 
66
 
67
  # Calculate scores for positive and negative categories
68
  positive_scores = similarities[:len(positive_prompts)]
 
75
  is_positive = positive_total > negative_total
76
  confidence = max(positive_total, negative_total)
77
 
 
 
78
  return {
79
  'classification': 'Positive' if is_positive else 'Negative',
80
  'confidence': float(confidence),
 
88
 
89
  except Exception as e:
90
  st.error(f"Error during classification: {e}")
 
 
91
  return None
92
 
93
  def main():
94
  st.title("CLIP-Based Custom Classifier")
95
+ st.markdown("### Define your own positive and negative prompts to classify images!")
96
 
97
  # Load model
98
+ with st.spinner("Loading CLIP model..."):
99
+ model, preprocess, device = load_clip_model()
100
 
101
  if model is None:
102
  st.error("Failed to load CLIP model. Please check your installation.")
 
108
  with st.sidebar:
109
  st.header("Configuration")
110
 
 
 
 
111
  st.header("Define Prompts")
112
 
113
  # Positive prompts
 
139
  col1, col2 = st.columns([1, 1])
140
 
141
  with col1:
142
+ st.header("Input Image")
143
 
144
+ # Tabs for different input methods
145
+ tab1, tab2 = st.tabs(["Upload Image", "Image URL"])
146
 
147
+ image_data = None
148
+
149
+ with tab1:
150
+ # File uploader - simplified for HF Spaces
151
+ uploaded_file = st.file_uploader(
152
+ "Choose an image file",
153
+ type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'],
154
+ help="Upload an image file to classify",
155
+ key="image_uploader" # Add explicit key
156
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ if uploaded_file is not None:
159
+ image_data = uploaded_file
160
+ # Display image
161
+ st.image(uploaded_file, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
162
+ st.success("Image uploaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ with tab2:
165
+ # URL input
166
+ image_url = st.text_input(
167
+ "Enter image URL:",
168
+ placeholder="https://example.com/image.jpg",
169
+ help="Enter a direct link to an image"
170
  )
171
+
172
+ if image_url.strip():
173
+ if not image_url.startswith(('http://', 'https://')):
174
+ st.warning("Please enter a valid URL starting with http:// or https://")
175
+ else:
176
+ try:
177
+ with st.spinner("Loading image..."):
178
+ response = requests.get(image_url, timeout=10)
179
+ response.raise_for_status()
180
+ image = Image.open(io.BytesIO(response.content))
181
+ image_data = image_url
182
+ st.image(image, caption="Image from URL", use_column_width=True)
183
+ st.success("Image loaded successfully!")
184
+ except Exception as e:
185
+ st.error(f"Error loading image: {e}")
186
 
187
  with col2:
188
+ st.header("Classification Results")
189
 
190
+ # Status check
191
+ ready_to_classify = (
192
+ image_data is not None and
193
+ len(positive_prompts) > 0 and
194
+ len(negative_prompts) > 0
195
+ )
196
 
 
197
  if not positive_prompts or not negative_prompts:
198
  st.warning("Please define both positive and negative prompts in the sidebar.")
199
+ elif image_data is None:
200
+ st.info("Please provide an image to classify.")
201
  else:
202
+ st.success("Ready to classify!")
203
+
204
+ if st.button("Classify Image", type="primary", use_container_width=True):
205
  with st.spinner("Classifying..."):
 
206
  result = classify_input(
207
+ model, preprocess, device, image_data,
208
+ positive_prompts, negative_prompts
 
209
  )
210
 
211
  if result:
 
 
212
  # Main classification result
213
  classification = result['classification']
214
  confidence = result['confidence']
 
218
  st.markdown(f"### Classification: <span style='color: {color}'>{classification}</span>",
219
  unsafe_allow_html=True)
220
 
221
+ # Metrics
222
+ col_conf, col_pos, col_neg = st.columns(3)
223
+ with col_conf:
224
+ st.metric("Confidence", f"{confidence:.3f}")
225
  with col_pos:
226
  st.metric("Positive Score", f"{result['positive_score']:.3f}")
227
  with col_neg:
 
231
  st.subheader("Detailed Scores")
232
 
233
  # Positive prompts scores
234
+ with st.expander("Positive Prompts Scores", expanded=True):
235
+ for prompt, score in result['detailed_scores']['positive_prompts']:
236
+ st.progress(float(score), text=f"{prompt}: {score:.3f}")
237
 
238
  # Negative prompts scores
239
+ with st.expander("Negative Prompts Scores", expanded=True):
240
+ for prompt, score in result['detailed_scores']['negative_prompts']:
241
+ st.progress(float(score), text=f"{prompt}: {score:.3f}")
242
  else:
243
+ st.error("Classification failed. Please try again.")
244
 
245
  # Instructions
246
  with st.expander("How to use this app"):
247
  st.markdown("""
248
+ **Instructions:**
249
  1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line)
250
+ 2. **Upload Image**: Use either the file uploader or paste an image URL
251
+ 3. **Classify**: Click the "Classify Image" button to see results
 
 
 
252
 
253
+ **Example prompts:**
254
+ - **Emotion detection**: "happy, smiling, joy" vs "sad, crying, anger"
255
+ - **Object detection**: "dog, puppy, canine" vs "cat, kitten, feline"
256
+ - **Content type**: "food, meal, cooking" vs "vehicle, car, transportation"
257
 
258
+ **Tips for Hugging Face Spaces:**
259
+ - Use common image formats (JPG, PNG, WebP)
260
+ - For URLs, make sure they're publicly accessible
261
+ - Keep image sizes reasonable for faster processing
 
262
  """)
263
 
264
  if __name__ == "__main__":