SalHargis commited on
Commit
d5e7e28
·
verified ·
1 Parent(s): 8125f92

Update to app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -1
app.py CHANGED
@@ -14,7 +14,7 @@ from lightglue.utils import load_image, rbd
14
  # --- CONFIG ---
15
  WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth"
16
  GALLERY_FILE = "mini_gallery.pt"
17
- TEST_QUERIES_DIR = "test_queries" # Folder containing your example images
18
  IMG_SIZE = 384
19
  EMBEDDING_DIM = 512
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -48,6 +48,173 @@ transform = transforms.Compose([
48
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
49
  ])
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
52
  """
53
  Creates a side-by-side visualization with lines connecting matched keypoints.
 
14
  # --- CONFIG ---
15
  WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth"
16
  GALLERY_FILE = "mini_gallery.pt"
17
+ TEST_QUERIES_DIR = "test_queries"
18
  IMG_SIZE = 384
19
  EMBEDDING_DIM = 512
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
48
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
49
  ])
50
 
51
+ def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
52
+ """
53
+ Creates a side-by-side visualization with lines connecting matched keypoints.
54
+ """
55
+ # FIX: Remove batch dimension if present (Shape [1, N, 2] -> [N, 2])
56
+ if kpts0.dim() == 3: kpts0 = kpts0[0]
57
+ if kpts1.dim() == 3: kpts1 = kpts1[0]
58
+
59
+ # Load images
60
+ img1 = Image.open(img_path1).convert("RGB")
61
+ img2 = Image.open(img_path2).convert("RGB")
62
+
63
+ # Resize for display consistency (optional, but helps layout)
64
+ # We resize primarily to make the plot look neat side-by-side
65
+ target_h = 400
66
+ w1, h1 = img1.size
67
+ w2, h2 = img2.size
68
+ scale1 = target_h / h1
69
+ scale2 = target_h / h2
70
+
71
+ img1 = img1.resize((int(w1 * scale1), target_h))
72
+ img2 = img2.resize((int(w2 * scale2), target_h))
73
+
74
+ # Create plot
75
+ fig, ax = plt.subplots(1, 1, figsize=(10, 5))
76
+ ax.axis('off')
77
+
78
+ # Concatenate images side-by-side
79
+ img1_np = np.array(img1)
80
+ img2_np = np.array(img2)
81
+ h1_new, w1_new, _ = img1_np.shape
82
+ h2_new, w2_new, _ = img2_np.shape
83
+
84
+ # Create canvas
85
+ width = w1_new + w2_new
86
+ canvas = np.zeros((target_h, width, 3), dtype=np.uint8)
87
+ canvas[:, :w1_new, :] = img1_np
88
+ canvas[:, w1_new:, :] = img2_np
89
+
90
+ ax.imshow(canvas)
91
+
92
+ # Extract matched points
93
+ # Now that we squeezed dimensions at the top, this indexing works!
94
+ m_kpts0 = kpts0[matches[..., 0]].cpu().numpy()
95
+ m_kpts1 = kpts1[matches[..., 1]].cpu().numpy()
96
+
97
+ # Scale points to match the resized images we are displaying
98
+ m_kpts0[:, 0] *= scale1
99
+ m_kpts0[:, 1] *= scale1
100
+ m_kpts1[:, 0] *= scale2
101
+ m_kpts1[:, 1] *= scale2
102
+
103
+ # Plot lines
104
+ # Shift x-coordinates of the second image by the width of the first
105
+ for (x0, y0), (x1, y1) in zip(m_kpts0, m_kpts1):
106
+ ax.plot([x0, x1 + w1_new], [y0, y1], color="lime", linewidth=0.5, alpha=0.7)
107
+ ax.scatter([x0, x1 + w1_new], [y0, y1], color="lime", s=2)
108
+
109
+ plt.tight_layout()
110
+
111
+ # Save to buffer
112
+ output_path = "match_viz_result.jpg"
113
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
114
+ plt.close()
115
+
116
+ return output_path
117
+
118
+ def predict(input_image):
119
+ if input_image is None: return "Upload an image!", None
120
+
121
+ # Save input temporarily for LightGlue
122
+ input_path = "temp_query.jpg"
123
+ input_image.save(input_path)
124
+
125
+ # 1. COARSE SEARCH (ArcFace)
126
+ img_t = transform(input_image).unsqueeze(0).to(DEVICE)
127
+ with torch.no_grad():
128
+ q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
129
+ scores = torch.mm(q_emb, g_embeddings.t())
130
+ # Get Top 3
131
+ top_scores, top_indices = torch.topk(scores, k=min(3, len(g_paths)))
132
+
133
+ # 2. FINE SEARCH (LightGlue)
134
+ feats_q = extractor.extract(load_image(input_path).to(DEVICE))
135
+
136
+ candidates = []
137
+ log = "🔍 **Analysis Process:**\n"
138
+
139
+ best_score = -1
140
+ best_idx = -1
141
+ best_matches_info = None
142
+
143
+ for rank, idx in enumerate(top_indices[0]):
144
+ idx = idx.item()
145
+ path = g_paths[idx]
146
+ label = g_labels[idx]
147
+ species = g_species[idx]
148
+
149
+ try:
150
+ if not os.path.exists(path):
151
+ log += f"- ⚠️ File missing: {path}\n"
152
+ continue
153
+
154
+ feats_c = extractor.extract(load_image(path).to(DEVICE))
155
+ with torch.no_grad():
156
+ matches = matcher({"image0": feats_q, "image1": feats_c})
157
+ matches = rbd(matches) # remove batch dim
158
+
159
+ geo_matches = len(matches["matches"])
160
+
161
+ log += f"- Candidate {rank+1}: **{species} / {label}** | Keypoints: {geo_matches}\n"
162
+
163
+ if geo_matches > best_score:
164
+ best_score = geo_matches
165
+ best_idx = idx
166
+ # Save visualization data
167
+ best_matches_info = (path, feats_q['keypoints'], feats_c['keypoints'], matches['matches'])
168
+
169
+ except Exception as e:
170
+ log += f"- Error processing {path}: {e}\n"
171
+
172
+ # 3. DECISION & VISUALIZATION
173
+ if best_idx != -1:
174
+ winner_path, kpts0, kpts1, matches = best_matches_info
175
+ winner_label = g_labels[best_idx]
176
+ winner_species = g_species[best_idx]
177
+
178
+ # GENERATE VISUALIZATION
179
+ viz_image_path = create_match_visualization(input_path, winner_path, kpts0, kpts1, matches)
180
+
181
+ result = f"✅ **MATCH FOUND**\nSpecies: {winner_species}\nIndividual: {winner_label}\n(Verified with {best_score} geometric matches)"
182
+ return result + "\n\n" + log, viz_image_path
183
+ else:
184
+ return "⚠️ No confident match found.\n\n" + log, None
185
+
186
+ # --- UI SETUP ---
187
+ examples_list = []
188
+ if os.path.exists(TEST_QUERIES_DIR):
189
+ examples_list = [os.path.join(TEST_QUERIES_DIR, f) for f in os.listdir(TEST_QUERIES_DIR) if f.lower().endswith(('.jpg', '.png'))]
190
+
191
+ iface = gr.Interface(
192
+ fn=predict,
193
+ inputs=gr.Image(type="pil", label="Test Image"),
194
+ outputs=[
195
+ gr.Markdown(label="Result"),
196
+ gr.Image(label="Feature Matching Visualization")
197
+ ],
198
+ title="Wildlife Re-ID: Coarse-to-Fine Demo",
199
+ description="Click a test image below to run the identification.",
200
+ examples=examples_list,
201
+ cache_examples=False
202
+ )
203
+
204
+ iface.launch() g_embeddings = data["embeddings"].to(DEVICE)
205
+ g_paths = data["paths"]
206
+ g_labels = data["labels"]
207
+ g_species = data["species"]
208
+ else:
209
+ raise FileNotFoundError("Gallery file missing!")
210
+
211
+ # Transform
212
+ transform = transforms.Compose([
213
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
214
+ transforms.ToTensor(),
215
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
216
+ ])
217
+
218
  def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches):
219
  """
220
  Creates a side-by-side visualization with lines connecting matched keypoints.