import streamlit as st import torch from PIL import Image import numpy as np import io import requests from typing import List, Tuple from transformers import CLIPProcessor, CLIPModel # Configure page st.set_page_config( page_title="CLIP Custom Classifier", page_icon="🔍", layout="wide" ) # Add custom CSS for better appearance st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_clip_model(): """Load CLIP model using Hugging Face Transformers""" try: st.info("Loading CLIP model via Hugging Face Transformers...") # Create a temporary directory for cache import tempfile import os temp_cache_dir = tempfile.mkdtemp() # Set cache directory environment variable os.environ['HF_HOME'] = temp_cache_dir os.environ['TRANSFORMERS_CACHE'] = temp_cache_dir os.environ['HF_DATASETS_CACHE'] = temp_cache_dir # Load model and processor with custom cache model_name = "openai/clip-vit-base-patch32" from transformers import CLIPModel, CLIPProcessor model = CLIPModel.from_pretrained(model_name, cache_dir=temp_cache_dir) processor = CLIPProcessor.from_pretrained(model_name, cache_dir=temp_cache_dir) device = "cpu" model.to(device) return model, processor, device except Exception as e: st.error(f"Error loading CLIP model: {e}") # Fallback: Try loading without custom cache try: st.info("Trying fallback loading method...") from transformers import CLIPModel, CLIPProcessor model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") device = "cpu" model.to(device) return model, processor, device except Exception as e2: st.error(f"Fallback also failed: {e2}") st.info("This might be a temporary issue. Please try refreshing the page in a few minutes.") return None, None, None def classify_input(model, processor, device, input_data, positive_prompts, negative_prompts, input_type="image"): """ Classify input based on positive and negative prompts using CLIP """ try: # Prepare text prompts all_prompts = positive_prompts + negative_prompts if input_type == "image": # Process image if isinstance(input_data, str): # URL response = requests.get(input_data, timeout=10) image = Image.open(io.BytesIO(response.content)) else: # Uploaded file image = Image.open(input_data) # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') # Process inputs inputs = processor(text=all_prompts, images=image, return_tensors="pt", padding=True) # Get outputs with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] elif input_type == "text": # For text-to-text comparison, we'll use text features # Process the input text and all prompts all_texts = [input_data] + all_prompts inputs = processor(text=all_texts, return_tensors="pt", padding=True) with torch.no_grad(): text_features = model.get_text_features(**inputs) # Calculate similarities between input text and prompts input_features = text_features[0:1] # First text is input prompt_features = text_features[1:] # Rest are prompts # Compute cosine similarity similarities = torch.cosine_similarity(input_features, prompt_features, dim=1) probs = torch.softmax(similarities * 100, dim=0).cpu().numpy() # Calculate scores for positive and negative categories positive_scores = probs[:len(positive_prompts)] negative_scores = probs[len(positive_prompts):] positive_total = np.sum(positive_scores) negative_total = np.sum(negative_scores) # Determine classification is_positive = positive_total > negative_total confidence = max(positive_total, negative_total) return { 'classification': 'Positive' if is_positive else 'Negative', 'confidence': float(confidence), 'positive_score': float(positive_total), 'negative_score': float(negative_total), 'detailed_scores': { 'positive_prompts': [(prompt, float(score)) for prompt, score in zip(positive_prompts, positive_scores)], 'negative_prompts': [(prompt, float(score)) for prompt, score in zip(negative_prompts, negative_scores)] } } except Exception as e: st.error(f"Error during classification: {e}") return None def main(): # Header st.markdown('

CLIP Custom Classifier

', unsafe_allow_html=True) st.markdown("### Define your own positive and negative prompts to classify images or text!") # Add info box st.markdown("""
How it works: This app uses OpenAI's CLIP model to classify images or text based on your custom prompts. Define what you consider "positive" and "negative" examples, and let AI do the classification!
""", unsafe_allow_html=True) # Load model model, processor, device = load_clip_model() if model is None: st.error("Failed to load CLIP model. Please refresh the page and try again.") st.stop() st.success(f"CLIP model loaded successfully on {device.upper()}") # Sidebar for configuration with st.sidebar: st.header("Configuration") # Input type selection input_type = st.radio("Select input type:", ["Image", "Text"], help="Choose what type of content you want to classify") st.header("Define Classification Prompts") # Positive prompts st.subheader("Positive Category") positive_prompts_text = st.text_area( "Enter positive prompts (one per line):", value="happy face\nsmiling person\njoyful expression\npositive emotion", height=120, help="These prompts define what should be classified as 'Positive'" ) # Negative prompts st.subheader("Negative Category") negative_prompts_text = st.text_area( "Enter negative prompts (one per line):", value="sad face\nangry person\nfrowning expression\nnegative emotion", height=120, help="These prompts define what should be classified as 'Negative'" ) # Process prompts positive_prompts = [p.strip() for p in positive_prompts_text.split('\n') if p.strip()] negative_prompts = [p.strip() for p in negative_prompts_text.split('\n') if p.strip()] # Display prompt counts col1, col2 = st.columns(2) with col1: st.metric("Positive", len(positive_prompts)) with col2: st.metric("Negative", len(negative_prompts)) # Main content area col1, col2 = st.columns([1, 1]) with col1: st.header("Input") input_data = None if input_type == "Image": # Image input options image_option = st.radio("Choose image source:", ["Upload", "URL"], horizontal=True) if image_option == "Upload": uploaded_file = st.file_uploader( "Choose an image file", type=['png', 'jpg', 'jpeg', 'gif', 'bmp'], help="Supported formats: PNG, JPG, JPEG, GIF, BMP" ) if uploaded_file: input_data = uploaded_file st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) else: # URL image_url = st.text_input( "Enter image URL:", placeholder="https://example.com/image.jpg", help="Paste a direct link to an image" ) if image_url: try: with st.spinner("Loading image..."): response = requests.get(image_url, timeout=10) image = Image.open(io.BytesIO(response.content)) input_data = image_url st.image(image, caption="Image from URL", use_column_width=True) except Exception as e: st.error(f"Error loading image from URL: {e}") else: # Text input text_input = st.text_area( "Enter text to classify:", height=200, placeholder="Type your text here...", help="Enter any text you want to classify" ) if text_input.strip(): input_data = text_input.strip() st.text_area("Text to classify:", value=text_input, height=100, disabled=True) with col2: st.header("Classification Results") if input_data and positive_prompts and negative_prompts: if st.button("Classify Now", type="primary", use_container_width=True): with st.spinner("AI is analyzing..."): result = classify_input( model, processor, device, input_data, positive_prompts, negative_prompts, input_type.lower() ) if result: # Main classification result classification = result['classification'] confidence = result['confidence'] # Display result with color coding if classification == "Positive": st.markdown("### Classification: POSITIVE", unsafe_allow_html=True) else: st.markdown("### Classification: NEGATIVE", unsafe_allow_html=True) # Confidence and scores in columns col_conf, col_pos, col_neg = st.columns(3) with col_conf: st.metric("Confidence", f"{confidence:.1%}") with col_pos: st.metric("Positive Score", f"{result['positive_score']:.3f}") with col_neg: st.metric("Negative Score", f"{result['negative_score']:.3f}") # Detailed breakdown st.subheader("Detailed Breakdown") # Create tabs for better organization tab1, tab2 = st.tabs(["Positive Scores", "Negative Scores"]) with tab1: st.write("**Individual prompt scores:**") for prompt, score in result['detailed_scores']['positive_prompts']: st.progress(float(score), text=f"{prompt}: {score:.3f}") with tab2: st.write("**Individual prompt scores:**") for prompt, score in result['detailed_scores']['negative_prompts']: st.progress(float(score), text=f"{prompt}: {score:.3f}") elif not positive_prompts or not negative_prompts: st.warning("Please define both positive and negative prompts in the sidebar.") elif not input_data: st.info("Please provide input data to classify in the left panel.") # Footer st.markdown("---") st.markdown( "Made with love using [OpenAI CLIP](https://openai.com/research/clip) and [Streamlit](https://streamlit.io) | " "Hosted on [Hugging Face Spaces](https://huggingface.co/spaces)" ) if __name__ == "__main__": main()