stephenebert commited on
Commit
1d850a0
·
verified ·
1 Parent(s): fc4cd56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -36
app.py CHANGED
@@ -3,62 +3,196 @@ import requests
3
  import gradio as gr
4
  import torch
5
  from transformers import CLIPProcessor, CLIPModel
 
 
 
 
 
6
 
7
  # 1) Load CLIP text encoder
8
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
9
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
10
  model.eval()
11
 
12
  def embed_text(text: str) -> list[float]:
13
- """Turn a string into a normalized 512-dim CLIP embedding."""
14
- inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
15
- with torch.no_grad():
16
- feats = model.get_text_features(**inputs)
17
- feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
18
- return feats.squeeze().cpu().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # 2) Where’s your FastAPI service?
21
- # In your HF Space settings → Variables, set:
22
- # API_URL = https://capstone-retrieval-api.onrender.com
23
  API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
24
 
25
  def call_search(caption: str, k: int):
26
  """Embed `caption`, POST to /search, return JSON (or error dict)."""
27
- if not caption:
28
- return {"error": "Please enter a caption to search."}
29
- # 1) embed locally
30
- vec = embed_text(caption)
31
- payload = {"query_vec": vec, "k": k}
32
- # 2) POST
33
  try:
34
- r = requests.post(f"{API_BASE}/search", json=payload, timeout=15)
35
- r.raise_for_status()
36
- return r.json()
37
- except requests.exceptions.HTTPError:
38
- return {"error": f"HTTP {r.status_code}: {r.text}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
- return {"error": str(e)}
41
 
42
  # 3) Gradio UI
43
- with gr.Blocks(title="Image ↔ Text Retrieval") as demo:
44
  gr.Markdown(
45
- "### Image ↔ Text Retrieval \n"
46
  "Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, "
47
  "POST it to your FastAPI+FAISS service, and show the top-K JSON results."
48
  )
49
-
 
50
  with gr.Row():
51
- caption_input = gr.Textbox(
52
- lines=2,
53
- placeholder="e.g. painting of King Henry VIII carrying an umbrella",
54
- label="Caption"
55
- )
56
- k_input = gr.Slider(
57
- minimum=1, maximum=10, value=3, step=1, label="Top-K"
58
  )
59
- output = gr.JSON(label="Results")
60
- btn = gr.Button("Submit")
61
- btn.click(fn=call_search, inputs=[caption_input, k_input], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
3
  import gradio as gr
4
  import torch
5
  from transformers import CLIPProcessor, CLIPModel
6
+ import logging
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
  # 1) Load CLIP text encoder
13
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
  model.eval()
16
 
17
  def embed_text(text: str) -> list[float]:
18
+ """Turn a string into a normalized CLIP embedding."""
19
+ try:
20
+ # Clean and preprocess text
21
+ text = text.strip()
22
+ if not text:
23
+ raise ValueError("Empty text input")
24
+
25
+ # Tokenize with proper handling
26
+ inputs = processor(
27
+ text=[text],
28
+ return_tensors="pt",
29
+ padding=True,
30
+ truncation=True,
31
+ max_length=77 # CLIP's max token length
32
+ )
33
+
34
+ with torch.no_grad():
35
+ # Get text features
36
+ feats = model.get_text_features(**inputs)
37
+
38
+ # Normalize to unit vector (L2 normalization)
39
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
40
+
41
+ # Convert to list and ensure proper shape
42
+ embedding = feats.squeeze().cpu().tolist()
43
+
44
+ logger.info(f"Generated embedding with shape: {len(embedding)}")
45
+ return embedding
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error in embed_text: {str(e)}")
49
+ raise
50
 
51
+ # 2) API configuration
 
 
52
  API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
53
 
54
  def call_search(caption: str, k: int):
55
  """Embed `caption`, POST to /search, return JSON (or error dict)."""
 
 
 
 
 
 
56
  try:
57
+ # Input validation
58
+ if not caption or not caption.strip():
59
+ return {"error": "Please enter a caption to search."}
60
+
61
+ caption = caption.strip()
62
+ k = max(1, min(int(k), 10)) # Clamp k between 1 and 10
63
+
64
+ logger.info(f"Searching for: '{caption}' with k={k}")
65
+
66
+ # 1) Embed locally
67
+ vec = embed_text(caption)
68
+
69
+ # Verify embedding dimensions
70
+ if len(vec) != 512:
71
+ return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"}
72
+
73
+ payload = {
74
+ "query_vec": vec,
75
+ "k": k,
76
+ "query_text": caption # Include original text for debugging
77
+ }
78
+
79
+ # 2) POST to API
80
+ headers = {
81
+ "Content-Type": "application/json",
82
+ "User-Agent": "HuggingFace-Gradio-Client"
83
+ }
84
+
85
+ response = requests.post(
86
+ f"{API_BASE}/search",
87
+ json=payload,
88
+ headers=headers,
89
+ timeout=30 # Increased timeout
90
+ )
91
+
92
+ response.raise_for_status()
93
+ result = response.json()
94
+
95
+ logger.info(f"API response status: {response.status_code}")
96
+
97
+ # Add metadata to result
98
+ if isinstance(result, dict):
99
+ result["_metadata"] = {
100
+ "query": caption,
101
+ "k": k,
102
+ "embedding_dim": len(vec),
103
+ "api_status": response.status_code
104
+ }
105
+
106
+ return result
107
+
108
+ except requests.exceptions.Timeout:
109
+ return {"error": "Request timed out. Please try again."}
110
+ except requests.exceptions.ConnectionError:
111
+ return {"error": "Could not connect to the API. Please check your internet connection."}
112
+ except requests.exceptions.HTTPError as e:
113
+ error_msg = f"HTTP {response.status_code}"
114
+ try:
115
+ error_detail = response.json().get("detail", response.text)
116
+ error_msg += f": {error_detail}"
117
+ except:
118
+ error_msg += f": {response.text}"
119
+ return {"error": error_msg}
120
+ except Exception as e:
121
+ logger.error(f"Unexpected error in call_search: {str(e)}")
122
+ return {"error": f"Unexpected error: {str(e)}"}
123
+
124
+ def validate_api_connection():
125
+ """Test API connection and return status."""
126
+ try:
127
+ response = requests.get(f"{API_BASE}/health", timeout=10)
128
+ return f"API is reachable (Status: {response.status_code})"
129
  except Exception as e:
130
+ return f"API connection failed: {str(e)}"
131
 
132
  # 3) Gradio UI
133
+ with gr.Blocks(title="Image ↔ Text Retrieval", theme=gr.themes.Soft()) as demo:
134
  gr.Markdown(
135
+ "### Image ↔ Text Retrieval\n"
136
  "Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, "
137
  "POST it to your FastAPI+FAISS service, and show the top-K JSON results."
138
  )
139
+
140
+ # API status indicator
141
  with gr.Row():
142
+ api_status = gr.Textbox(
143
+ value=validate_api_connection(),
144
+ label="API Status",
145
+ interactive=False
 
 
 
146
  )
147
+ refresh_btn = gr.Button("Refresh Status", size="sm")
148
+ refresh_btn.click(fn=validate_api_connection, outputs=api_status)
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=2):
152
+ caption_input = gr.Textbox(
153
+ lines=3,
154
+ placeholder="e.g. painting of King Henry VIII carrying an umbrella",
155
+ label="Caption",
156
+ info="Enter a descriptive text to search for similar images"
157
+ )
158
+
159
+ with gr.Column(scale=1):
160
+ k_input = gr.Slider(
161
+ minimum=1,
162
+ maximum=10,
163
+ value=3,
164
+ step=1,
165
+ label="Top-K Results"
166
+ )
167
+
168
+ with gr.Row():
169
+ btn = gr.Button("Submit", variant="primary")
170
+ clear_btn = gr.Button("Clear", variant="secondary")
171
+
172
+ output = gr.JSON(label="Search Results")
173
+
174
+ # Event handlers
175
+ btn.click(
176
+ fn=call_search,
177
+ inputs=[caption_input, k_input],
178
+ outputs=output
179
+ )
180
+
181
+ clear_btn.click(
182
+ fn=lambda: ("", 3, {}),
183
+ outputs=[caption_input, k_input, output]
184
+ )
185
+
186
+ # Allow Enter key to submit
187
+ caption_input.submit(
188
+ fn=call_search,
189
+ inputs=[caption_input, k_input],
190
+ outputs=output
191
+ )
192
 
193
  if __name__ == "__main__":
194
+ demo.launch(
195
+ server_name="0.0.0.0",
196
+ server_port=7860,
197
+ show_error=True
198
+ )