Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import gradio as gr | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 1) Load CLIP text encoder | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| model.eval() | |
| def embed_text(text: str) -> list[float]: | |
| """Turn a string into a normalized CLIP embedding.""" | |
| try: | |
| # Clean and preprocess text | |
| text = text.strip() | |
| if not text: | |
| raise ValueError("Empty text input") | |
| # Tokenize with proper handling | |
| inputs = processor( | |
| text=[text], | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77 # CLIP's max token length | |
| ) | |
| with torch.no_grad(): | |
| # Get text features | |
| feats = model.get_text_features(**inputs) | |
| # Normalize to unit vector (L2 normalization) | |
| feats = feats / feats.norm(p=2, dim=-1, keepdim=True) | |
| # Convert to list and ensure proper shape | |
| embedding = feats.squeeze().cpu().tolist() | |
| logger.info(f"Generated embedding with shape: {len(embedding)}") | |
| return embedding | |
| except Exception as e: | |
| logger.error(f"Error in embed_text: {str(e)}") | |
| raise | |
| # 2) API configuration | |
| API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/") | |
| def call_search(caption: str, k: int): | |
| """Embed `caption`, POST to /search, return JSON (or error dict).""" | |
| try: | |
| # Input validation | |
| if not caption or not caption.strip(): | |
| return {"error": "Please enter a caption to search."} | |
| caption = caption.strip() | |
| k = max(1, min(int(k), 10)) # Clamp k between 1 and 10 | |
| logger.info(f"Searching for: '{caption}' with k={k}") | |
| # 1) Embed locally | |
| vec = embed_text(caption) | |
| # Verify embedding dimensions | |
| if len(vec) != 512: | |
| return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"} | |
| payload = { | |
| "query_vec": vec, | |
| "k": k, | |
| "query_text": caption # Include original text for debugging | |
| } | |
| # 2) POST to API | |
| headers = { | |
| "Content-Type": "application/json", | |
| "User-Agent": "HuggingFace-Gradio-Client" | |
| } | |
| response = requests.post( | |
| f"{API_BASE}/search", | |
| json=payload, | |
| headers=headers, | |
| timeout=30 # Increased timeout | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| logger.info(f"API response status: {response.status_code}") | |
| # Add metadata to result | |
| if isinstance(result, dict): | |
| result["_metadata"] = { | |
| "query": caption, | |
| "k": k, | |
| "embedding_dim": len(vec), | |
| "api_status": response.status_code | |
| } | |
| return result | |
| except requests.exceptions.Timeout: | |
| return {"error": "Request timed out. Please try again."} | |
| except requests.exceptions.ConnectionError: | |
| return {"error": "Could not connect to the API. Please check your internet connection."} | |
| except requests.exceptions.HTTPError as e: | |
| error_msg = f"HTTP {response.status_code}" | |
| try: | |
| error_detail = response.json().get("detail", response.text) | |
| error_msg += f": {error_detail}" | |
| except: | |
| error_msg += f": {response.text}" | |
| return {"error": error_msg} | |
| except Exception as e: | |
| logger.error(f"Unexpected error in call_search: {str(e)}") | |
| return {"error": f"Unexpected error: {str(e)}"} | |
| def validate_api_connection(): | |
| """Test API connection and return status.""" | |
| try: | |
| response = requests.get(f"{API_BASE}/health", timeout=10) | |
| return f"API is reachable (Status: {response.status_code})" | |
| except Exception as e: | |
| return f"API connection failed: {str(e)}" | |
| # 3) Gradio UI | |
| with gr.Blocks(title="Image ↔ Text Retrieval (small dataset)", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "### Image ↔ Text Retrieval (small dataset)\n" | |
| "Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, " | |
| "POST it to your FastAPI+FAISS service, and show the top-K JSON results." | |
| ) | |
| # API status indicator | |
| with gr.Row(): | |
| api_status = gr.Textbox( | |
| value=validate_api_connection(), | |
| label="API Status", | |
| interactive=False | |
| ) | |
| refresh_btn = gr.Button("Refresh Status", size="sm") | |
| refresh_btn.click(fn=validate_api_connection, outputs=api_status) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| caption_input = gr.Textbox( | |
| lines=3, | |
| placeholder="type something", | |
| label="Caption", | |
| info="Enter a descriptive text to search for similar images" | |
| ) | |
| with gr.Column(scale=1): | |
| k_input = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Top-K Results" | |
| ) | |
| with gr.Row(): | |
| btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| output = gr.JSON(label="Search Results") | |
| # Event handlers | |
| btn.click( | |
| fn=call_search, | |
| inputs=[caption_input, k_input], | |
| outputs=output | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", 3, {}), | |
| outputs=[caption_input, k_input, output] | |
| ) | |
| # Allow Enter key to submit | |
| caption_input.submit( | |
| fn=call_search, | |
| inputs=[caption_input, k_input], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |