| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| from PIL import Image |
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import random |
|
|
| |
| from lightglue import LightGlue, ALIKED |
| from lightglue.utils import load_image, rbd |
|
|
| |
| WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth" |
| GALLERY_FILE = "mini_gallery.pt" |
| TEST_QUERIES_DIR = "test_queries" |
| IMG_SIZE = 384 |
| EMBEDDING_DIM = 512 |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| MAX_VIZ_LINES = 50 |
|
|
| |
| print("Loading Models...") |
| model = models.resnet50(weights=None) |
| model.fc = nn.Linear(model.fc.in_features, EMBEDDING_DIM) |
| model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE), strict=False) |
| model.to(DEVICE).eval() |
|
|
| extractor = ALIKED(max_num_keypoints=1024, detection_threshold=0.2).eval().to(DEVICE) |
| matcher = LightGlue(features='aliked').eval().to(DEVICE) |
|
|
| |
| if os.path.exists(GALLERY_FILE): |
| data = torch.load(GALLERY_FILE, map_location=DEVICE) |
| g_embeddings = data["embeddings"].to(DEVICE) |
| g_paths = data["paths"] |
| g_labels = data["labels"] |
| g_species = data["species"] |
| else: |
| raise FileNotFoundError("Gallery file missing!") |
|
|
| |
| GT_LOOKUP = {} |
| if os.path.exists(TEST_QUERIES_DIR): |
| for f in os.listdir(TEST_QUERIES_DIR): |
| if f.lower().endswith(('.jpg', '.png', '.jpeg')): |
| full_path = os.path.join(TEST_QUERIES_DIR, f) |
| try: |
| f_size = os.path.getsize(full_path) |
| parts = f.split("_") |
| if len(parts) >= 2: |
| GT_LOOKUP[f_size] = (parts[0], parts[1]) |
| except: pass |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches, output_name, max_lines=MAX_VIZ_LINES): |
| if kpts0.dim() == 3: kpts0 = kpts0[0] |
| if kpts1.dim() == 3: kpts1 = kpts1[0] |
|
|
| |
| num_matches = len(matches) |
| if num_matches > max_lines: |
| indices = random.sample(range(num_matches), max_lines) |
| matches_to_draw = matches[indices] |
| else: |
| matches_to_draw = matches |
|
|
| img1 = Image.open(img_path1).convert("RGB") |
| img2 = Image.open(img_path2).convert("RGB") |
| |
| target_h = 400 |
| w1, h1 = img1.size |
| w2, h2 = img2.size |
| scale1 = target_h / h1 |
| scale2 = target_h / h2 |
| |
| img1 = img1.resize((int(w1 * scale1), target_h)) |
| img2 = img2.resize((int(w2 * scale2), target_h)) |
| |
| fig, ax = plt.subplots(1, 1, figsize=(10, 5)) |
| ax.axis('off') |
| |
| img1_np = np.array(img1) |
| img2_np = np.array(img2) |
| h1_new, w1_new, _ = img1_np.shape |
| h2_new, w2_new, _ = img2_np.shape |
| |
| width = w1_new + w2_new |
| canvas = np.zeros((target_h, width, 3), dtype=np.uint8) |
| canvas[:, :w1_new, :] = img1_np |
| canvas[:, w1_new:, :] = img2_np |
| |
| ax.imshow(canvas) |
| |
| m_kpts0 = kpts0[matches_to_draw[..., 0]].cpu().numpy() |
| m_kpts1 = kpts1[matches_to_draw[..., 1]].cpu().numpy() |
| |
| m_kpts0[:, 0] *= scale1 |
| m_kpts0[:, 1] *= scale1 |
| m_kpts1[:, 0] *= scale2 |
| m_kpts1[:, 1] *= scale2 |
| |
| for (x0, y0), (x1, y1) in zip(m_kpts0, m_kpts1): |
| ax.plot([x0, x1 + w1_new], [y0, y1], color="lime", linewidth=0.8, alpha=0.6) |
| ax.scatter([x0, x1 + w1_new], [y0, y1], color="lime", s=3) |
|
|
| plt.tight_layout() |
| output_path = f"{output_name}.jpg" |
| plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150) |
| plt.close() |
| return output_path |
|
|
| def predict(input_path): |
| |
| default_header = "Please upload an image." |
| default_logs = ["", "", ""] |
| default_imgs = [None, None, None] |
| |
| if input_path is None: |
| return default_header, default_logs[0], default_imgs[0], default_logs[1], default_imgs[1], default_logs[2], default_imgs[2] |
| |
| |
| true_species, true_id = "Unknown", "Unknown" |
| try: |
| input_size = os.path.getsize(input_path) |
| if input_size in GT_LOOKUP: |
| true_species, true_id = GT_LOOKUP[input_size] |
| else: |
| filename = os.path.basename(input_path) |
| if "_QUERY" in filename: |
| parts = filename.split("_") |
| true_species, true_id = parts[0], parts[1] |
| except: pass |
|
|
| |
| input_image = Image.open(input_path).convert("RGB") |
|
|
| |
| img_t = transform(input_image).unsqueeze(0).to(DEVICE) |
| with torch.no_grad(): |
| q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1) |
| scores = torch.mm(q_emb, g_embeddings.t()) |
| top_scores, top_indices = torch.topk(scores, k=min(50, len(g_paths))) |
|
|
| |
| unique_candidates = [] |
| seen_individuals = set() |
| for i in range(len(top_indices[0])): |
| if len(unique_candidates) >= 3: break |
| idx = top_indices[0][i].item() |
| score = top_scores[0][i].item() |
| label = g_labels[idx] |
| if label not in seen_individuals: |
| seen_individuals.add(label) |
| unique_candidates.append((idx, score)) |
|
|
| |
| feats_q = extractor.extract(load_image(input_path).to(DEVICE)) |
| |
| best_score = -1 |
| best_candidate_idx = -1 |
| |
| |
| cand_logs = ["Waiting for data...", "Waiting for data...", "Waiting for data..."] |
| cand_viz_paths = [None, None, None] |
|
|
| for rank, (idx, arcface_sim) in enumerate(unique_candidates): |
| path = g_paths[idx] |
| label = g_labels[idx] |
| species = g_species[idx] |
| |
| try: |
| if not os.path.exists(path): continue |
|
|
| feats_c = extractor.extract(load_image(path).to(DEVICE)) |
| with torch.no_grad(): |
| matches = matcher({"image0": feats_q, "image1": feats_c}) |
| matches = rbd(matches) |
| |
| geo_matches = len(matches["matches"]) |
| sim_percent = arcface_sim * 100 |
| |
| |
| log_str = f"### Candidate {rank+1}: {species} / {label}\n" |
| log_str += f"**Coarse-Search Confidence:** {sim_percent:.1f}% | **📐 Geometric Matches:** {geo_matches}" |
| cand_logs[rank] = log_str |
| |
| viz_name = f"viz_rank_{rank}" |
| viz_path = create_match_visualization( |
| input_path, path, |
| feats_q['keypoints'], feats_c['keypoints'], |
| matches['matches'], viz_name |
| ) |
| cand_viz_paths[rank] = viz_path |
|
|
| if geo_matches > best_score: |
| best_score = geo_matches |
| best_candidate_idx = idx |
| |
| except Exception as e: |
| cand_logs[rank] = f"Error processing candidate: {e}" |
|
|
| |
| CONFIDENCE_THRESHOLD = 15 |
| if best_candidate_idx != -1 and best_score > CONFIDENCE_THRESHOLD: |
| pred_species = g_species[best_candidate_idx] |
| pred_id = g_labels[best_candidate_idx] |
| is_correct = (pred_id == true_id) |
| |
| if true_id == "Unknown": header = f"### ❓ MATCH FOUND (No Ground Truth)\n" |
| elif is_correct: header = f"### ✅ CORRECT MATCH!\n" |
| else: header = f"### ❌ INCORRECT MATCH\n" |
| |
| header += f"**Ground Truth:** {true_species} / {true_id} ➡️ **Prediction:** {pred_species} / {pred_id}\n" |
| header += f"*(Confirmed with {best_score} geometric keypoints)*" |
| else: |
| header = "### ⚠️ UNKNOWN / NO MATCH\n" |
| header += f"**Ground Truth:** {true_species} / {true_id}\n" |
| header += f"**Prediction:** None (Best match only had {best_score} keypoints)\n" |
|
|
| |
| return (header, |
| cand_logs[0], cand_viz_paths[0], |
| cand_logs[1], cand_viz_paths[1], |
| cand_logs[2], cand_viz_paths[2]) |
|
|
| |
| examples_list = [] |
| if os.path.exists(TEST_QUERIES_DIR): |
| examples_list = [os.path.join(TEST_QUERIES_DIR, f) for f in os.listdir(TEST_QUERIES_DIR) if f.lower().endswith(('.jpg', '.png'))] |
|
|
| with gr.Blocks(title="Wildlife Re-ID: Coarse-to-Fine Demo") as demo: |
| gr.Markdown("# Wildlife Re-ID: Coarse-to-Fine Demo") |
| gr.Markdown("Select a test image. The system finds the Top 3 UNIQUE individuals using embeddings, then verifies them using geometry.") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| input_img = gr.Image(type="filepath", label="Test Image", height=300) |
| gr.Examples(examples=examples_list, inputs=input_img, label="Test Examples", examples_per_page=4) |
| submit_btn = gr.Button("Run Identification", variant="primary", size="lg") |
|
|
| |
| with gr.Column(scale=2): |
| header_md = gr.Markdown(label="Final Decision") |
| |
| |
| with gr.Group(): |
| log1 = gr.Markdown() |
| |
| img1 = gr.Image(label="Visualization", show_label=False) |
| |
| |
| with gr.Group(): |
| log2 = gr.Markdown() |
| img2 = gr.Image(label="Visualization", show_label=False) |
| |
| |
| with gr.Group(): |
| log3 = gr.Markdown() |
| img3 = gr.Image(label="Visualization", show_label=False) |
|
|
| submit_btn.click( |
| fn=predict, |
| inputs=input_img, |
| outputs=[header_md, log1, img1, log2, img2, log3, img3] |
| ) |
|
|
| demo.launch(allowed_paths=[TEST_QUERIES_DIR]) |