retrieval-demo / app.py
stephenebert's picture
Update app.py
bdb7dad verified
import os
import requests
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 1) Load CLIP text encoder
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
def embed_text(text: str) -> list[float]:
"""Turn a string into a normalized CLIP embedding."""
try:
# Clean and preprocess text
text = text.strip()
if not text:
raise ValueError("Empty text input")
# Tokenize with proper handling
inputs = processor(
text=[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=77 # CLIP's max token length
)
with torch.no_grad():
# Get text features
feats = model.get_text_features(**inputs)
# Normalize to unit vector (L2 normalization)
feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
# Convert to list and ensure proper shape
embedding = feats.squeeze().cpu().tolist()
logger.info(f"Generated embedding with shape: {len(embedding)}")
return embedding
except Exception as e:
logger.error(f"Error in embed_text: {str(e)}")
raise
# 2) API configuration
API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
def call_search(caption: str, k: int):
"""Embed `caption`, POST to /search, return JSON (or error dict)."""
try:
# Input validation
if not caption or not caption.strip():
return {"error": "Please enter a caption to search."}
caption = caption.strip()
k = max(1, min(int(k), 10)) # Clamp k between 1 and 10
logger.info(f"Searching for: '{caption}' with k={k}")
# 1) Embed locally
vec = embed_text(caption)
# Verify embedding dimensions
if len(vec) != 512:
return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"}
payload = {
"query_vec": vec,
"k": k,
"query_text": caption # Include original text for debugging
}
# 2) POST to API
headers = {
"Content-Type": "application/json",
"User-Agent": "HuggingFace-Gradio-Client"
}
response = requests.post(
f"{API_BASE}/search",
json=payload,
headers=headers,
timeout=30 # Increased timeout
)
response.raise_for_status()
result = response.json()
logger.info(f"API response status: {response.status_code}")
# Add metadata to result
if isinstance(result, dict):
result["_metadata"] = {
"query": caption,
"k": k,
"embedding_dim": len(vec),
"api_status": response.status_code
}
return result
except requests.exceptions.Timeout:
return {"error": "Request timed out. Please try again."}
except requests.exceptions.ConnectionError:
return {"error": "Could not connect to the API. Please check your internet connection."}
except requests.exceptions.HTTPError as e:
error_msg = f"HTTP {response.status_code}"
try:
error_detail = response.json().get("detail", response.text)
error_msg += f": {error_detail}"
except:
error_msg += f": {response.text}"
return {"error": error_msg}
except Exception as e:
logger.error(f"Unexpected error in call_search: {str(e)}")
return {"error": f"Unexpected error: {str(e)}"}
def validate_api_connection():
"""Test API connection and return status."""
try:
response = requests.get(f"{API_BASE}/health", timeout=10)
return f"API is reachable (Status: {response.status_code})"
except Exception as e:
return f"API connection failed: {str(e)}"
# 3) Gradio UI
with gr.Blocks(title="Image ↔ Text Retrieval (small dataset)", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"### Image ↔ Text Retrieval (small dataset)\n"
"Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, "
"POST it to your FastAPI+FAISS service, and show the top-K JSON results."
)
# API status indicator
with gr.Row():
api_status = gr.Textbox(
value=validate_api_connection(),
label="API Status",
interactive=False
)
refresh_btn = gr.Button("Refresh Status", size="sm")
refresh_btn.click(fn=validate_api_connection, outputs=api_status)
with gr.Row():
with gr.Column(scale=2):
caption_input = gr.Textbox(
lines=3,
placeholder="type something",
label="Caption",
info="Enter a descriptive text to search for similar images"
)
with gr.Column(scale=1):
k_input = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Top-K Results"
)
with gr.Row():
btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
output = gr.JSON(label="Search Results")
# Event handlers
btn.click(
fn=call_search,
inputs=[caption_input, k_input],
outputs=output
)
clear_btn.click(
fn=lambda: ("", 3, {}),
outputs=[caption_input, k_input, output]
)
# Allow Enter key to submit
caption_input.submit(
fn=call_search,
inputs=[caption_input, k_input],
outputs=output
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)