Spaces:
Sleeping
Sleeping
File size: 6,271 Bytes
824e2d6 bdb7dad 6194281 bdb7dad 1d850a0 bdb7dad 6194281 bdb7dad 37c3b30 bdb7dad 8e5b0a0 6194281 bdb7dad 6491af2 bdb7dad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | import os
import requests
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 1) Load CLIP text encoder
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
def embed_text(text: str) -> list[float]:
"""Turn a string into a normalized CLIP embedding."""
try:
# Clean and preprocess text
text = text.strip()
if not text:
raise ValueError("Empty text input")
# Tokenize with proper handling
inputs = processor(
text=[text],
return_tensors="pt",
padding=True,
truncation=True,
max_length=77 # CLIP's max token length
)
with torch.no_grad():
# Get text features
feats = model.get_text_features(**inputs)
# Normalize to unit vector (L2 normalization)
feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
# Convert to list and ensure proper shape
embedding = feats.squeeze().cpu().tolist()
logger.info(f"Generated embedding with shape: {len(embedding)}")
return embedding
except Exception as e:
logger.error(f"Error in embed_text: {str(e)}")
raise
# 2) API configuration
API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
def call_search(caption: str, k: int):
"""Embed `caption`, POST to /search, return JSON (or error dict)."""
try:
# Input validation
if not caption or not caption.strip():
return {"error": "Please enter a caption to search."}
caption = caption.strip()
k = max(1, min(int(k), 10)) # Clamp k between 1 and 10
logger.info(f"Searching for: '{caption}' with k={k}")
# 1) Embed locally
vec = embed_text(caption)
# Verify embedding dimensions
if len(vec) != 512:
return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"}
payload = {
"query_vec": vec,
"k": k,
"query_text": caption # Include original text for debugging
}
# 2) POST to API
headers = {
"Content-Type": "application/json",
"User-Agent": "HuggingFace-Gradio-Client"
}
response = requests.post(
f"{API_BASE}/search",
json=payload,
headers=headers,
timeout=30 # Increased timeout
)
response.raise_for_status()
result = response.json()
logger.info(f"API response status: {response.status_code}")
# Add metadata to result
if isinstance(result, dict):
result["_metadata"] = {
"query": caption,
"k": k,
"embedding_dim": len(vec),
"api_status": response.status_code
}
return result
except requests.exceptions.Timeout:
return {"error": "Request timed out. Please try again."}
except requests.exceptions.ConnectionError:
return {"error": "Could not connect to the API. Please check your internet connection."}
except requests.exceptions.HTTPError as e:
error_msg = f"HTTP {response.status_code}"
try:
error_detail = response.json().get("detail", response.text)
error_msg += f": {error_detail}"
except:
error_msg += f": {response.text}"
return {"error": error_msg}
except Exception as e:
logger.error(f"Unexpected error in call_search: {str(e)}")
return {"error": f"Unexpected error: {str(e)}"}
def validate_api_connection():
"""Test API connection and return status."""
try:
response = requests.get(f"{API_BASE}/health", timeout=10)
return f"API is reachable (Status: {response.status_code})"
except Exception as e:
return f"API connection failed: {str(e)}"
# 3) Gradio UI
with gr.Blocks(title="Image ↔ Text Retrieval (small dataset)", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"### Image ↔ Text Retrieval (small dataset)\n"
"Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, "
"POST it to your FastAPI+FAISS service, and show the top-K JSON results."
)
# API status indicator
with gr.Row():
api_status = gr.Textbox(
value=validate_api_connection(),
label="API Status",
interactive=False
)
refresh_btn = gr.Button("Refresh Status", size="sm")
refresh_btn.click(fn=validate_api_connection, outputs=api_status)
with gr.Row():
with gr.Column(scale=2):
caption_input = gr.Textbox(
lines=3,
placeholder="type something",
label="Caption",
info="Enter a descriptive text to search for similar images"
)
with gr.Column(scale=1):
k_input = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Top-K Results"
)
with gr.Row():
btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
output = gr.JSON(label="Search Results")
# Event handlers
btn.click(
fn=call_search,
inputs=[caption_input, k_input],
outputs=output
)
clear_btn.click(
fn=lambda: ("", 3, {}),
outputs=[caption_input, k_input, output]
)
# Allow Enter key to submit
caption_input.submit(
fn=call_search,
inputs=[caption_input, k_input],
outputs=output
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
) |