import streamlit as st
import torch
from PIL import Image
import numpy as np
import io
import requests
from typing import List, Tuple
from transformers import CLIPProcessor, CLIPModel
# Configure page
st.set_page_config(
page_title="CLIP Custom Classifier",
page_icon="🔍",
layout="wide"
)
# Add custom CSS for better appearance
st.markdown("""
""", unsafe_allow_html=True)
@st.cache_resource
def load_clip_model():
"""Load CLIP model using Hugging Face Transformers"""
try:
st.info("Loading CLIP model via Hugging Face Transformers...")
# Create a temporary directory for cache
import tempfile
import os
temp_cache_dir = tempfile.mkdtemp()
# Set cache directory environment variable
os.environ['HF_HOME'] = temp_cache_dir
os.environ['TRANSFORMERS_CACHE'] = temp_cache_dir
os.environ['HF_DATASETS_CACHE'] = temp_cache_dir
# Load model and processor with custom cache
model_name = "openai/clip-vit-base-patch32"
from transformers import CLIPModel, CLIPProcessor
model = CLIPModel.from_pretrained(model_name, cache_dir=temp_cache_dir)
processor = CLIPProcessor.from_pretrained(model_name, cache_dir=temp_cache_dir)
device = "cpu"
model.to(device)
return model, processor, device
except Exception as e:
st.error(f"Error loading CLIP model: {e}")
# Fallback: Try loading without custom cache
try:
st.info("Trying fallback loading method...")
from transformers import CLIPModel, CLIPProcessor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = "cpu"
model.to(device)
return model, processor, device
except Exception as e2:
st.error(f"Fallback also failed: {e2}")
st.info("This might be a temporary issue. Please try refreshing the page in a few minutes.")
return None, None, None
def classify_input(model, processor, device, input_data, positive_prompts, negative_prompts, input_type="image"):
"""
Classify input based on positive and negative prompts using CLIP
"""
try:
# Prepare text prompts
all_prompts = positive_prompts + negative_prompts
if input_type == "image":
# Process image
if isinstance(input_data, str): # URL
response = requests.get(input_data, timeout=10)
image = Image.open(io.BytesIO(response.content))
else: # Uploaded file
image = Image.open(input_data)
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Process inputs
inputs = processor(text=all_prompts, images=image, return_tensors="pt", padding=True)
# Get outputs
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
elif input_type == "text":
# For text-to-text comparison, we'll use text features
# Process the input text and all prompts
all_texts = [input_data] + all_prompts
inputs = processor(text=all_texts, return_tensors="pt", padding=True)
with torch.no_grad():
text_features = model.get_text_features(**inputs)
# Calculate similarities between input text and prompts
input_features = text_features[0:1] # First text is input
prompt_features = text_features[1:] # Rest are prompts
# Compute cosine similarity
similarities = torch.cosine_similarity(input_features, prompt_features, dim=1)
probs = torch.softmax(similarities * 100, dim=0).cpu().numpy()
# Calculate scores for positive and negative categories
positive_scores = probs[:len(positive_prompts)]
negative_scores = probs[len(positive_prompts):]
positive_total = np.sum(positive_scores)
negative_total = np.sum(negative_scores)
# Determine classification
is_positive = positive_total > negative_total
confidence = max(positive_total, negative_total)
return {
'classification': 'Positive' if is_positive else 'Negative',
'confidence': float(confidence),
'positive_score': float(positive_total),
'negative_score': float(negative_total),
'detailed_scores': {
'positive_prompts': [(prompt, float(score)) for prompt, score in zip(positive_prompts, positive_scores)],
'negative_prompts': [(prompt, float(score)) for prompt, score in zip(negative_prompts, negative_scores)]
}
}
except Exception as e:
st.error(f"Error during classification: {e}")
return None
def main():
# Header
st.markdown('
CLIP Custom Classifier
', unsafe_allow_html=True)
st.markdown("### Define your own positive and negative prompts to classify images or text!")
# Add info box
st.markdown("""
How it works: This app uses OpenAI's CLIP model to classify images or text based on your custom prompts.
Define what you consider "positive" and "negative" examples, and let AI do the classification!
""", unsafe_allow_html=True)
# Load model
model, processor, device = load_clip_model()
if model is None:
st.error("Failed to load CLIP model. Please refresh the page and try again.")
st.stop()
st.success(f"CLIP model loaded successfully on {device.upper()}")
# Sidebar for configuration
with st.sidebar:
st.header("Configuration")
# Input type selection
input_type = st.radio("Select input type:", ["Image", "Text"], help="Choose what type of content you want to classify")
st.header("Define Classification Prompts")
# Positive prompts
st.subheader("Positive Category")
positive_prompts_text = st.text_area(
"Enter positive prompts (one per line):",
value="happy face\nsmiling person\njoyful expression\npositive emotion",
height=120,
help="These prompts define what should be classified as 'Positive'"
)
# Negative prompts
st.subheader("Negative Category")
negative_prompts_text = st.text_area(
"Enter negative prompts (one per line):",
value="sad face\nangry person\nfrowning expression\nnegative emotion",
height=120,
help="These prompts define what should be classified as 'Negative'"
)
# Process prompts
positive_prompts = [p.strip() for p in positive_prompts_text.split('\n') if p.strip()]
negative_prompts = [p.strip() for p in negative_prompts_text.split('\n') if p.strip()]
# Display prompt counts
col1, col2 = st.columns(2)
with col1:
st.metric("Positive", len(positive_prompts))
with col2:
st.metric("Negative", len(negative_prompts))
# Main content area
col1, col2 = st.columns([1, 1])
with col1:
st.header("Input")
input_data = None
if input_type == "Image":
# Image input options
image_option = st.radio("Choose image source:", ["Upload", "URL"], horizontal=True)
if image_option == "Upload":
uploaded_file = st.file_uploader(
"Choose an image file",
type=['png', 'jpg', 'jpeg', 'gif', 'bmp'],
help="Supported formats: PNG, JPG, JPEG, GIF, BMP"
)
if uploaded_file:
input_data = uploaded_file
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
else: # URL
image_url = st.text_input(
"Enter image URL:",
placeholder="https://example.com/image.jpg",
help="Paste a direct link to an image"
)
if image_url:
try:
with st.spinner("Loading image..."):
response = requests.get(image_url, timeout=10)
image = Image.open(io.BytesIO(response.content))
input_data = image_url
st.image(image, caption="Image from URL", use_column_width=True)
except Exception as e:
st.error(f"Error loading image from URL: {e}")
else: # Text input
text_input = st.text_area(
"Enter text to classify:",
height=200,
placeholder="Type your text here...",
help="Enter any text you want to classify"
)
if text_input.strip():
input_data = text_input.strip()
st.text_area("Text to classify:", value=text_input, height=100, disabled=True)
with col2:
st.header("Classification Results")
if input_data and positive_prompts and negative_prompts:
if st.button("Classify Now", type="primary", use_container_width=True):
with st.spinner("AI is analyzing..."):
result = classify_input(
model, processor, device, input_data,
positive_prompts, negative_prompts,
input_type.lower()
)
if result:
# Main classification result
classification = result['classification']
confidence = result['confidence']
# Display result with color coding
if classification == "Positive":
st.markdown("### Classification: POSITIVE",
unsafe_allow_html=True)
else:
st.markdown("### Classification: NEGATIVE",
unsafe_allow_html=True)
# Confidence and scores in columns
col_conf, col_pos, col_neg = st.columns(3)
with col_conf:
st.metric("Confidence", f"{confidence:.1%}")
with col_pos:
st.metric("Positive Score", f"{result['positive_score']:.3f}")
with col_neg:
st.metric("Negative Score", f"{result['negative_score']:.3f}")
# Detailed breakdown
st.subheader("Detailed Breakdown")
# Create tabs for better organization
tab1, tab2 = st.tabs(["Positive Scores", "Negative Scores"])
with tab1:
st.write("**Individual prompt scores:**")
for prompt, score in result['detailed_scores']['positive_prompts']:
st.progress(float(score), text=f"{prompt}: {score:.3f}")
with tab2:
st.write("**Individual prompt scores:**")
for prompt, score in result['detailed_scores']['negative_prompts']:
st.progress(float(score), text=f"{prompt}: {score:.3f}")
elif not positive_prompts or not negative_prompts:
st.warning("Please define both positive and negative prompts in the sidebar.")
elif not input_data:
st.info("Please provide input data to classify in the left panel.")
# Footer
st.markdown("---")
st.markdown(
"Made with love using [OpenAI CLIP](https://openai.com/research/clip) and [Streamlit](https://streamlit.io) | "
"Hosted on [Hugging Face Spaces](https://huggingface.co/spaces)"
)
if __name__ == "__main__":
main()