stephenebert commited on
Commit
1e33c60
·
verified ·
1 Parent(s): 30bc636

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -171
app.py CHANGED
@@ -1,198 +1,110 @@
 
 
1
  import os
2
  import requests
3
  import gradio as gr
4
  import torch
5
  from transformers import CLIPProcessor, CLIPModel
6
- import logging
7
-
8
- # Set up logging
9
- logging.basicConfig(level=logging.INFO)
10
- logger = logging.getLogger(__name__)
11
 
12
- # 1) Load CLIP text encoder
 
13
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
  model.eval()
16
 
 
17
  def embed_text(text: str) -> list[float]:
18
- """Turn a string into a normalized CLIP embedding."""
19
- try:
20
- # Clean and preprocess text
21
- text = text.strip()
22
- if not text:
23
- raise ValueError("Empty text input")
24
-
25
- # Tokenize with proper handling
26
- inputs = processor(
27
- text=[text],
28
- return_tensors="pt",
29
- padding=True,
30
- truncation=True,
31
- max_length=77 # CLIP's max token length
32
- )
33
-
34
- with torch.no_grad():
35
- # Get text features
36
- feats = model.get_text_features(**inputs)
37
-
38
- # Normalize to unit vector (L2 normalization)
39
- feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
40
-
41
- # Convert to list and ensure proper shape
42
- embedding = feats.squeeze().cpu().tolist()
43
-
44
- logger.info(f"Generated embedding with shape: {len(embedding)}")
45
- return embedding
46
-
47
- except Exception as e:
48
- logger.error(f"Error in embed_text: {str(e)}")
49
- raise
50
 
51
- # 2) API configuration
 
 
 
 
52
  API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
53
 
 
54
  def call_search(caption: str, k: int):
55
- """Embed `caption`, POST to /search, return JSON (or error dict)."""
56
- try:
57
- # Input validation
58
- if not caption or not caption.strip():
59
- return {"error": "Please enter a caption to search."}
60
-
61
- caption = caption.strip()
62
- k = max(1, min(int(k), 10)) # Clamp k between 1 and 10
63
-
64
- logger.info(f"Searching for: '{caption}' with k={k}")
65
-
66
- # 1) Embed locally
67
- vec = embed_text(caption)
68
-
69
- # Verify embedding dimensions
70
- if len(vec) != 512:
71
- return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"}
72
-
73
- payload = {
74
- "query_vec": vec,
75
- "k": k,
76
- "query_text": caption # Include original text for debugging
77
- }
78
-
79
- # 2) POST to API
80
- headers = {
81
- "Content-Type": "application/json",
82
- "User-Agent": "HuggingFace-Gradio-Client"
83
- }
84
-
85
- response = requests.post(
86
- f"{API_BASE}/search",
87
- json=payload,
88
- headers=headers,
89
- timeout=30 # Increased timeout
90
- )
91
-
92
- response.raise_for_status()
93
- result = response.json()
94
-
95
- logger.info(f"API response status: {response.status_code}")
96
-
97
- # Add metadata to result
98
- if isinstance(result, dict):
99
- result["_metadata"] = {
100
- "query": caption,
101
- "k": k,
102
- "embedding_dim": len(vec),
103
- "api_status": response.status_code
104
- }
105
-
106
- return result
107
-
108
- except requests.exceptions.Timeout:
109
- return {"error": "Request timed out. Please try again."}
110
- except requests.exceptions.ConnectionError:
111
- return {"error": "Could not connect to the API. Please check your internet connection."}
112
- except requests.exceptions.HTTPError as e:
113
- error_msg = f"HTTP {response.status_code}"
114
- try:
115
- error_detail = response.json().get("detail", response.text)
116
- error_msg += f": {error_detail}"
117
- except:
118
- error_msg += f": {response.text}"
119
- return {"error": error_msg}
120
- except Exception as e:
121
- logger.error(f"Unexpected error in call_search: {str(e)}")
122
- return {"error": f"Unexpected error: {str(e)}"}
123
 
124
- def validate_api_connection():
125
- """Test API connection and return status."""
126
  try:
127
- response = requests.get(f"{API_BASE}/health", timeout=10)
128
- return f"API is reachable (Status: {response.status_code})"
 
129
  except Exception as e:
130
- return f"API connection failed: {str(e)}"
 
131
 
 
 
 
 
 
 
 
 
 
 
 
132
  # 3) Gradio UI
133
- with gr.Blocks(title="Image ↔ Text Retrieval (small dataset)", theme=gr.themes.Soft()) as demo:
134
  gr.Markdown(
135
- "### Image ↔ Text Retrieval (small dataset)\n"
136
- "Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, "
137
- "POST it to your FastAPI+FAISS service, and show the top-K JSON results."
 
 
138
  )
139
-
140
- # API status indicator
141
  with gr.Row():
142
- api_status = gr.Textbox(
143
- value=validate_api_connection(),
144
- label="API Status",
145
- interactive=False
146
  )
147
- refresh_btn = gr.Button("Refresh Status", size="sm")
148
- refresh_btn.click(fn=validate_api_connection, outputs=api_status)
149
-
150
- with gr.Row():
151
- with gr.Column(scale=2):
152
- caption_input = gr.Textbox(
153
- lines=3,
154
- placeholder="type something",
155
- label="Caption",
156
- info="Enter a descriptive text to search for similar images"
157
- )
158
-
159
- with gr.Column(scale=1):
160
- k_input = gr.Slider(
161
- minimum=1,
162
- maximum=10,
163
- value=3,
164
- step=1,
165
- label="Top-K Results"
166
- )
167
-
168
- with gr.Row():
169
- btn = gr.Button("Submit", variant="primary")
170
- clear_btn = gr.Button("Clear", variant="secondary")
171
-
172
- output = gr.JSON(label="Search Results")
173
-
174
- # Event handlers
175
  btn.click(
176
- fn=call_search,
177
- inputs=[caption_input, k_input],
178
- outputs=output
179
- )
180
-
181
- clear_btn.click(
182
- fn=lambda: ("", 3, {}),
183
- outputs=[caption_input, k_input, output]
184
- )
185
-
186
- # Allow Enter key to submit
187
- caption_input.submit(
188
- fn=call_search,
189
- inputs=[caption_input, k_input],
190
- outputs=output
191
  )
192
 
193
  if __name__ == "__main__":
194
- demo.launch(
195
- server_name="0.0.0.0",
196
- server_port=7860,
197
- show_error=True
198
- )
 
1
+ # app.py
2
+
3
  import os
4
  import requests
5
  import gradio as gr
6
  import torch
7
  from transformers import CLIPProcessor, CLIPModel
 
 
 
 
 
8
 
9
+ # -----------------------------------------------------------------------------
10
+ # 1) Load CLIP text‐encoder locally (no GPU required for small demo)
11
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
  model.eval()
14
 
15
+
16
  def embed_text(text: str) -> list[float]:
17
+ """Turn a string into a normalized 512-dim CLIP vector."""
18
+ inputs = processor(
19
+ text=[text],
20
+ return_tensors="pt",
21
+ padding=True,
22
+ truncation=True,
23
+ )
24
+ with torch.no_grad():
25
+ feats = model.get_text_features(**inputs)
26
+ # normalize to unit length for cosine‐as‐inner‐product
27
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
28
+ return feats.squeeze().cpu().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # 2) Where’s your FastAPI service?
33
+ # In HF Space → Settings → Variables, set:
34
+ # API_URL = https://capstone-retrieval-api.onrender.com
35
  API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
36
 
37
+
38
  def call_search(caption: str, k: int):
39
+ """Encode `caption` POST to /search parse JSON return list of (img, caption)."""
40
+ if not caption:
41
+ return [], "Please enter a caption."
42
+
43
+ # 2a) embed locally
44
+ vec = embed_text(caption)
45
+ payload = {"query_vec": vec, "k": k}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
47
  try:
48
+ r = requests.post(f"{API_BASE}/search", json=payload, timeout=15)
49
+ r.raise_for_status()
50
+ data = r.json()
51
  except Exception as e:
52
+ # any network / HTTP error
53
+ return [], f"Error: {e!s}"
54
 
55
+ # 2b) build gallery list [ (path, label), ... ]
56
+ gallery_items = []
57
+ for rec in data.get("results", []):
58
+ path = rec["image_path"]
59
+ label = f"{rec['caption']} ({rec['score']:.3f})"
60
+ gallery_items.append((path, label))
61
+
62
+ return gallery_items, None
63
+
64
+
65
+ # -----------------------------------------------------------------------------
66
  # 3) Gradio UI
67
+ with gr.Blocks(title="Image ↔ Text Retrieval") as demo:
68
  gr.Markdown(
69
+ """
70
+ ## Image Text Retrieval
71
+ Type a caption, pick *k*, click **Submit** we embed your text with CLIP,
72
+ call your FastAPI + FAISS service, and show the top-K **images**.
73
+ """
74
  )
75
+
 
76
  with gr.Row():
77
+ caption_in = gr.Textbox(
78
+ label="Caption",
79
+ placeholder="e.g. painting of King Henry VIII carrying an umbrella",
80
+ lines=2,
81
  )
82
+ k_in = gr.Slider(
83
+ label="Top-K",
84
+ minimum=1, maximum=10, step=1, value=3
85
+ )
86
+
87
+ gallery = gr.Gallery(
88
+ label="Results",
89
+ show_label=False,
90
+ elem_id="result_gallery",
91
+ ).style(grid=[3], height="auto") # if this errors in your gradio version, just drop .style()
92
+
93
+ error_box = gr.Markdown(visible=False)
94
+
95
+ def _wrapped(caption, k):
96
+ imgs, err = call_search(caption, k)
97
+ if err:
98
+ return gr.update(visible=True, value=f"**{err}**"), []
99
+ return gr.update(visible=False), imgs
100
+
101
+ btn = gr.Button("Submit")
 
 
 
 
 
 
 
 
102
  btn.click(
103
+ fn=_wrapped,
104
+ inputs=[caption_in, k_in],
105
+ outputs=[error_box, gallery],
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
 
108
  if __name__ == "__main__":
109
+ # locally: python app.py → http://localhost:7860
110
+ demo.launch()