Update app.py
Browse files
app.py
CHANGED
|
@@ -39,27 +39,18 @@ if os.path.exists(GALLERY_FILE):
|
|
| 39 |
else:
|
| 40 |
raise FileNotFoundError("Gallery file missing!")
|
| 41 |
|
| 42 |
-
# --- BUILD GROUND TRUTH LOOKUP
|
| 43 |
-
# We map file_size -> (Species, ID) so we can identify images even if Gradio renames them.
|
| 44 |
GT_LOOKUP = {}
|
| 45 |
if os.path.exists(TEST_QUERIES_DIR):
|
| 46 |
-
print("Building Ground Truth Lookup Table...")
|
| 47 |
for f in os.listdir(TEST_QUERIES_DIR):
|
| 48 |
if f.lower().endswith(('.jpg', '.png', '.jpeg')):
|
| 49 |
full_path = os.path.join(TEST_QUERIES_DIR, f)
|
| 50 |
try:
|
| 51 |
-
# key = file size in bytes
|
| 52 |
f_size = os.path.getsize(full_path)
|
| 53 |
-
|
| 54 |
-
# Parse filename: "Species_ID_QUERY.jpg"
|
| 55 |
parts = f.split("_")
|
| 56 |
if len(parts) >= 2:
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
GT_LOOKUP[f_size] = (species, ind_id)
|
| 60 |
-
except Exception as e:
|
| 61 |
-
print(f"Skipping {f}: {e}")
|
| 62 |
-
print(f"Indexed {len(GT_LOOKUP)} test images.")
|
| 63 |
|
| 64 |
# Transform
|
| 65 |
transform = transforms.Compose([
|
|
@@ -120,24 +111,18 @@ def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
|
|
| 120 |
def predict(input_path):
|
| 121 |
if input_path is None: return "Upload an image!", None
|
| 122 |
|
| 123 |
-
# --- 0.
|
| 124 |
-
true_species = "Unknown"
|
| 125 |
-
true_id = "Unknown"
|
| 126 |
-
|
| 127 |
-
# Method A: Check File Size (Robust against renaming)
|
| 128 |
try:
|
| 129 |
input_size = os.path.getsize(input_path)
|
| 130 |
if input_size in GT_LOOKUP:
|
| 131 |
true_species, true_id = GT_LOOKUP[input_size]
|
| 132 |
else:
|
| 133 |
-
# Method B: Fallback to filename parsing
|
| 134 |
filename = os.path.basename(input_path)
|
| 135 |
if "_QUERY" in filename:
|
| 136 |
parts = filename.split("_")
|
| 137 |
-
true_species = parts[0]
|
| 138 |
-
|
| 139 |
-
except:
|
| 140 |
-
pass
|
| 141 |
|
| 142 |
# Load Image
|
| 143 |
input_image = Image.open(input_path).convert("RGB")
|
|
@@ -147,18 +132,38 @@ def predict(input_path):
|
|
| 147 |
with torch.no_grad():
|
| 148 |
q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
|
| 149 |
scores = torch.mm(q_emb, g_embeddings.t())
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
# --- 2.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
feats_q = extractor.extract(load_image(input_path).to(DEVICE))
|
| 154 |
|
| 155 |
log = "🔍 **Analysis Process:**\n"
|
| 156 |
best_score = -1
|
| 157 |
-
|
| 158 |
best_matches_info = None
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
path = g_paths[idx]
|
| 163 |
label = g_labels[idx]
|
| 164 |
species = g_species[idx]
|
|
@@ -173,29 +178,33 @@ def predict(input_path):
|
|
| 173 |
|
| 174 |
geo_matches = len(matches["matches"])
|
| 175 |
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
|
|
|
| 178 |
if geo_matches > best_score:
|
| 179 |
best_score = geo_matches
|
| 180 |
-
|
| 181 |
best_matches_info = (path, feats_q['keypoints'], feats_c['keypoints'], matches['matches'])
|
| 182 |
|
| 183 |
except Exception as e:
|
| 184 |
log += f"- Error: {e}\n"
|
| 185 |
|
| 186 |
-
# ---
|
| 187 |
CONFIDENCE_THRESHOLD = 15
|
|
|
|
| 188 |
|
| 189 |
-
if
|
| 190 |
-
pred_species = g_species[
|
| 191 |
-
pred_id = g_labels[
|
| 192 |
|
| 193 |
-
# Check correctness
|
| 194 |
is_correct = (pred_id == true_id)
|
| 195 |
-
|
| 196 |
-
# Logic: If Truth is Unknown, we can't say it's Incorrect.
|
| 197 |
if true_id == "Unknown":
|
| 198 |
-
header = f"# ❓ MATCH FOUND
|
| 199 |
elif is_correct:
|
| 200 |
header = f"# ✅ CORRECT MATCH!\n"
|
| 201 |
else:
|
|
@@ -203,7 +212,7 @@ def predict(input_path):
|
|
| 203 |
|
| 204 |
header += f"**Ground Truth:** {true_species} / {true_id}\n"
|
| 205 |
header += f"**Model Prediction:** {pred_species} / {pred_id}\n"
|
| 206 |
-
header += f"*(
|
| 207 |
|
| 208 |
winner_path, kpts0, kpts1, matches = best_matches_info
|
| 209 |
viz_path = create_match_visualization(input_path, winner_path, kpts0, kpts1, matches)
|
|
@@ -211,12 +220,10 @@ def predict(input_path):
|
|
| 211 |
else:
|
| 212 |
header = "# ⚠️ UNKNOWN / NO MATCH\n"
|
| 213 |
header += f"**Ground Truth:** {true_species} / {true_id}\n"
|
| 214 |
-
header += f"**Model Prediction:** None
|
| 215 |
-
|
| 216 |
if true_id == "Unknown" or true_species != "Unknown":
|
| 217 |
-
header += "\n*Model correctly rejected
|
| 218 |
-
|
| 219 |
-
|
| 220 |
return header + "\n\n" + log, viz_path
|
| 221 |
|
| 222 |
# --- UI SETUP ---
|
|
@@ -232,7 +239,7 @@ iface = gr.Interface(
|
|
| 232 |
gr.Image(label="Visualization")
|
| 233 |
],
|
| 234 |
title="Wildlife Re-ID: Coarse-to-Fine Demo",
|
| 235 |
-
description="Select a test image. The system
|
| 236 |
examples=examples_list,
|
| 237 |
cache_examples=False
|
| 238 |
)
|
|
|
|
| 39 |
else:
|
| 40 |
raise FileNotFoundError("Gallery file missing!")
|
| 41 |
|
| 42 |
+
# --- BUILD GROUND TRUTH LOOKUP ---
|
|
|
|
| 43 |
GT_LOOKUP = {}
|
| 44 |
if os.path.exists(TEST_QUERIES_DIR):
|
|
|
|
| 45 |
for f in os.listdir(TEST_QUERIES_DIR):
|
| 46 |
if f.lower().endswith(('.jpg', '.png', '.jpeg')):
|
| 47 |
full_path = os.path.join(TEST_QUERIES_DIR, f)
|
| 48 |
try:
|
|
|
|
| 49 |
f_size = os.path.getsize(full_path)
|
|
|
|
|
|
|
| 50 |
parts = f.split("_")
|
| 51 |
if len(parts) >= 2:
|
| 52 |
+
GT_LOOKUP[f_size] = (parts[0], parts[1])
|
| 53 |
+
except: pass
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
# Transform
|
| 56 |
transform = transforms.Compose([
|
|
|
|
| 111 |
def predict(input_path):
|
| 112 |
if input_path is None: return "Upload an image!", None
|
| 113 |
|
| 114 |
+
# --- 0. GROUND TRUTH ---
|
| 115 |
+
true_species, true_id = "Unknown", "Unknown"
|
|
|
|
|
|
|
|
|
|
| 116 |
try:
|
| 117 |
input_size = os.path.getsize(input_path)
|
| 118 |
if input_size in GT_LOOKUP:
|
| 119 |
true_species, true_id = GT_LOOKUP[input_size]
|
| 120 |
else:
|
|
|
|
| 121 |
filename = os.path.basename(input_path)
|
| 122 |
if "_QUERY" in filename:
|
| 123 |
parts = filename.split("_")
|
| 124 |
+
true_species, true_id = parts[0], parts[1]
|
| 125 |
+
except: pass
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# Load Image
|
| 128 |
input_image = Image.open(input_path).convert("RGB")
|
|
|
|
| 132 |
with torch.no_grad():
|
| 133 |
q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
|
| 134 |
scores = torch.mm(q_emb, g_embeddings.t())
|
| 135 |
+
|
| 136 |
+
# NEW: Fetch top 50 matches so we can filter duplicates
|
| 137 |
+
# We need enough candidates to find 3 unique individuals
|
| 138 |
+
top_scores, top_indices = torch.topk(scores, k=min(50, len(g_paths)))
|
| 139 |
|
| 140 |
+
# --- 2. UNIQUE CANDIDATE FILTERING ---
|
| 141 |
+
unique_candidates = []
|
| 142 |
+
seen_individuals = set()
|
| 143 |
+
|
| 144 |
+
# Loop through results until we have 3 unique people
|
| 145 |
+
for i in range(len(top_indices[0])):
|
| 146 |
+
if len(unique_candidates) >= 3: break
|
| 147 |
+
|
| 148 |
+
idx = top_indices[0][i].item()
|
| 149 |
+
score = top_scores[0][i].item() # This is the ArcFace Similarity!
|
| 150 |
+
label = g_labels[idx]
|
| 151 |
+
|
| 152 |
+
# If we haven't seen this individual yet, add them!
|
| 153 |
+
if label not in seen_individuals:
|
| 154 |
+
seen_individuals.add(label)
|
| 155 |
+
unique_candidates.append((idx, score))
|
| 156 |
+
|
| 157 |
+
# --- 3. FINE SEARCH (LightGlue) ---
|
| 158 |
feats_q = extractor.extract(load_image(input_path).to(DEVICE))
|
| 159 |
|
| 160 |
log = "🔍 **Analysis Process:**\n"
|
| 161 |
best_score = -1
|
| 162 |
+
best_candidate_idx = -1
|
| 163 |
best_matches_info = None
|
| 164 |
|
| 165 |
+
# Process only our 3 UNIQUE candidates
|
| 166 |
+
for rank, (idx, arcface_sim) in enumerate(unique_candidates):
|
| 167 |
path = g_paths[idx]
|
| 168 |
label = g_labels[idx]
|
| 169 |
species = g_species[idx]
|
|
|
|
| 178 |
|
| 179 |
geo_matches = len(matches["matches"])
|
| 180 |
|
| 181 |
+
# Format ArcFace score as percentage
|
| 182 |
+
sim_percent = arcface_sim * 100
|
| 183 |
+
|
| 184 |
+
log += f"- Candidate {rank+1}: **{species} / {label}**\n"
|
| 185 |
+
log += f" • 🧠 Coarse Confidence (ArcFace): **{sim_percent:.1f}%**\n"
|
| 186 |
+
log += f" • 📐 Geometric Matches (LightGlue): **{geo_matches}**\n\n"
|
| 187 |
|
| 188 |
+
# We still pick the winner based on LightGlue (Fine-grained)
|
| 189 |
if geo_matches > best_score:
|
| 190 |
best_score = geo_matches
|
| 191 |
+
best_candidate_idx = idx
|
| 192 |
best_matches_info = (path, feats_q['keypoints'], feats_c['keypoints'], matches['matches'])
|
| 193 |
|
| 194 |
except Exception as e:
|
| 195 |
log += f"- Error: {e}\n"
|
| 196 |
|
| 197 |
+
# --- 4. FINAL DECISION ---
|
| 198 |
CONFIDENCE_THRESHOLD = 15
|
| 199 |
+
viz_path = None
|
| 200 |
|
| 201 |
+
if best_candidate_idx != -1 and best_score > CONFIDENCE_THRESHOLD:
|
| 202 |
+
pred_species = g_species[best_candidate_idx]
|
| 203 |
+
pred_id = g_labels[best_candidate_idx]
|
| 204 |
|
|
|
|
| 205 |
is_correct = (pred_id == true_id)
|
|
|
|
|
|
|
| 206 |
if true_id == "Unknown":
|
| 207 |
+
header = f"# ❓ MATCH FOUND\n"
|
| 208 |
elif is_correct:
|
| 209 |
header = f"# ✅ CORRECT MATCH!\n"
|
| 210 |
else:
|
|
|
|
| 212 |
|
| 213 |
header += f"**Ground Truth:** {true_species} / {true_id}\n"
|
| 214 |
header += f"**Model Prediction:** {pred_species} / {pred_id}\n"
|
| 215 |
+
header += f"*(Confirmed with {best_score} geometric keypoints)*"
|
| 216 |
|
| 217 |
winner_path, kpts0, kpts1, matches = best_matches_info
|
| 218 |
viz_path = create_match_visualization(input_path, winner_path, kpts0, kpts1, matches)
|
|
|
|
| 220 |
else:
|
| 221 |
header = "# ⚠️ UNKNOWN / NO MATCH\n"
|
| 222 |
header += f"**Ground Truth:** {true_species} / {true_id}\n"
|
| 223 |
+
header += f"**Model Prediction:** None\n"
|
|
|
|
| 224 |
if true_id == "Unknown" or true_species != "Unknown":
|
| 225 |
+
header += "\n*Model correctly rejected non-matching candidates.*"
|
| 226 |
+
|
|
|
|
| 227 |
return header + "\n\n" + log, viz_path
|
| 228 |
|
| 229 |
# --- UI SETUP ---
|
|
|
|
| 239 |
gr.Image(label="Visualization")
|
| 240 |
],
|
| 241 |
title="Wildlife Re-ID: Coarse-to-Fine Demo",
|
| 242 |
+
description="Select a test image. The system finds the Top 3 UNIQUE individuals using embeddings, then verifies the best match using geometry.",
|
| 243 |
examples=examples_list,
|
| 244 |
cache_examples=False
|
| 245 |
)
|