SalHargis commited on
Commit
18d9fa9
·
verified ·
1 Parent(s): 75bb77a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -44
app.py CHANGED
@@ -39,27 +39,18 @@ if os.path.exists(GALLERY_FILE):
39
  else:
40
  raise FileNotFoundError("Gallery file missing!")
41
 
42
- # --- BUILD GROUND TRUTH LOOKUP (The Fix) ---
43
- # We map file_size -> (Species, ID) so we can identify images even if Gradio renames them.
44
  GT_LOOKUP = {}
45
  if os.path.exists(TEST_QUERIES_DIR):
46
- print("Building Ground Truth Lookup Table...")
47
  for f in os.listdir(TEST_QUERIES_DIR):
48
  if f.lower().endswith(('.jpg', '.png', '.jpeg')):
49
  full_path = os.path.join(TEST_QUERIES_DIR, f)
50
  try:
51
- # key = file size in bytes
52
  f_size = os.path.getsize(full_path)
53
-
54
- # Parse filename: "Species_ID_QUERY.jpg"
55
  parts = f.split("_")
56
  if len(parts) >= 2:
57
- species = parts[0]
58
- ind_id = parts[1]
59
- GT_LOOKUP[f_size] = (species, ind_id)
60
- except Exception as e:
61
- print(f"Skipping {f}: {e}")
62
- print(f"Indexed {len(GT_LOOKUP)} test images.")
63
 
64
  # Transform
65
  transform = transforms.Compose([
@@ -120,24 +111,18 @@ def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
120
  def predict(input_path):
121
  if input_path is None: return "Upload an image!", None
122
 
123
- # --- 0. IDENTIFY GROUND TRUTH (ROBUST) ---
124
- true_species = "Unknown"
125
- true_id = "Unknown"
126
-
127
- # Method A: Check File Size (Robust against renaming)
128
  try:
129
  input_size = os.path.getsize(input_path)
130
  if input_size in GT_LOOKUP:
131
  true_species, true_id = GT_LOOKUP[input_size]
132
  else:
133
- # Method B: Fallback to filename parsing
134
  filename = os.path.basename(input_path)
135
  if "_QUERY" in filename:
136
  parts = filename.split("_")
137
- true_species = parts[0]
138
- true_id = parts[1]
139
- except:
140
- pass
141
 
142
  # Load Image
143
  input_image = Image.open(input_path).convert("RGB")
@@ -147,18 +132,38 @@ def predict(input_path):
147
  with torch.no_grad():
148
  q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
149
  scores = torch.mm(q_emb, g_embeddings.t())
150
- top_scores, top_indices = torch.topk(scores, k=min(3, len(g_paths)))
 
 
 
151
 
152
- # --- 2. FINE SEARCH (LightGlue) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  feats_q = extractor.extract(load_image(input_path).to(DEVICE))
154
 
155
  log = "🔍 **Analysis Process:**\n"
156
  best_score = -1
157
- best_idx = -1
158
  best_matches_info = None
159
 
160
- for rank, idx in enumerate(top_indices[0]):
161
- idx = idx.item()
162
  path = g_paths[idx]
163
  label = g_labels[idx]
164
  species = g_species[idx]
@@ -173,29 +178,33 @@ def predict(input_path):
173
 
174
  geo_matches = len(matches["matches"])
175
 
176
- log += f"- Candidate {rank+1}: **{species} / {label}** | Keypoints: {geo_matches}\n"
 
 
 
 
 
177
 
 
178
  if geo_matches > best_score:
179
  best_score = geo_matches
180
- best_idx = idx
181
  best_matches_info = (path, feats_q['keypoints'], feats_c['keypoints'], matches['matches'])
182
 
183
  except Exception as e:
184
  log += f"- Error: {e}\n"
185
 
186
- # --- 3. FINAL DECISION ---
187
  CONFIDENCE_THRESHOLD = 15
 
188
 
189
- if best_idx != -1 and best_score > CONFIDENCE_THRESHOLD:
190
- pred_species = g_species[best_idx]
191
- pred_id = g_labels[best_idx]
192
 
193
- # Check correctness
194
  is_correct = (pred_id == true_id)
195
-
196
- # Logic: If Truth is Unknown, we can't say it's Incorrect.
197
  if true_id == "Unknown":
198
- header = f"# ❓ MATCH FOUND (No Ground Truth)\n"
199
  elif is_correct:
200
  header = f"# ✅ CORRECT MATCH!\n"
201
  else:
@@ -203,7 +212,7 @@ def predict(input_path):
203
 
204
  header += f"**Ground Truth:** {true_species} / {true_id}\n"
205
  header += f"**Model Prediction:** {pred_species} / {pred_id}\n"
206
- header += f"*(Confidence: {best_score} geometric keypoints)*"
207
 
208
  winner_path, kpts0, kpts1, matches = best_matches_info
209
  viz_path = create_match_visualization(input_path, winner_path, kpts0, kpts1, matches)
@@ -211,12 +220,10 @@ def predict(input_path):
211
  else:
212
  header = "# ⚠️ UNKNOWN / NO MATCH\n"
213
  header += f"**Ground Truth:** {true_species} / {true_id}\n"
214
- header += f"**Model Prediction:** None (Best candidate only had {best_score} matches)\n"
215
-
216
  if true_id == "Unknown" or true_species != "Unknown":
217
- header += "\n*Model correctly rejected a non-matching image!*"
218
- viz_path = None
219
-
220
  return header + "\n\n" + log, viz_path
221
 
222
  # --- UI SETUP ---
@@ -232,7 +239,7 @@ iface = gr.Interface(
232
  gr.Image(label="Visualization")
233
  ],
234
  title="Wildlife Re-ID: Coarse-to-Fine Demo",
235
- description="Select a test image. The system will reveal the Ground Truth (hidden in filename) and compare it with the Model's Prediction.",
236
  examples=examples_list,
237
  cache_examples=False
238
  )
 
39
  else:
40
  raise FileNotFoundError("Gallery file missing!")
41
 
42
+ # --- BUILD GROUND TRUTH LOOKUP ---
 
43
  GT_LOOKUP = {}
44
  if os.path.exists(TEST_QUERIES_DIR):
 
45
  for f in os.listdir(TEST_QUERIES_DIR):
46
  if f.lower().endswith(('.jpg', '.png', '.jpeg')):
47
  full_path = os.path.join(TEST_QUERIES_DIR, f)
48
  try:
 
49
  f_size = os.path.getsize(full_path)
 
 
50
  parts = f.split("_")
51
  if len(parts) >= 2:
52
+ GT_LOOKUP[f_size] = (parts[0], parts[1])
53
+ except: pass
 
 
 
 
54
 
55
  # Transform
56
  transform = transforms.Compose([
 
111
  def predict(input_path):
112
  if input_path is None: return "Upload an image!", None
113
 
114
+ # --- 0. GROUND TRUTH ---
115
+ true_species, true_id = "Unknown", "Unknown"
 
 
 
116
  try:
117
  input_size = os.path.getsize(input_path)
118
  if input_size in GT_LOOKUP:
119
  true_species, true_id = GT_LOOKUP[input_size]
120
  else:
 
121
  filename = os.path.basename(input_path)
122
  if "_QUERY" in filename:
123
  parts = filename.split("_")
124
+ true_species, true_id = parts[0], parts[1]
125
+ except: pass
 
 
126
 
127
  # Load Image
128
  input_image = Image.open(input_path).convert("RGB")
 
132
  with torch.no_grad():
133
  q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
134
  scores = torch.mm(q_emb, g_embeddings.t())
135
+
136
+ # NEW: Fetch top 50 matches so we can filter duplicates
137
+ # We need enough candidates to find 3 unique individuals
138
+ top_scores, top_indices = torch.topk(scores, k=min(50, len(g_paths)))
139
 
140
+ # --- 2. UNIQUE CANDIDATE FILTERING ---
141
+ unique_candidates = []
142
+ seen_individuals = set()
143
+
144
+ # Loop through results until we have 3 unique people
145
+ for i in range(len(top_indices[0])):
146
+ if len(unique_candidates) >= 3: break
147
+
148
+ idx = top_indices[0][i].item()
149
+ score = top_scores[0][i].item() # This is the ArcFace Similarity!
150
+ label = g_labels[idx]
151
+
152
+ # If we haven't seen this individual yet, add them!
153
+ if label not in seen_individuals:
154
+ seen_individuals.add(label)
155
+ unique_candidates.append((idx, score))
156
+
157
+ # --- 3. FINE SEARCH (LightGlue) ---
158
  feats_q = extractor.extract(load_image(input_path).to(DEVICE))
159
 
160
  log = "🔍 **Analysis Process:**\n"
161
  best_score = -1
162
+ best_candidate_idx = -1
163
  best_matches_info = None
164
 
165
+ # Process only our 3 UNIQUE candidates
166
+ for rank, (idx, arcface_sim) in enumerate(unique_candidates):
167
  path = g_paths[idx]
168
  label = g_labels[idx]
169
  species = g_species[idx]
 
178
 
179
  geo_matches = len(matches["matches"])
180
 
181
+ # Format ArcFace score as percentage
182
+ sim_percent = arcface_sim * 100
183
+
184
+ log += f"- Candidate {rank+1}: **{species} / {label}**\n"
185
+ log += f" • 🧠 Coarse Confidence (ArcFace): **{sim_percent:.1f}%**\n"
186
+ log += f" • 📐 Geometric Matches (LightGlue): **{geo_matches}**\n\n"
187
 
188
+ # We still pick the winner based on LightGlue (Fine-grained)
189
  if geo_matches > best_score:
190
  best_score = geo_matches
191
+ best_candidate_idx = idx
192
  best_matches_info = (path, feats_q['keypoints'], feats_c['keypoints'], matches['matches'])
193
 
194
  except Exception as e:
195
  log += f"- Error: {e}\n"
196
 
197
+ # --- 4. FINAL DECISION ---
198
  CONFIDENCE_THRESHOLD = 15
199
+ viz_path = None
200
 
201
+ if best_candidate_idx != -1 and best_score > CONFIDENCE_THRESHOLD:
202
+ pred_species = g_species[best_candidate_idx]
203
+ pred_id = g_labels[best_candidate_idx]
204
 
 
205
  is_correct = (pred_id == true_id)
 
 
206
  if true_id == "Unknown":
207
+ header = f"# ❓ MATCH FOUND\n"
208
  elif is_correct:
209
  header = f"# ✅ CORRECT MATCH!\n"
210
  else:
 
212
 
213
  header += f"**Ground Truth:** {true_species} / {true_id}\n"
214
  header += f"**Model Prediction:** {pred_species} / {pred_id}\n"
215
+ header += f"*(Confirmed with {best_score} geometric keypoints)*"
216
 
217
  winner_path, kpts0, kpts1, matches = best_matches_info
218
  viz_path = create_match_visualization(input_path, winner_path, kpts0, kpts1, matches)
 
220
  else:
221
  header = "# ⚠️ UNKNOWN / NO MATCH\n"
222
  header += f"**Ground Truth:** {true_species} / {true_id}\n"
223
+ header += f"**Model Prediction:** None\n"
 
224
  if true_id == "Unknown" or true_species != "Unknown":
225
+ header += "\n*Model correctly rejected non-matching candidates.*"
226
+
 
227
  return header + "\n\n" + log, viz_path
228
 
229
  # --- UI SETUP ---
 
239
  gr.Image(label="Visualization")
240
  ],
241
  title="Wildlife Re-ID: Coarse-to-Fine Demo",
242
+ description="Select a test image. The system finds the Top 3 UNIQUE individuals using embeddings, then verifies the best match using geometry.",
243
  examples=examples_list,
244
  cache_examples=False
245
  )