stephenebert commited on
Commit
6194281
·
verified ·
1 Parent(s): 1e33c60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -83
app.py CHANGED
@@ -1,110 +1,198 @@
1
- # app.py
2
-
3
  import os
4
  import requests
5
  import gradio as gr
6
  import torch
7
  from transformers import CLIPProcessor, CLIPModel
 
 
 
 
 
8
 
9
- # -----------------------------------------------------------------------------
10
- # 1) Load CLIP text‐encoder locally (no GPU required for small demo)
11
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
  model.eval()
14
 
15
-
16
  def embed_text(text: str) -> list[float]:
17
- """Turn a string into a normalized 512-dim CLIP vector."""
18
- inputs = processor(
19
- text=[text],
20
- return_tensors="pt",
21
- padding=True,
22
- truncation=True,
23
- )
24
- with torch.no_grad():
25
- feats = model.get_text_features(**inputs)
26
- # normalize to unit length for cosine‐as‐inner‐product
27
- feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
28
- return feats.squeeze().cpu().tolist()
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # -----------------------------------------------------------------------------
32
- # 2) Where’s your FastAPI service?
33
- # In HF Space → Settings → Variables, set:
34
- # API_URL = https://capstone-retrieval-api.onrender.com
35
  API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
36
 
37
-
38
  def call_search(caption: str, k: int):
39
- """Encode `caption` POST to /search parse JSON → return list of (img, caption)."""
40
- if not caption:
41
- return [], "Please enter a caption."
42
-
43
- # 2a) embed locally
44
- vec = embed_text(caption)
45
- payload = {"query_vec": vec, "k": k}
46
-
47
  try:
48
- r = requests.post(f"{API_BASE}/search", json=payload, timeout=15)
49
- r.raise_for_status()
50
- data = r.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except Exception as e:
52
- # any network / HTTP error
53
- return [], f"Error: {e!s}"
54
-
55
- # 2b) build gallery list [ (path, label), ... ]
56
- gallery_items = []
57
- for rec in data.get("results", []):
58
- path = rec["image_path"]
59
- label = f"{rec['caption']} ({rec['score']:.3f})"
60
- gallery_items.append((path, label))
61
-
62
- return gallery_items, None
63
 
 
 
 
 
 
 
 
64
 
65
- # -----------------------------------------------------------------------------
66
  # 3) Gradio UI
67
- with gr.Blocks(title="Image ↔ Text Retrieval") as demo:
68
  gr.Markdown(
69
- """
70
- ## Image Text Retrieval
71
- Type a caption, pick *k*, click **Submit** we embed your text with CLIP,
72
- call your FastAPI + FAISS service, and show the top-K **images**.
73
- """
74
  )
75
-
 
76
  with gr.Row():
77
- caption_in = gr.Textbox(
78
- label="Caption",
79
- placeholder="e.g. painting of King Henry VIII carrying an umbrella",
80
- lines=2,
81
  )
82
- k_in = gr.Slider(
83
- label="Top-K",
84
- minimum=1, maximum=10, step=1, value=3
85
- )
86
-
87
- gallery = gr.Gallery(
88
- label="Results",
89
- show_label=False,
90
- elem_id="result_gallery",
91
- ).style(grid=[3], height="auto") # if this errors in your gradio version, just drop .style()
92
-
93
- error_box = gr.Markdown(visible=False)
94
-
95
- def _wrapped(caption, k):
96
- imgs, err = call_search(caption, k)
97
- if err:
98
- return gr.update(visible=True, value=f"**{err}**"), []
99
- return gr.update(visible=False), imgs
100
-
101
- btn = gr.Button("Submit")
 
 
 
 
 
 
 
 
102
  btn.click(
103
- fn=_wrapped,
104
- inputs=[caption_in, k_in],
105
- outputs=[error_box, gallery],
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
 
108
  if __name__ == "__main__":
109
- # locally: python app.py → http://localhost:7860
110
- demo.launch()
 
 
 
 
 
 
1
  import os
2
  import requests
3
  import gradio as gr
4
  import torch
5
  from transformers import CLIPProcessor, CLIPModel
6
+ import logging
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ # 1) Load CLIP text encoder
 
13
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
  model.eval()
16
 
 
17
  def embed_text(text: str) -> list[float]:
18
+ """Turn a string into a normalized CLIP embedding."""
19
+ try:
20
+ # Clean and preprocess text
21
+ text = text.strip()
22
+ if not text:
23
+ raise ValueError("Empty text input")
24
+
25
+ # Tokenize with proper handling
26
+ inputs = processor(
27
+ text=[text],
28
+ return_tensors="pt",
29
+ padding=True,
30
+ truncation=True,
31
+ max_length=77 # CLIP's max token length
32
+ )
33
+
34
+ with torch.no_grad():
35
+ # Get text features
36
+ feats = model.get_text_features(**inputs)
37
+
38
+ # Normalize to unit vector (L2 normalization)
39
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
40
+
41
+ # Convert to list and ensure proper shape
42
+ embedding = feats.squeeze().cpu().tolist()
43
+
44
+ logger.info(f"Generated embedding with shape: {len(embedding)}")
45
+ return embedding
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error in embed_text: {str(e)}")
49
+ raise
50
 
51
+ # 2) API configuration
 
 
 
52
  API_BASE = os.getenv("API_URL", "https://capstone-retrieval-api.onrender.com").rstrip("/")
53
 
 
54
  def call_search(caption: str, k: int):
55
+ """Embed `caption`, POST to /search, return JSON (or error dict)."""
 
 
 
 
 
 
 
56
  try:
57
+ # Input validation
58
+ if not caption or not caption.strip():
59
+ return {"error": "Please enter a caption to search."}
60
+
61
+ caption = caption.strip()
62
+ k = max(1, min(int(k), 10)) # Clamp k between 1 and 10
63
+
64
+ logger.info(f"Searching for: '{caption}' with k={k}")
65
+
66
+ # 1) Embed locally
67
+ vec = embed_text(caption)
68
+
69
+ # Verify embedding dimensions
70
+ if len(vec) != 512:
71
+ return {"error": f"Unexpected embedding dimension: {len(vec)} (expected 512)"}
72
+
73
+ payload = {
74
+ "query_vec": vec,
75
+ "k": k,
76
+ "query_text": caption # Include original text for debugging
77
+ }
78
+
79
+ # 2) POST to API
80
+ headers = {
81
+ "Content-Type": "application/json",
82
+ "User-Agent": "HuggingFace-Gradio-Client"
83
+ }
84
+
85
+ response = requests.post(
86
+ f"{API_BASE}/search",
87
+ json=payload,
88
+ headers=headers,
89
+ timeout=30 # Increased timeout
90
+ )
91
+
92
+ response.raise_for_status()
93
+ result = response.json()
94
+
95
+ logger.info(f"API response status: {response.status_code}")
96
+
97
+ # Add metadata to result
98
+ if isinstance(result, dict):
99
+ result["_metadata"] = {
100
+ "query": caption,
101
+ "k": k,
102
+ "embedding_dim": len(vec),
103
+ "api_status": response.status_code
104
+ }
105
+
106
+ return result
107
+
108
+ except requests.exceptions.Timeout:
109
+ return {"error": "Request timed out. Please try again."}
110
+ except requests.exceptions.ConnectionError:
111
+ return {"error": "Could not connect to the API. Please check your internet connection."}
112
+ except requests.exceptions.HTTPError as e:
113
+ error_msg = f"HTTP {response.status_code}"
114
+ try:
115
+ error_detail = response.json().get("detail", response.text)
116
+ error_msg += f": {error_detail}"
117
+ except:
118
+ error_msg += f": {response.text}"
119
+ return {"error": error_msg}
120
  except Exception as e:
121
+ logger.error(f"Unexpected error in call_search: {str(e)}")
122
+ return {"error": f"Unexpected error: {str(e)}"}
 
 
 
 
 
 
 
 
 
123
 
124
+ def validate_api_connection():
125
+ """Test API connection and return status."""
126
+ try:
127
+ response = requests.get(f"{API_BASE}/health", timeout=10)
128
+ return f"API is reachable (Status: {response.status_code})"
129
+ except Exception as e:
130
+ return f"API connection failed: {str(e)}"
131
 
 
132
  # 3) Gradio UI
133
+ with gr.Blocks(title="Image ↔ Text Retrieval (small dataset)", theme=gr.themes.Soft()) as demo:
134
  gr.Markdown(
135
+ "### Image ↔ Text Retrieval (small dataset)\n"
136
+ "Type a caption, pick *k*, click **Submit** – we encode your text with CLIP, "
137
+ "POST it to your FastAPI+FAISS service, and show the top-K JSON results."
 
 
138
  )
139
+
140
+ # API status indicator
141
  with gr.Row():
142
+ api_status = gr.Textbox(
143
+ value=validate_api_connection(),
144
+ label="API Status",
145
+ interactive=False
146
  )
147
+ refresh_btn = gr.Button("Refresh Status", size="sm")
148
+ refresh_btn.click(fn=validate_api_connection, outputs=api_status)
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=2):
152
+ caption_input = gr.Textbox(
153
+ lines=3,
154
+ placeholder="type something",
155
+ label="Caption",
156
+ info="Enter a descriptive text to search for similar images"
157
+ )
158
+
159
+ with gr.Column(scale=1):
160
+ k_input = gr.Slider(
161
+ minimum=1,
162
+ maximum=10,
163
+ value=3,
164
+ step=1,
165
+ label="Top-K Results"
166
+ )
167
+
168
+ with gr.Row():
169
+ btn = gr.Button("Submit", variant="primary")
170
+ clear_btn = gr.Button("Clear", variant="secondary")
171
+
172
+ output = gr.JSON(label="Search Results")
173
+
174
+ # Event handlers
175
  btn.click(
176
+ fn=call_search,
177
+ inputs=[caption_input, k_input],
178
+ outputs=output
179
+ )
180
+
181
+ clear_btn.click(
182
+ fn=lambda: ("", 3, {}),
183
+ outputs=[caption_input, k_input, output]
184
+ )
185
+
186
+ # Allow Enter key to submit
187
+ caption_input.submit(
188
+ fn=call_search,
189
+ inputs=[caption_input, k_input],
190
+ outputs=output
191
  )
192
 
193
  if __name__ == "__main__":
194
+ demo.launch(
195
+ server_name="0.0.0.0",
196
+ server_port=7860,
197
+ show_error=True
198
+ )