File size: 10,027 Bytes
8125f92
 
 
 
 
 
 
 
127a9d6
8125f92
 
 
 
 
8c24824
8125f92
 
d5e7e28
8125f92
 
 
29db6b8
8125f92
8c24824
8125f92
 
 
 
 
 
 
 
 
8c24824
8125f92
 
 
 
 
 
d5e7e28
 
 
8c24824
75bb77a
 
 
 
 
 
 
 
 
18d9fa9
 
75bb77a
d5e7e28
 
 
 
 
 
127a9d6
d5e7e28
 
 
8c24824
127a9d6
 
 
 
 
 
 
d5e7e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127a9d6
 
d5e7e28
 
 
 
 
 
 
127a9d6
 
d5e7e28
 
127a9d6
 
d5e7e28
 
 
e050a1b
8c24824
a41b0ce
 
 
 
 
 
d5e7e28
8c24824
18d9fa9
75bb77a
 
 
 
 
 
 
 
18d9fa9
 
75bb77a
8c24824
e050a1b
d5e7e28
8c24824
d5e7e28
 
 
 
18d9fa9
d5e7e28
8c24824
18d9fa9
 
 
 
 
127a9d6
18d9fa9
 
 
 
 
8c24824
d5e7e28
 
 
18d9fa9
a41b0ce
 
 
127a9d6
d5e7e28
18d9fa9
d5e7e28
 
 
 
 
e050a1b
d5e7e28
 
 
 
e050a1b
d5e7e28
 
18d9fa9
 
8c24824
a41b0ce
8c24824
a41b0ce
d5e7e28
127a9d6
 
 
 
 
 
 
 
d5e7e28
 
18d9fa9
d5e7e28
 
a41b0ce
d5e7e28
8c24824
e050a1b
18d9fa9
 
 
e050a1b
a41b0ce
 
 
 
 
 
18d9fa9
d5e7e28
a41b0ce
e050a1b
a41b0ce
18d9fa9
a41b0ce
 
 
 
 
d5e7e28
8c24824
d5e7e28
 
 
 
0362400
29db6b8
a41b0ce
29db6b8
 
8c24824
29db6b8
 
a41b0ce
29db6b8
 
8c24824
29db6b8
a41b0ce
 
 
 
 
0611d3e
 
a41b0ce
 
 
 
0611d3e
a41b0ce
 
 
 
0611d3e
29db6b8
 
 
 
a41b0ce
29db6b8
d5e7e28
0362400
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import random

# LightGlue Imports
from lightglue import LightGlue, ALIKED
from lightglue.utils import load_image, rbd

# Configuration
WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth"
GALLERY_FILE = "mini_gallery.pt"
TEST_QUERIES_DIR = "test_queries"
IMG_SIZE = 384
EMBEDDING_DIM = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_VIZ_LINES = 50

# Lets load the models
print("Loading Models...")
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, EMBEDDING_DIM)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE), strict=False)
model.to(DEVICE).eval()

extractor = ALIKED(max_num_keypoints=1024, detection_threshold=0.2).eval().to(DEVICE)
matcher = LightGlue(features='aliked').eval().to(DEVICE)

# load the gallery of images
if os.path.exists(GALLERY_FILE):
    data = torch.load(GALLERY_FILE, map_location=DEVICE)
    g_embeddings = data["embeddings"].to(DEVICE)
    g_paths = data["paths"]
    g_labels = data["labels"]
    g_species = data["species"]
else:
    raise FileNotFoundError("Gallery file missing!")

# ground truth lookup
GT_LOOKUP = {}
if os.path.exists(TEST_QUERIES_DIR):
    for f in os.listdir(TEST_QUERIES_DIR):
        if f.lower().endswith(('.jpg', '.png', '.jpeg')):
            full_path = os.path.join(TEST_QUERIES_DIR, f)
            try:
                f_size = os.path.getsize(full_path)
                parts = f.split("_")
                if len(parts) >= 2:
                    GT_LOOKUP[f_size] = (parts[0], parts[1])
            except: pass

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches, output_name, max_lines=MAX_VIZ_LINES):
    if kpts0.dim() == 3: kpts0 = kpts0[0]
    if kpts1.dim() == 3: kpts1 = kpts1[0]

    # Density redction
    num_matches = len(matches)
    if num_matches > max_lines:
        indices = random.sample(range(num_matches), max_lines)
        matches_to_draw = matches[indices]
    else:
        matches_to_draw = matches

    img1 = Image.open(img_path1).convert("RGB")
    img2 = Image.open(img_path2).convert("RGB")
    
    target_h = 400
    w1, h1 = img1.size
    w2, h2 = img2.size
    scale1 = target_h / h1
    scale2 = target_h / h2
    
    img1 = img1.resize((int(w1 * scale1), target_h))
    img2 = img2.resize((int(w2 * scale2), target_h))
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.axis('off')
    
    img1_np = np.array(img1)
    img2_np = np.array(img2)
    h1_new, w1_new, _ = img1_np.shape
    h2_new, w2_new, _ = img2_np.shape
    
    width = w1_new + w2_new
    canvas = np.zeros((target_h, width, 3), dtype=np.uint8)
    canvas[:, :w1_new, :] = img1_np
    canvas[:, w1_new:, :] = img2_np
    
    ax.imshow(canvas)
    
    m_kpts0 = kpts0[matches_to_draw[..., 0]].cpu().numpy()
    m_kpts1 = kpts1[matches_to_draw[..., 1]].cpu().numpy()
    
    m_kpts0[:, 0] *= scale1
    m_kpts0[:, 1] *= scale1
    m_kpts1[:, 0] *= scale2
    m_kpts1[:, 1] *= scale2
    
    for (x0, y0), (x1, y1) in zip(m_kpts0, m_kpts1):
        ax.plot([x0, x1 + w1_new], [y0, y1], color="lime", linewidth=0.8, alpha=0.6)
        ax.scatter([x0, x1 + w1_new], [y0, y1], color="lime", s=3)

    plt.tight_layout()
    output_path = f"{output_name}.jpg"
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150)
    plt.close()
    return output_path

def predict(input_path):
    # default return values are emopty
    default_header = "Please upload an image."
    default_logs = ["", "", ""]
    default_imgs = [None, None, None]
    
    if input_path is None: 
        return default_header, default_logs[0], default_imgs[0], default_logs[1], default_imgs[1], default_logs[2], default_imgs[2]
    
    # Ground truth
    true_species, true_id = "Unknown", "Unknown"
    try:
        input_size = os.path.getsize(input_path)
        if input_size in GT_LOOKUP:
            true_species, true_id = GT_LOOKUP[input_size]
        else:
            filename = os.path.basename(input_path)
            if "_QUERY" in filename:
                parts = filename.split("_")
                true_species, true_id = parts[0], parts[1]
    except: pass

    # load Image
    input_image = Image.open(input_path).convert("RGB")

    # Coarse Search (ArcFace)
    img_t = transform(input_image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
        scores = torch.mm(q_emb, g_embeddings.t())
        top_scores, top_indices = torch.topk(scores, k=min(50, len(g_paths)))

    # Filter for unique individual candidates
    unique_candidates = []
    seen_individuals = set()
    for i in range(len(top_indices[0])):
        if len(unique_candidates) >= 3: break
        idx = top_indices[0][i].item()
        score = top_scores[0][i].item()
        label = g_labels[idx]
        if label not in seen_individuals:
            seen_individuals.add(label)
            unique_candidates.append((idx, score))

    # Fine search and visualization 
    feats_q = extractor.extract(load_image(input_path).to(DEVICE))
    
    best_score = -1
    best_candidate_idx = -1
    
    # Initialize output lists (size 3)
    cand_logs = ["Waiting for data...", "Waiting for data...", "Waiting for data..."]
    cand_viz_paths = [None, None, None]

    for rank, (idx, arcface_sim) in enumerate(unique_candidates):
        path = g_paths[idx]
        label = g_labels[idx]
        species = g_species[idx]
        
        try:
            if not os.path.exists(path): continue

            feats_c = extractor.extract(load_image(path).to(DEVICE))
            with torch.no_grad():
                matches = matcher({"image0": feats_q, "image1": feats_c})
                matches = rbd(matches)
            
            geo_matches = len(matches["matches"])
            sim_percent = arcface_sim * 100
            
            # Create individual log string
            log_str = f"### Candidate {rank+1}: {species} / {label}\n"
            log_str += f"**Coarse-Search Confidence:** {sim_percent:.1f}%   |   **📐 Geometric Matches:** {geo_matches}"
            cand_logs[rank] = log_str
            
            viz_name = f"viz_rank_{rank}"
            viz_path = create_match_visualization(
                input_path, path, 
                feats_q['keypoints'], feats_c['keypoints'], 
                matches['matches'], viz_name
            )
            cand_viz_paths[rank] = viz_path

            if geo_matches > best_score:
                best_score = geo_matches
                best_candidate_idx = idx
                
        except Exception as e:
            cand_logs[rank] = f"Error processing candidate: {e}"

    # Final decision calculation
    CONFIDENCE_THRESHOLD = 15
    if best_candidate_idx != -1 and best_score > CONFIDENCE_THRESHOLD:
        pred_species = g_species[best_candidate_idx]
        pred_id = g_labels[best_candidate_idx]
        is_correct = (pred_id == true_id)
        
        if true_id == "Unknown": header = f"### ❓ MATCH FOUND (No Ground Truth)\n"
        elif is_correct: header = f"### ✅ CORRECT MATCH!\n"
        else: header = f"### ❌ INCORRECT MATCH\n"
        
        header += f"**Ground Truth:** {true_species} / {true_id}   ➡️   **Prediction:** {pred_species} / {pred_id}\n"
        header += f"*(Confirmed with {best_score} geometric keypoints)*"
    else:
        header = "### ⚠️ UNKNOWN / NO MATCH\n"
        header += f"**Ground Truth:** {true_species} / {true_id}\n"
        header += f"**Prediction:** None (Best match only had {best_score} keypoints)\n"

    # Return: Header, then (Log, Img) for Cand 1, then (Log, Img) for Cand 2, etc.
    return (header, 
            cand_logs[0], cand_viz_paths[0],
            cand_logs[1], cand_viz_paths[1],
            cand_logs[2], cand_viz_paths[2])

# Setup for the user interface 
examples_list = []
if os.path.exists(TEST_QUERIES_DIR):
    examples_list = [os.path.join(TEST_QUERIES_DIR, f) for f in os.listdir(TEST_QUERIES_DIR) if f.lower().endswith(('.jpg', '.png'))]

with gr.Blocks(title="Wildlife Re-ID: Coarse-to-Fine Demo") as demo:
    gr.Markdown("# Wildlife Re-ID: Coarse-to-Fine Demo")
    gr.Markdown("Select a test image. The system finds the Top 3 UNIQUE individuals using embeddings, then verifies them using geometry.")

    with gr.Row():
        # Left Column: Input
        with gr.Column(scale=1):
            input_img = gr.Image(type="filepath", label="Test Image", height=300)
            gr.Examples(examples=examples_list, inputs=input_img, label="Test Examples", examples_per_page=4)
            submit_btn = gr.Button("Run Identification", variant="primary", size="lg")

        # Right Column: Vertical Stack of Candidates
        with gr.Column(scale=2):
            header_md = gr.Markdown(label="Final Decision")
            
            # Candidate 1 Group
            with gr.Group():
                log1 = gr.Markdown()
                # FIX: Removed show_download_button AND height
                img1 = gr.Image(label="Visualization", show_label=False)
            
            # Candidate 2 Group
            with gr.Group():
                log2 = gr.Markdown()
                img2 = gr.Image(label="Visualization", show_label=False)
                
            # Candidate 3 Group
            with gr.Group():
                log3 = gr.Markdown()
                img3 = gr.Image(label="Visualization", show_label=False)

    submit_btn.click(
        fn=predict,
        inputs=input_img,
        outputs=[header_md, log1, img1, log2, img2, log3, img3]
    )

demo.launch(allowed_paths=[TEST_QUERIES_DIR])