import streamlit as st import torch import clip from PIL import Image import numpy as np import io import requests import tempfile import os from typing import List, Tuple # Configure page st.set_page_config( page_title="CLIP Classifier", page_icon="🔍", layout="wide" ) @st.cache_resource def load_clip_model(): """Load CLIP model and preprocessing function""" try: device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) return model, preprocess, device except Exception as e: st.error(f"Error loading CLIP model: {e}") return None, None, None def classify_input(model, preprocess, device, image_data, positive_prompts, negative_prompts): """ Classify image based on positive and negative prompts using CLIP """ try: # Prepare text prompts all_prompts = positive_prompts + negative_prompts text_inputs = clip.tokenize(all_prompts).to(device) # Process image if isinstance(image_data, str): # URL response = requests.get(image_data, timeout=10) response.raise_for_status() image = Image.open(io.BytesIO(response.content)) else: # PIL Image or uploaded file if hasattr(image_data, 'read'): # Handle Streamlit UploadedFile image_bytes = image_data.read() image = Image.open(io.BytesIO(image_bytes)) else: image = image_data # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') image_input = preprocess(image).unsqueeze(0).to(device) # Get features with torch.no_grad(): image_features = model.encode_image(image_input) text_features = model.encode_text(text_inputs) # Calculate similarities similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1) similarities = similarities[0].cpu().numpy() # Calculate scores for positive and negative categories positive_scores = similarities[:len(positive_prompts)] negative_scores = similarities[len(positive_prompts):] positive_total = np.sum(positive_scores) negative_total = np.sum(negative_scores) # Determine classification is_positive = positive_total > negative_total confidence = max(positive_total, negative_total) return { 'classification': 'Positive' if is_positive else 'Negative', 'confidence': float(confidence), 'positive_score': float(positive_total), 'negative_score': float(negative_total), 'detailed_scores': { 'positive_prompts': [(prompt, float(score)) for prompt, score in zip(positive_prompts, positive_scores)], 'negative_prompts': [(prompt, float(score)) for prompt, score in zip(negative_prompts, negative_scores)] } } except Exception as e: st.error(f"Error during classification: {e}") return None def main(): st.title("CLIP-Based Custom Classifier") st.markdown("### Define your own positive and negative prompts to classify images!") # Load model with st.spinner("Loading CLIP model..."): model, preprocess, device = load_clip_model() if model is None: st.error("Failed to load CLIP model. Please check your installation.") st.stop() st.success(f"CLIP model loaded successfully on {device}") # Sidebar for configuration with st.sidebar: st.header("Configuration") st.header("Define Prompts") # Positive prompts st.subheader("Positive Prompts") positive_prompts_text = st.text_area( "Enter positive prompts (one per line):", value="happy face\nsmiling person\njoyful expression\npositive emotion", height=100, help="These prompts define what should be classified as 'Positive'" ) # Negative prompts st.subheader("Negative Prompts") negative_prompts_text = st.text_area( "Enter negative prompts (one per line):", value="sad face\nangry person\nfrowning expression\nnegative emotion", height=100, help="These prompts define what should be classified as 'Negative'" ) # Process prompts positive_prompts = [p.strip() for p in positive_prompts_text.split('\n') if p.strip()] negative_prompts = [p.strip() for p in negative_prompts_text.split('\n') if p.strip()] st.info(f"Positive prompts: {len(positive_prompts)}") st.info(f"Negative prompts: {len(negative_prompts)}") # Main content area col1, col2 = st.columns([1, 1]) with col1: st.header("Input Image") # Tabs for different input methods tab1, tab2 = st.tabs(["Upload Image", "Image URL"]) image_data = None with tab1: # File uploader - simplified for HF Spaces uploaded_file = st.file_uploader( "Choose an image file", type=['png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'], help="Upload an image file to classify", key="image_uploader" # Add explicit key ) if uploaded_file is not None: image_data = uploaded_file # Display image st.image(uploaded_file, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True) st.success("Image uploaded successfully!") with tab2: # URL input image_url = st.text_input( "Enter image URL:", placeholder="https://example.com/image.jpg", help="Enter a direct link to an image" ) if image_url.strip(): if not image_url.startswith(('http://', 'https://')): st.warning("Please enter a valid URL starting with http:// or https://") else: try: with st.spinner("Loading image..."): response = requests.get(image_url, timeout=10) response.raise_for_status() image = Image.open(io.BytesIO(response.content)) image_data = image_url st.image(image, caption="Image from URL", use_column_width=True) st.success("Image loaded successfully!") except Exception as e: st.error(f"Error loading image: {e}") with col2: st.header("Classification Results") # Status check ready_to_classify = ( image_data is not None and len(positive_prompts) > 0 and len(negative_prompts) > 0 ) if not positive_prompts or not negative_prompts: st.warning("Please define both positive and negative prompts in the sidebar.") elif image_data is None: st.info("Please provide an image to classify.") else: st.success("Ready to classify!") if st.button("Classify Image", type="primary", use_container_width=True): with st.spinner("Classifying..."): result = classify_input( model, preprocess, device, image_data, positive_prompts, negative_prompts ) if result: # Main classification result classification = result['classification'] confidence = result['confidence'] # Display result with color coding color = "green" if classification == "Positive" else "red" st.markdown(f"### Classification: {classification}", unsafe_allow_html=True) # Metrics col_conf, col_pos, col_neg = st.columns(3) with col_conf: st.metric("Confidence", f"{confidence:.3f}") with col_pos: st.metric("Positive Score", f"{result['positive_score']:.3f}") with col_neg: st.metric("Negative Score", f"{result['negative_score']:.3f}") # Detailed breakdown st.subheader("Detailed Scores") # Positive prompts scores with st.expander("Positive Prompts Scores", expanded=True): for prompt, score in result['detailed_scores']['positive_prompts']: st.progress(float(score), text=f"{prompt}: {score:.3f}") # Negative prompts scores with st.expander("Negative Prompts Scores", expanded=True): for prompt, score in result['detailed_scores']['negative_prompts']: st.progress(float(score), text=f"{prompt}: {score:.3f}") else: st.error("Classification failed. Please try again.") # Instructions with st.expander("How to use this app"): st.markdown(""" **Instructions:** 1. **Define Prompts**: In the sidebar, enter your positive and negative prompts (one per line) 2. **Upload Image**: Use either the file uploader or paste an image URL 3. **Classify**: Click the "Classify Image" button to see results **Example prompts:** - **Emotion detection**: "happy, smiling, joy" vs "sad, crying, anger" - **Object detection**: "dog, puppy, canine" vs "cat, kitten, feline" - **Content type**: "food, meal, cooking" vs "vehicle, car, transportation" **Tips for Hugging Face Spaces:** - Use common image formats (JPG, PNG, WebP) - For URLs, make sure they're publicly accessible - Keep image sizes reasonable for faster processing """) if __name__ == "__main__": main()