import os import requests import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 1) Load CLIP text encoder processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") model.eval() def embed_text(text: str) -> list[float]: """Turn a string into a normalized CLIP embedding.""" try: # Clean and preprocess text text = text.strip() if not text: raise ValueError("Empty text input") # Tokenize with proper handling inputs = processor( text=[text], return_tensors="pt", padding=True, truncation=True, max_length=77 # CLIP's max token length ) with torch.no_grad(): # Get text features feats = model.get_text_features(**inputs) # Normalize to unit vector (L2 normalization) feats = feats / feats.norm(p=2, dim=-1, keepdim=True) # Convert to list and ensure proper shape embedding = feats.squeeze().cpu().tolist() logger.info(f"Generated embedding with shape: {len(embedding)}") return embedding except Exception as e: logger.error(f"Error in embed_text: {str(e)}") raise # 2) API configuration API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/") def call_search(caption: str, k: int): """Embed `caption`, POST to /search, return JSON (or error dict).""" try: # Input validation if not caption or not caption.strip(): return {"error": "Please enter a caption to search."} caption = caption.strip() k = max(1, min(int(k), 10)) # Clamp k between 1 and 10 logger.info(f"Searching for: '{caption}' with k={k}") # 1) Embed locally vec = embed_text(caption) # Verify embedding dimensions if len(vec) != 512: return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"} payload = { "query_vec": vec, "k": k, "query_text": caption # Include original text for debugging } # 2) POST to API headers = { "Content-Type": "application/json", "User-Agent": "HuggingFace-Gradio-Client" } response = requests.post( f"{API_BASE}/search", json=payload, headers=headers, timeout=30 # Increased timeout ) response.raise_for_status() result = response.json() logger.info(f"API response status: {response.status_code}") # Add metadata to result if isinstance(result, dict): result["_metadata"] = { "query": caption, "k": k, "embedding_dim": len(vec), "api_status": response.status_code } return result except requests.exceptions.Timeout: return {"error": "Request timed out. Please try again."} except requests.exceptions.ConnectionError: return {"error": "Could not connect to the API. Please check your internet connection."} except requests.exceptions.HTTPError as e: error_msg = f"HTTP {response.status_code}" try: error_detail = response.json().get("detail", response.text) error_msg += f": {error_detail}" except: error_msg += f": {response.text}" return {"error": error_msg} except Exception as e: logger.error(f"Unexpected error in call_search: {str(e)}") return {"error": f"Unexpected error: {str(e)}"} def validate_api_connection(): """Test API connection and return status.""" try: response = requests.get(f"{API_BASE}/health", timeout=10) return f"API is reachable (Status: {response.status_code})" except Exception as e: return f"API connection failed: {str(e)}" # 3) Gradio UI with gr.Blocks(title="Image ↔ Text Retrieval (small dataset)", theme=gr.themes.Soft()) as demo: gr.Markdown( "### Image ↔ Text Retrieval (small dataset)\n" "Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, " "POST it to your FastAPI+FAISS service, and show the top-K JSON results." ) # API status indicator with gr.Row(): api_status = gr.Textbox( value=validate_api_connection(), label="API Status", interactive=False ) refresh_btn = gr.Button("Refresh Status", size="sm") refresh_btn.click(fn=validate_api_connection, outputs=api_status) with gr.Row(): with gr.Column(scale=2): caption_input = gr.Textbox( lines=3, placeholder="type something", label="Caption", info="Enter a descriptive text to search for similar images" ) with gr.Column(scale=1): k_input = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Top-K Results" ) with gr.Row(): btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") output = gr.JSON(label="Search Results") # Event handlers btn.click( fn=call_search, inputs=[caption_input, k_input], outputs=output ) clear_btn.click( fn=lambda: ("", 3, {}), outputs=[caption_input, k_input, output] ) # Allow Enter key to submit caption_input.submit( fn=call_search, inputs=[caption_input, k_input], outputs=output ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True )