File size: 6,271 Bytes
824e2d6
 
 
bdb7dad
 
 
6194281
bdb7dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d850a0
bdb7dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6194281
bdb7dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37c3b30
bdb7dad
 
 
8e5b0a0
6194281
bdb7dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6491af2
bdb7dad
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import requests
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 1) Load CLIP text encoder
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

def embed_text(text: str) -> list[float]:
    """Turn a string into a normalized CLIP embedding."""
    try:
        # Clean and preprocess text
        text = text.strip()
        if not text:
            raise ValueError("Empty text input")
        
        # Tokenize with proper handling
        inputs = processor(
            text=[text], 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=77  # CLIP's max token length
        )
        
        with torch.no_grad():
            # Get text features
            feats = model.get_text_features(**inputs)
            
            # Normalize to unit vector (L2 normalization)
            feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
            
            # Convert to list and ensure proper shape
            embedding = feats.squeeze().cpu().tolist()
            
            logger.info(f"Generated embedding with shape: {len(embedding)}")
            return embedding
            
    except Exception as e:
        logger.error(f"Error in embed_text: {str(e)}")
        raise

# 2) API configuration
API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")

def call_search(caption: str, k: int):
    """Embed `caption`, POST to /search, return JSON (or error dict)."""
    try:
        # Input validation
        if not caption or not caption.strip():
            return {"error": "Please enter a caption to search."}
        
        caption = caption.strip()
        k = max(1, min(int(k), 10))  # Clamp k between 1 and 10
        
        logger.info(f"Searching for: '{caption}' with k={k}")
        
        # 1) Embed locally
        vec = embed_text(caption)
        
        # Verify embedding dimensions
        if len(vec) != 512:
            return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"}
        
        payload = {
            "query_vec": vec, 
            "k": k,
            "query_text": caption  # Include original text for debugging
        }
        
        # 2) POST to API
        headers = {
            "Content-Type": "application/json",
            "User-Agent": "HuggingFace-Gradio-Client"
        }
        
        response = requests.post(
            f"{API_BASE}/search", 
            json=payload, 
            headers=headers,
            timeout=30  # Increased timeout
        )
        
        response.raise_for_status()
        result = response.json()
        
        logger.info(f"API response status: {response.status_code}")
        
        # Add metadata to result
        if isinstance(result, dict):
            result["_metadata"] = {
                "query": caption,
                "k": k,
                "embedding_dim": len(vec),
                "api_status": response.status_code
            }
        
        return result
        
    except requests.exceptions.Timeout:
        return {"error": "Request timed out. Please try again."}
    except requests.exceptions.ConnectionError:
        return {"error": "Could not connect to the API. Please check your internet connection."}
    except requests.exceptions.HTTPError as e:
        error_msg = f"HTTP {response.status_code}"
        try:
            error_detail = response.json().get("detail", response.text)
            error_msg += f": {error_detail}"
        except:
            error_msg += f": {response.text}"
        return {"error": error_msg}
    except Exception as e:
        logger.error(f"Unexpected error in call_search: {str(e)}")
        return {"error": f"Unexpected error: {str(e)}"}

def validate_api_connection():
    """Test API connection and return status."""
    try:
        response = requests.get(f"{API_BASE}/health", timeout=10)
        return f"API is reachable (Status: {response.status_code})"
    except Exception as e:
        return f"API connection failed: {str(e)}"

# 3) Gradio UI
with gr.Blocks(title="Image ↔ Text Retrieval (small dataset)", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        "### Image ↔ Text Retrieval (small dataset)\n"
        "Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, "
        "POST it to your FastAPI+FAISS service, and show the top-K JSON results."
    )
    
    # API status indicator
    with gr.Row():
        api_status = gr.Textbox(
            value=validate_api_connection(),
            label="API Status",
            interactive=False
        )
        refresh_btn = gr.Button("Refresh Status", size="sm")
        refresh_btn.click(fn=validate_api_connection, outputs=api_status)
    
    with gr.Row():
        with gr.Column(scale=2):
            caption_input = gr.Textbox(
                lines=3,
                placeholder="type something",
                label="Caption",
                info="Enter a descriptive text to search for similar images"
            )
        
        with gr.Column(scale=1):
            k_input = gr.Slider(
                minimum=1,
                maximum=10,
                value=3,
                step=1,
                label="Top-K Results"
            )
    
    with gr.Row():
        btn = gr.Button("Submit", variant="primary")
        clear_btn = gr.Button("Clear", variant="secondary")
    
    output = gr.JSON(label="Search Results")
    
    # Event handlers
    btn.click(
        fn=call_search,
        inputs=[caption_input, k_input],
        outputs=output
    )
    
    clear_btn.click(
        fn=lambda: ("", 3, {}),
        outputs=[caption_input, k_input, output]
    )
    
    # Allow Enter key to submit
    caption_input.submit(
        fn=call_search,
        inputs=[caption_input, k_input],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )