Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import clip | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| import requests | |
| import tempfile | |
| import os | |
| from typing import List, Tuple | |
| # Configure page | |
| st.set_page_config( | |
| page_title="CLIP Classifier", | |
| page_icon="🔍", | |
| layout="wide" | |
| ) | |
| def load_clip_model(): | |
| """Load CLIP model and preprocessing function""" | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| return model, preprocess, device | |
| except Exception as e: | |
| st.error(f"Error loading CLIP model: {e}") | |
| return None, None, None | |
| def classify_input(model, preprocess, device, image_data, positive_prompts, negative_prompts): | |
| """ | |
| Classify image based on positive and negative prompts using CLIP | |
| """ | |
| try: | |
| # Prepare text prompts | |
| all_prompts = positive_prompts + negative_prompts | |
| text_inputs = clip.tokenize(all_prompts).to(device) | |
| # Process image | |
| if isinstance(image_data, str): # URL | |
| response = requests.get(image_data, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(io.BytesIO(response.content)) | |
| else: # PIL Image or uploaded file | |
| if hasattr(image_data, 'read'): | |
| # Handle Streamlit UploadedFile | |
| image_bytes = image_data.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| else: | |
| image = image_data | |
| # Convert to RGB if necessary | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_input = preprocess(image).unsqueeze(0).to(device) | |
| # Get features | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_input) | |
| text_features = model.encode_text(text_inputs) | |
| # Calculate similarities | |
| similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| similarities = similarities[0].cpu().numpy() | |
| # Calculate scores for positive and negative categories | |
| positive_scores = similarities[:len(positive_prompts)] | |
| negative_scores = similarities[len(positive_prompts):] | |
| positive_total = np.sum(positive_scores) | |
| negative_total = np.sum(negative_scores) | |
| # Determine classification | |
| is_positive = positive_total > negative_total | |
| confidence = max(positive_total, negative_total) | |
| return { | |
| 'classification': 'Positive' if is_positive else 'Negative', | |
| 'confidence': float(confidence), | |
| 'positive_score': float(positive_total), | |
| 'negative_score': float(negative_total), | |
| 'detailed_scores': { | |
| 'positive_prompts': [(prompt, float(score)) for prompt, score in zip(positive_prompts, positive_scores)], | |
| 'negative_prompts': [(prompt, float(score)) for prompt, score in zip(negative_prompts, negative_scores)] | |
| } | |
| } | |
| except Exception as e: | |
| st.error(f"Error during classification: {e}") | |
| return None | |
| def main(): | |
| st.title("CLIP-Based Custom Classifier") | |
| st.markdown("### Define your own positive and negative prompts to classify images!") | |
| # Load model | |
| with st.spinner("Loading CLIP model..."): | |
| model, preprocess, device = load_clip_model() | |
| if model is None: | |
| st.error("Failed to load CLIP model. Please check your installation.") | |
| st.stop() | |
| st.success(f"CLIP model loaded successfully on {device}") | |
| # Sidebar for configuration | |
| with st.sidebar: | |
| st.header("Configuration") | |
| st.header("Define Prompts") | |
| # Positive prompts | |
| st.subheader("Positive Prompts") | |
| positive_prompts_text = st.text_area( | |
| "Enter positive prompts (one per line):", | |
| value="happy face\nsmiling person\njoyful expression\npositive emotion", | |
| height=100, | |
| help="These prompts define what should be classified as 'Positive'" | |
| ) | |
| # Negative prompts | |
| st.subheader("Negative Prompts") | |
| negative_prompts_text = st.text_area( | |
| "Enter negative prompts (one per line):", | |
| value="sad face\nangry person\nfrowning expression\nnegative emotion", | |
| height=100, | |
| help="These prompts define what should be classified as 'Negative'" | |
| ) | |
| # Process prompts | |
| positive_prompts = [p.strip() for p in positive_prompts_text.split('\n') if p.strip()] | |
| negative_prompts = [p.strip() for p in negative_prompts_text.split('\n') if p.strip()] | |
| st.info(f"Positive prompts: {len(positive_prompts)}") | |
| st.info(f"Negative prompts: {len(negative_prompts)}") | |
| # Main content area | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.header("Input Image") | |
| # Tabs for different input methods | |
| tab1, tab2 = st.tabs(["Upload Image", "Image URL"]) | |
| image_data = None | |
| with tab1: | |
| # File uploader - simplified for HF Spaces | |
| uploaded_file = st.file_uploader( | |
| "Choose an image file", | |
| type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'], | |
| help="Upload an image file to classify", | |
| key="image_uploader" # Add explicit key | |
| ) | |
| if uploaded_file is not None: | |
| image_data = uploaded_file | |
| # Display image | |
| st.image(uploaded_file, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True) | |
| st.success("Image uploaded successfully!") | |
| with tab2: | |
| # URL input | |
| image_url = st.text_input( | |
| "Enter image URL:", | |
| placeholder="https://example.com/image.jpg", | |
| help="Enter a direct link to an image" | |
| ) | |
| if image_url.strip(): | |
| if not image_url.startswith(('http://', 'https://')): | |
| st.warning("Please enter a valid URL starting with http:// or https://") | |
| else: | |
| try: | |
| with st.spinner("Loading image..."): | |
| response = requests.get(image_url, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(io.BytesIO(response.content)) | |
| image_data = image_url | |
| st.image(image, caption="Image from URL", use_column_width=True) | |
| st.success("Image loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading image: {e}") | |
| with col2: | |
| st.header("Classification Results") | |
| # Status check | |
| ready_to_classify = ( | |
| image_data is not None and | |
| len(positive_prompts) > 0 and | |
| len(negative_prompts) > 0 | |
| ) | |
| if not positive_prompts or not negative_prompts: | |
| st.warning("Please define both positive and negative prompts in the sidebar.") | |
| elif image_data is None: | |
| st.info("Please provide an image to classify.") | |
| else: | |
| st.success("Ready to classify!") | |
| if st.button("Classify Image", type="primary", use_container_width=True): | |
| with st.spinner("Classifying..."): | |
| result = classify_input( | |
| model, preprocess, device, image_data, | |
| positive_prompts, negative_prompts | |
| ) | |
| if result: | |
| # Main classification result | |
| classification = result['classification'] | |
| confidence = result['confidence'] | |
| # Display result with color coding | |
| color = "green" if classification == "Positive" else "red" | |
| st.markdown(f"### Classification: <span style='color: {color}'>{classification}</span>", | |
| unsafe_allow_html=True) | |
| # Metrics | |
| col_conf, col_pos, col_neg = st.columns(3) | |
| with col_conf: | |
| st.metric("Confidence", f"{confidence:.3f}") | |
| with col_pos: | |
| st.metric("Positive Score", f"{result['positive_score']:.3f}") | |
| with col_neg: | |
| st.metric("Negative Score", f"{result['negative_score']:.3f}") | |
| # Detailed breakdown | |
| st.subheader("Detailed Scores") | |
| # Positive prompts scores | |
| with st.expander("Positive Prompts Scores", expanded=True): | |
| for prompt, score in result['detailed_scores']['positive_prompts']: | |
| st.progress(float(score), text=f"{prompt}: {score:.3f}") | |
| # Negative prompts scores | |
| with st.expander("Negative Prompts Scores", expanded=True): | |
| for prompt, score in result['detailed_scores']['negative_prompts']: | |
| st.progress(float(score), text=f"{prompt}: {score:.3f}") | |
| else: | |
| st.error("Classification failed. Please try again.") | |
| # Instructions | |
| with st.expander("How to use this app"): | |
| st.markdown(""" | |
| **Instructions:** | |
| 1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line) | |
| 2. **Upload Image**: Use either the file uploader or paste an image URL | |
| 3. **Classify**: Click the "Classify Image" button to see results | |
| **Example prompts:** | |
| - **Emotion detection**: "happy, smiling, joy" vs "sad, crying, anger" | |
| - **Object detection**: "dog, puppy, canine" vs "cat, kitten, feline" | |
| - **Content type**: "food, meal, cooking" vs "vehicle, car, transportation" | |
| **Tips for Hugging Face Spaces:** | |
| - Use common image formats (JPG, PNG, WebP) | |
| - For URLs, make sure they're publicly accessible | |
| - Keep image sizes reasonable for faster processing | |
| """) | |
| if __name__ == "__main__": | |
| main() |