File size: 13,465 Bytes
b38c8ff
cc4be5a
 
 
 
 
 
eabdac2
b38c8ff
cc4be5a
 
b8ebd69
cc4be5a
 
 
b38c8ff
b8ebd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc4be5a
 
eabdac2
cc4be5a
eabdac2
3fbd534
a045d33
 
 
 
 
 
 
 
 
 
 
eabdac2
a045d33
 
 
 
 
3fbd534
eabdac2
 
3fe1064
eabdac2
a045d33
cc4be5a
 
a045d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78c270b
a045d33
b38c8ff
eabdac2
cc4be5a
 
 
 
 
 
 
 
 
 
b8ebd69
cc4be5a
 
 
 
b8ebd69
 
 
cc4be5a
eabdac2
 
 
 
cc4be5a
eabdac2
 
 
cc4be5a
 
eabdac2
 
 
 
 
cc4be5a
 
eabdac2
cc4be5a
eabdac2
 
 
 
 
 
 
cc4be5a
 
eabdac2
 
cc4be5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b38c8ff
cc4be5a
b8ebd69
78c270b
cc4be5a
 
b8ebd69
 
 
78c270b
b8ebd69
 
 
 
cc4be5a
eabdac2
cc4be5a
 
b8ebd69
cc4be5a
 
78c270b
cc4be5a
 
 
78c270b
cc4be5a
 
b8ebd69
cc4be5a
78c270b
cc4be5a
 
78c270b
cc4be5a
 
 
b8ebd69
cc4be5a
 
 
 
78c270b
cc4be5a
 
 
b8ebd69
cc4be5a
 
 
 
 
 
 
78c270b
b8ebd69
 
78c270b
b8ebd69
78c270b
cc4be5a
 
 
 
 
78c270b
cc4be5a
 
 
 
 
b8ebd69
cc4be5a
 
 
 
b8ebd69
 
cc4be5a
 
 
78c270b
cc4be5a
 
b8ebd69
 
 
 
 
cc4be5a
 
b8ebd69
 
 
 
78c270b
cc4be5a
78c270b
cc4be5a
 
 
 
b8ebd69
 
 
cc4be5a
 
 
78c270b
cc4be5a
 
78c270b
cc4be5a
 
78c270b
 
cc4be5a
eabdac2
cc4be5a
 
 
 
 
 
 
 
 
78c270b
b8ebd69
78c270b
b8ebd69
 
78c270b
b8ebd69
cc4be5a
b8ebd69
 
cc4be5a
b8ebd69
78c270b
cc4be5a
78c270b
cc4be5a
78c270b
cc4be5a
 
78c270b
cc4be5a
b8ebd69
78c270b
cc4be5a
b8ebd69
 
 
 
 
 
 
 
 
cc4be5a
 
78c270b
cc4be5a
 
78c270b
b8ebd69
 
 
 
78c270b
 
b8ebd69
b38c8ff
cc4be5a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import streamlit as st
import torch
from PIL import Image
import numpy as np
import io
import requests
from typing import List, Tuple
from transformers import CLIPProcessor, CLIPModel

# Configure page
st.set_page_config(
    page_title="CLIP Custom Classifier",
    page_icon="🔍",
    layout="wide"
)

# Add custom CSS for better appearance
st.markdown("""
<style>
    .main-header {
        text-align: center;
        padding: 1rem 0;
        background: linear-gradient(90deg, #ff6b6b, #4ecdc4);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        font-size: 2.5rem;
        font-weight: bold;
        margin-bottom: 1rem;
    }
    .info-box {
        background-color: #f0f2f6;
        padding: 1rem;
        border-radius: 0.5rem;
        border-left: 4px solid #4ecdc4;
        margin: 1rem 0;
    }
</style>
""", unsafe_allow_html=True)

@st.cache_resource
def load_clip_model():
    """Load CLIP model using Hugging Face Transformers"""
    try:
        st.info("Loading CLIP model via Hugging Face Transformers...")
        
        # Create a temporary directory for cache
        import tempfile
        import os
        temp_cache_dir = tempfile.mkdtemp()
        
        # Set cache directory environment variable
        os.environ['HF_HOME'] = temp_cache_dir
        os.environ['TRANSFORMERS_CACHE'] = temp_cache_dir
        os.environ['HF_DATASETS_CACHE'] = temp_cache_dir
        
        # Load model and processor with custom cache
        model_name = "openai/clip-vit-base-patch32"
        
        from transformers import CLIPModel, CLIPProcessor
        
        model = CLIPModel.from_pretrained(model_name, cache_dir=temp_cache_dir)
        processor = CLIPProcessor.from_pretrained(model_name, cache_dir=temp_cache_dir)
        
        device = "cpu"
        model.to(device)
        
        return model, processor, device
        
    except Exception as e:
        st.error(f"Error loading CLIP model: {e}")
        
        # Fallback: Try loading without custom cache
        try:
            st.info("Trying fallback loading method...")
            from transformers import CLIPModel, CLIPProcessor
            
            model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
            processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            
            device = "cpu"
            model.to(device)
            
            return model, processor, device
            
        except Exception as e2:
            st.error(f"Fallback also failed: {e2}")
            st.info("This might be a temporary issue. Please try refreshing the page in a few minutes.")
            return None, None, None

def classify_input(model, processor, device, input_data, positive_prompts, negative_prompts, input_type="image"):
    """
    Classify input based on positive and negative prompts using CLIP
    """
    try:
        # Prepare text prompts
        all_prompts = positive_prompts + negative_prompts
        
        if input_type == "image":
            # Process image
            if isinstance(input_data, str):  # URL
                response = requests.get(input_data, timeout=10)
                image = Image.open(io.BytesIO(response.content))
            else:  # Uploaded file
                image = Image.open(input_data)
            
            # Convert to RGB if necessary
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Process inputs
            inputs = processor(text=all_prompts, images=image, return_tensors="pt", padding=True)
            
            # Get outputs
            with torch.no_grad():
                outputs = model(**inputs)
                logits_per_image = outputs.logits_per_image
                probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
        
        elif input_type == "text":
            # For text-to-text comparison, we'll use text features
            # Process the input text and all prompts
            all_texts = [input_data] + all_prompts
            
            inputs = processor(text=all_texts, return_tensors="pt", padding=True)
            
            with torch.no_grad():
                text_features = model.get_text_features(**inputs)
                
                # Calculate similarities between input text and prompts
                input_features = text_features[0:1]  # First text is input
                prompt_features = text_features[1:]  # Rest are prompts
                
                # Compute cosine similarity
                similarities = torch.cosine_similarity(input_features, prompt_features, dim=1)
                probs = torch.softmax(similarities * 100, dim=0).cpu().numpy()
        
        # Calculate scores for positive and negative categories
        positive_scores = probs[:len(positive_prompts)]
        negative_scores = probs[len(positive_prompts):]
        
        positive_total = np.sum(positive_scores)
        negative_total = np.sum(negative_scores)
        
        # Determine classification
        is_positive = positive_total > negative_total
        confidence = max(positive_total, negative_total)
        
        return {
            'classification': 'Positive' if is_positive else 'Negative',
            'confidence': float(confidence),
            'positive_score': float(positive_total),
            'negative_score': float(negative_total),
            'detailed_scores': {
                'positive_prompts': [(prompt, float(score)) for prompt, score in zip(positive_prompts, positive_scores)],
                'negative_prompts': [(prompt, float(score)) for prompt, score in zip(negative_prompts, negative_scores)]
            }
        }
    
    except Exception as e:
        st.error(f"Error during classification: {e}")
        return None

def main():
    # Header
    st.markdown('<h1 class="main-header">CLIP Custom Classifier</h1>', unsafe_allow_html=True)
    st.markdown("### Define your own positive and negative prompts to classify images or text!")
    
    # Add info box
    st.markdown("""
    <div class="info-box">
        <strong>How it works:</strong> This app uses OpenAI's CLIP model to classify images or text based on your custom prompts. 
        Define what you consider "positive" and "negative" examples, and let AI do the classification!
    </div>
    """, unsafe_allow_html=True)
    
    # Load model
    model, processor, device = load_clip_model()
    
    if model is None:
        st.error("Failed to load CLIP model. Please refresh the page and try again.")
        st.stop()
    
    st.success(f"CLIP model loaded successfully on {device.upper()}")
    
    # Sidebar for configuration
    with st.sidebar:
        st.header("Configuration")
        
        # Input type selection
        input_type = st.radio("Select input type:", ["Image", "Text"], help="Choose what type of content you want to classify")
        
        st.header("Define Classification Prompts")
        
        # Positive prompts
        st.subheader("Positive Category")
        positive_prompts_text = st.text_area(
            "Enter positive prompts (one per line):",
            value="happy face\nsmiling person\njoyful expression\npositive emotion",
            height=120,
            help="These prompts define what should be classified as 'Positive'"
        )
        
        # Negative prompts
        st.subheader("Negative Category")
        negative_prompts_text = st.text_area(
            "Enter negative prompts (one per line):",
            value="sad face\nangry person\nfrowning expression\nnegative emotion",
            height=120,
            help="These prompts define what should be classified as 'Negative'"
        )
        
        # Process prompts
        positive_prompts = [p.strip() for p in positive_prompts_text.split('\n') if p.strip()]
        negative_prompts = [p.strip() for p in negative_prompts_text.split('\n') if p.strip()]
        
        # Display prompt counts
        col1, col2 = st.columns(2)
        with col1:
            st.metric("Positive", len(positive_prompts))
        with col2:
            st.metric("Negative", len(negative_prompts))
    
    # Main content area
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.header("Input")
        
        input_data = None
        
        if input_type == "Image":
            # Image input options
            image_option = st.radio("Choose image source:", ["Upload", "URL"], horizontal=True)
            
            if image_option == "Upload":
                uploaded_file = st.file_uploader(
                    "Choose an image file",
                    type=['png', 'jpg', 'jpeg', 'gif', 'bmp'],
                    help="Supported formats: PNG, JPG, JPEG, GIF, BMP"
                )
                if uploaded_file:
                    input_data = uploaded_file
                    st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
            
            else:  # URL
                image_url = st.text_input(
                    "Enter image URL:",
                    placeholder="https://example.com/image.jpg",
                    help="Paste a direct link to an image"
                )
                if image_url:
                    try:
                        with st.spinner("Loading image..."):
                            response = requests.get(image_url, timeout=10)
                            image = Image.open(io.BytesIO(response.content))
                            input_data = image_url
                            st.image(image, caption="Image from URL", use_column_width=True)
                    except Exception as e:
                        st.error(f"Error loading image from URL: {e}")
        
        else:  # Text input
            text_input = st.text_area(
                "Enter text to classify:",
                height=200,
                placeholder="Type your text here...",
                help="Enter any text you want to classify"
            )
            if text_input.strip():
                input_data = text_input.strip()
                st.text_area("Text to classify:", value=text_input, height=100, disabled=True)
    
    with col2:
        st.header("Classification Results")
        
        if input_data and positive_prompts and negative_prompts:
            if st.button("Classify Now", type="primary", use_container_width=True):
                with st.spinner("AI is analyzing..."):
                    result = classify_input(
                        model, processor, device, input_data, 
                        positive_prompts, negative_prompts,
                        input_type.lower()
                    )
                
                if result:
                    # Main classification result
                    classification = result['classification']
                    confidence = result['confidence']
                    
                    # Display result with color coding
                    if classification == "Positive":
                        st.markdown("### Classification: <span style='color: #28a745; font-weight: bold;'>POSITIVE</span>", 
                                  unsafe_allow_html=True)
                    else:
                        st.markdown("### Classification: <span style='color: #dc3545; font-weight: bold;'>NEGATIVE</span>", 
                                  unsafe_allow_html=True)
                    
                    # Confidence and scores in columns
                    col_conf, col_pos, col_neg = st.columns(3)
                    
                    with col_conf:
                        st.metric("Confidence", f"{confidence:.1%}")
                    with col_pos:
                        st.metric("Positive Score", f"{result['positive_score']:.3f}")
                    with col_neg:
                        st.metric("Negative Score", f"{result['negative_score']:.3f}")
                    
                    # Detailed breakdown
                    st.subheader("Detailed Breakdown")
                    
                    # Create tabs for better organization
                    tab1, tab2 = st.tabs(["Positive Scores", "Negative Scores"])
                    
                    with tab1:
                        st.write("**Individual prompt scores:**")
                        for prompt, score in result['detailed_scores']['positive_prompts']:
                            st.progress(float(score), text=f"{prompt}: {score:.3f}")
                    
                    with tab2:
                        st.write("**Individual prompt scores:**")
                        for prompt, score in result['detailed_scores']['negative_prompts']:
                            st.progress(float(score), text=f"{prompt}: {score:.3f}")
        
        elif not positive_prompts or not negative_prompts:
            st.warning("Please define both positive and negative prompts in the sidebar.")
        
        elif not input_data:
            st.info("Please provide input data to classify in the left panel.")
    
    # Footer
    st.markdown("---")
    st.markdown(
        "Made with love using [OpenAI CLIP](https://openai.com/research/clip) and [Streamlit](https://streamlit.io) | "
        "Hosted on [Hugging Face Spaces](https://huggingface.co/spaces)"
    )

if __name__ == "__main__":
    main()