Update to app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,7 @@ from lightglue.utils import load_image, rbd
|
|
| 14 |
# --- CONFIG ---
|
| 15 |
WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth"
|
| 16 |
GALLERY_FILE = "mini_gallery.pt"
|
| 17 |
-
TEST_QUERIES_DIR = "test_queries"
|
| 18 |
IMG_SIZE = 384
|
| 19 |
EMBEDDING_DIM = 512
|
| 20 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -48,6 +48,173 @@ transform = transforms.Compose([
|
|
| 48 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 49 |
])
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
|
| 52 |
"""
|
| 53 |
Creates a side-by-side visualization with lines connecting matched keypoints.
|
|
|
|
| 14 |
# --- CONFIG ---
|
| 15 |
WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth"
|
| 16 |
GALLERY_FILE = "mini_gallery.pt"
|
| 17 |
+
TEST_QUERIES_DIR = "test_queries"
|
| 18 |
IMG_SIZE = 384
|
| 19 |
EMBEDDING_DIM = 512
|
| 20 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 48 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 49 |
])
|
| 50 |
|
| 51 |
+
def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
|
| 52 |
+
"""
|
| 53 |
+
Creates a side-by-side visualization with lines connecting matched keypoints.
|
| 54 |
+
"""
|
| 55 |
+
# FIX: Remove batch dimension if present (Shape [1, N, 2] -> [N, 2])
|
| 56 |
+
if kpts0.dim() == 3: kpts0 = kpts0[0]
|
| 57 |
+
if kpts1.dim() == 3: kpts1 = kpts1[0]
|
| 58 |
+
|
| 59 |
+
# Load images
|
| 60 |
+
img1 = Image.open(img_path1).convert("RGB")
|
| 61 |
+
img2 = Image.open(img_path2).convert("RGB")
|
| 62 |
+
|
| 63 |
+
# Resize for display consistency (optional, but helps layout)
|
| 64 |
+
# We resize primarily to make the plot look neat side-by-side
|
| 65 |
+
target_h = 400
|
| 66 |
+
w1, h1 = img1.size
|
| 67 |
+
w2, h2 = img2.size
|
| 68 |
+
scale1 = target_h / h1
|
| 69 |
+
scale2 = target_h / h2
|
| 70 |
+
|
| 71 |
+
img1 = img1.resize((int(w1 * scale1), target_h))
|
| 72 |
+
img2 = img2.resize((int(w2 * scale2), target_h))
|
| 73 |
+
|
| 74 |
+
# Create plot
|
| 75 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
|
| 76 |
+
ax.axis('off')
|
| 77 |
+
|
| 78 |
+
# Concatenate images side-by-side
|
| 79 |
+
img1_np = np.array(img1)
|
| 80 |
+
img2_np = np.array(img2)
|
| 81 |
+
h1_new, w1_new, _ = img1_np.shape
|
| 82 |
+
h2_new, w2_new, _ = img2_np.shape
|
| 83 |
+
|
| 84 |
+
# Create canvas
|
| 85 |
+
width = w1_new + w2_new
|
| 86 |
+
canvas = np.zeros((target_h, width, 3), dtype=np.uint8)
|
| 87 |
+
canvas[:, :w1_new, :] = img1_np
|
| 88 |
+
canvas[:, w1_new:, :] = img2_np
|
| 89 |
+
|
| 90 |
+
ax.imshow(canvas)
|
| 91 |
+
|
| 92 |
+
# Extract matched points
|
| 93 |
+
# Now that we squeezed dimensions at the top, this indexing works!
|
| 94 |
+
m_kpts0 = kpts0[matches[..., 0]].cpu().numpy()
|
| 95 |
+
m_kpts1 = kpts1[matches[..., 1]].cpu().numpy()
|
| 96 |
+
|
| 97 |
+
# Scale points to match the resized images we are displaying
|
| 98 |
+
m_kpts0[:, 0] *= scale1
|
| 99 |
+
m_kpts0[:, 1] *= scale1
|
| 100 |
+
m_kpts1[:, 0] *= scale2
|
| 101 |
+
m_kpts1[:, 1] *= scale2
|
| 102 |
+
|
| 103 |
+
# Plot lines
|
| 104 |
+
# Shift x-coordinates of the second image by the width of the first
|
| 105 |
+
for (x0, y0), (x1, y1) in zip(m_kpts0, m_kpts1):
|
| 106 |
+
ax.plot([x0, x1 + w1_new], [y0, y1], color="lime", linewidth=0.5, alpha=0.7)
|
| 107 |
+
ax.scatter([x0, x1 + w1_new], [y0, y1], color="lime", s=2)
|
| 108 |
+
|
| 109 |
+
plt.tight_layout()
|
| 110 |
+
|
| 111 |
+
# Save to buffer
|
| 112 |
+
output_path = "match_viz_result.jpg"
|
| 113 |
+
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
|
| 114 |
+
plt.close()
|
| 115 |
+
|
| 116 |
+
return output_path
|
| 117 |
+
|
| 118 |
+
def predict(input_image):
|
| 119 |
+
if input_image is None: return "Upload an image!", None
|
| 120 |
+
|
| 121 |
+
# Save input temporarily for LightGlue
|
| 122 |
+
input_path = "temp_query.jpg"
|
| 123 |
+
input_image.save(input_path)
|
| 124 |
+
|
| 125 |
+
# 1. COARSE SEARCH (ArcFace)
|
| 126 |
+
img_t = transform(input_image).unsqueeze(0).to(DEVICE)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
|
| 129 |
+
scores = torch.mm(q_emb, g_embeddings.t())
|
| 130 |
+
# Get Top 3
|
| 131 |
+
top_scores, top_indices = torch.topk(scores, k=min(3, len(g_paths)))
|
| 132 |
+
|
| 133 |
+
# 2. FINE SEARCH (LightGlue)
|
| 134 |
+
feats_q = extractor.extract(load_image(input_path).to(DEVICE))
|
| 135 |
+
|
| 136 |
+
candidates = []
|
| 137 |
+
log = "🔍 **Analysis Process:**\n"
|
| 138 |
+
|
| 139 |
+
best_score = -1
|
| 140 |
+
best_idx = -1
|
| 141 |
+
best_matches_info = None
|
| 142 |
+
|
| 143 |
+
for rank, idx in enumerate(top_indices[0]):
|
| 144 |
+
idx = idx.item()
|
| 145 |
+
path = g_paths[idx]
|
| 146 |
+
label = g_labels[idx]
|
| 147 |
+
species = g_species[idx]
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
if not os.path.exists(path):
|
| 151 |
+
log += f"- ⚠️ File missing: {path}\n"
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
feats_c = extractor.extract(load_image(path).to(DEVICE))
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
matches = matcher({"image0": feats_q, "image1": feats_c})
|
| 157 |
+
matches = rbd(matches) # remove batch dim
|
| 158 |
+
|
| 159 |
+
geo_matches = len(matches["matches"])
|
| 160 |
+
|
| 161 |
+
log += f"- Candidate {rank+1}: **{species} / {label}** | Keypoints: {geo_matches}\n"
|
| 162 |
+
|
| 163 |
+
if geo_matches > best_score:
|
| 164 |
+
best_score = geo_matches
|
| 165 |
+
best_idx = idx
|
| 166 |
+
# Save visualization data
|
| 167 |
+
best_matches_info = (path, feats_q['keypoints'], feats_c['keypoints'], matches['matches'])
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
log += f"- Error processing {path}: {e}\n"
|
| 171 |
+
|
| 172 |
+
# 3. DECISION & VISUALIZATION
|
| 173 |
+
if best_idx != -1:
|
| 174 |
+
winner_path, kpts0, kpts1, matches = best_matches_info
|
| 175 |
+
winner_label = g_labels[best_idx]
|
| 176 |
+
winner_species = g_species[best_idx]
|
| 177 |
+
|
| 178 |
+
# GENERATE VISUALIZATION
|
| 179 |
+
viz_image_path = create_match_visualization(input_path, winner_path, kpts0, kpts1, matches)
|
| 180 |
+
|
| 181 |
+
result = f"✅ **MATCH FOUND**\nSpecies: {winner_species}\nIndividual: {winner_label}\n(Verified with {best_score} geometric matches)"
|
| 182 |
+
return result + "\n\n" + log, viz_image_path
|
| 183 |
+
else:
|
| 184 |
+
return "⚠️ No confident match found.\n\n" + log, None
|
| 185 |
+
|
| 186 |
+
# --- UI SETUP ---
|
| 187 |
+
examples_list = []
|
| 188 |
+
if os.path.exists(TEST_QUERIES_DIR):
|
| 189 |
+
examples_list = [os.path.join(TEST_QUERIES_DIR, f) for f in os.listdir(TEST_QUERIES_DIR) if f.lower().endswith(('.jpg', '.png'))]
|
| 190 |
+
|
| 191 |
+
iface = gr.Interface(
|
| 192 |
+
fn=predict,
|
| 193 |
+
inputs=gr.Image(type="pil", label="Test Image"),
|
| 194 |
+
outputs=[
|
| 195 |
+
gr.Markdown(label="Result"),
|
| 196 |
+
gr.Image(label="Feature Matching Visualization")
|
| 197 |
+
],
|
| 198 |
+
title="Wildlife Re-ID: Coarse-to-Fine Demo",
|
| 199 |
+
description="Click a test image below to run the identification.",
|
| 200 |
+
examples=examples_list,
|
| 201 |
+
cache_examples=False
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
iface.launch() g_embeddings = data["embeddings"].to(DEVICE)
|
| 205 |
+
g_paths = data["paths"]
|
| 206 |
+
g_labels = data["labels"]
|
| 207 |
+
g_species = data["species"]
|
| 208 |
+
else:
|
| 209 |
+
raise FileNotFoundError("Gallery file missing!")
|
| 210 |
+
|
| 211 |
+
# Transform
|
| 212 |
+
transform = transforms.Compose([
|
| 213 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
| 214 |
+
transforms.ToTensor(),
|
| 215 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 216 |
+
])
|
| 217 |
+
|
| 218 |
def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
|
| 219 |
"""
|
| 220 |
Creates a side-by-side visualization with lines connecting matched keypoints.
|