Jaywalker061707 commited on
Commit
288963f
·
verified ·
1 Parent(s): b1440f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import torch
7
  from transformers import CLIPModel, CLIPProcessor
8
  import torch.nn.functional as F
 
9
 
10
  # ---------- utils ----------
11
  def flux_to_gray(flux_array):
@@ -89,6 +90,65 @@ def search(text_query, image_query, k=5):
89
  return items, f"Returned {k} results."
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # ---------- UI ----------
93
  with gr.Blocks() as demo:
94
  gr.Markdown("JWST multimodal search — build the index")
@@ -106,6 +166,64 @@ with gr.Blocks() as demo:
106
  k = gr.Slider(1, 12, value=6, step=1, label="Top K")
107
 
108
  search_btn = gr.Button("Search")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  gallery = gr.Gallery(label="Results", columns=6, height=300)
110
  info2 = gr.Textbox(label="Search status", lines=1)
111
 
 
6
  import torch
7
  from transformers import CLIPModel, CLIPProcessor
8
  import torch.nn.functional as F
9
+ import os, json, time
10
 
11
  # ---------- utils ----------
12
  def flux_to_gray(flux_array):
 
90
  return items, f"Returned {k} results."
91
 
92
 
93
+ # ---------- evaluation helpers ----------
94
+ def _search_topk_for_eval(text_query, image_query, k=5):
95
+ if INDEX["feats"] is None:
96
+ return [], [], "Build the index first."
97
+ with torch.no_grad():
98
+ if text_query and str(text_query).strip():
99
+ inputs = processor(text=[str(text_query).strip()], return_tensors="pt")
100
+ q = model.get_text_features(**inputs)
101
+ elif image_query is not None:
102
+ pil = image_query.convert("RGB")
103
+ inputs = processor(images=pil, return_tensors="pt")
104
+ q = model.get_image_features(**inputs)
105
+ else:
106
+ return [], [], "Enter text or upload an image."
107
+ q = F.normalize(q, p=2, dim=-1)[0]
108
+ sims = (INDEX["feats"] @ q).cpu()
109
+ k = min(int(k), sims.shape[0])
110
+ topk = torch.topk(sims, k=k)
111
+ idxs = topk.indices.tolist()
112
+ # reuse thumbs and captions like your main search
113
+ items = []
114
+ for idx in idxs:
115
+ cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}"
116
+ items.append((INDEX["thumbs"][idx], cap))
117
+ return items, idxs, f"Eval preview: top {k} ready."
118
+
119
+ def _format_eval_summary(query, k, hits, p_at_k):
120
+ lines = []
121
+ lines.append(f"Query: {query or '[image query]'}")
122
+ lines.append(f"K: {k}")
123
+ lines.append(f"Relevant marked: {hits} of {k}")
124
+ lines.append(f"Precision@{k}: {p_at_k:.2f}")
125
+ lines.append("Saved to eval_runs.jsonl")
126
+ return "\n".join(lines)
127
+
128
+ def _save_eval_run(record):
129
+ try:
130
+ with open("eval_runs.jsonl", "a", encoding="utf-8") as f:
131
+ f.write(json.dumps(record) + "\n")
132
+ except Exception:
133
+ pass
134
+
135
+ def _compute_avg_from_file():
136
+ try:
137
+ total = 0.0
138
+ n = 0
139
+ with open("eval_runs.jsonl", "r", encoding="utf-8") as f:
140
+ for line in f:
141
+ rec = json.loads(line)
142
+ if "precision_at_k" in rec:
143
+ total += float(rec["precision_at_k"])
144
+ n += 1
145
+ if n == 0:
146
+ return "No runs recorded yet."
147
+ return f"Macro average Precision@K across {n} runs: {total/n:.2f}"
148
+ except FileNotFoundError:
149
+ return "No eval_runs.jsonl yet. Run at least one evaluation."
150
+
151
+
152
  # ---------- UI ----------
153
  with gr.Blocks() as demo:
154
  gr.Markdown("JWST multimodal search — build the index")
 
166
  k = gr.Slider(1, 12, value=6, step=1, label="Top K")
167
 
168
  search_btn = gr.Button("Search")
169
+
170
+ # ---------- evaluation UI ----------
171
+ with gr.Accordion("Evaluation", open=False):
172
+ eval_query = gr.Textbox(label="Evaluation query", placeholder="Type a query or leave empty and upload an image")
173
+ eval_img = gr.Image(label="Evaluation image (optional)", type="pil")
174
+ eval_k = gr.Slider(1, 12, value=6, step=1, label="K for evaluation")
175
+
176
+ run_and_label = gr.Button("Run and label this query")
177
+
178
+ eval_gallery = gr.Gallery(label="Eval top K results", columns=6, height=300)
179
+ relevant_picker = gr.CheckboxGroup(label="Select indices of relevant results (1..K)")
180
+ eval_md = gr.Markdown()
181
+
182
+ eval_state = gr.State({"result_indices": [], "k": 5, "query": ""})
183
+
184
+ def _run_eval_query(q_txt, q_img_in, k_in, state):
185
+ items, idxs, _ = _search_topk_for_eval(q_txt, q_img_in, k_in)
186
+ state["result_indices"] = idxs
187
+ state["k"] = int(k_in)
188
+ state["query"] = q_txt if (q_txt and q_txt.strip()) else "[image query]"
189
+ choice_labels = [str(i+1) for i in range(len(idxs))]
190
+ return items, gr.update(choices=choice_labels, value=[]), "Mark relevant then click Compute metrics.", state
191
+
192
+ run_and_label.click(
193
+ fn=_run_eval_query,
194
+ inputs=[eval_query, eval_img, eval_k, eval_state],
195
+ outputs=[eval_gallery, relevant_picker, eval_md, eval_state]
196
+ )
197
+
198
+ compute_btn = gr.Button("Compute metrics")
199
+
200
+ def _compute_pk(selected_indices, state):
201
+ k = int(state.get("k", 5))
202
+ query = state.get("query", "")
203
+ # user marks which of the K are relevant; count is the hits
204
+ hits = len(selected_indices)
205
+ p_at_k = hits / max(k, 1)
206
+ record = {
207
+ "ts": int(time.time()),
208
+ "query": query,
209
+ "k": k,
210
+ "relevant_indices": sorted([int(s) for s in selected_indices]),
211
+ "precision_at_k": p_at_k
212
+ }
213
+ _save_eval_run(record)
214
+ return _format_eval_summary(query, k, hits, p_at_k)
215
+
216
+ compute_btn.click(
217
+ fn=_compute_pk,
218
+ inputs=[relevant_picker, eval_state],
219
+ outputs=eval_md
220
+ )
221
+
222
+ avg_btn = gr.Button("Compute average across saved runs")
223
+ avg_md = gr.Markdown()
224
+
225
+ avg_btn.click(fn=_compute_avg_from_file, outputs=avg_md)
226
+
227
  gallery = gr.Gallery(label="Results", columns=6, height=300)
228
  info2 = gr.Textbox(label="Search status", lines=1)
229