import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import os import numpy as np import matplotlib.pyplot as plt import random # LightGlue Imports from lightglue import LightGlue, ALIKED from lightglue.utils import load_image, rbd # Configuration WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth" GALLERY_FILE = "mini_gallery.pt" TEST_QUERIES_DIR = "test_queries" IMG_SIZE = 384 EMBEDDING_DIM = 512 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_VIZ_LINES = 50 # Lets load the models print("Loading Models...") model = models.resnet50(weights=None) model.fc = nn.Linear(model.fc.in_features, EMBEDDING_DIM) model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE), strict=False) model.to(DEVICE).eval() extractor = ALIKED(max_num_keypoints=1024, detection_threshold=0.2).eval().to(DEVICE) matcher = LightGlue(features='aliked').eval().to(DEVICE) # load the gallery of images if os.path.exists(GALLERY_FILE): data = torch.load(GALLERY_FILE, map_location=DEVICE) g_embeddings = data["embeddings"].to(DEVICE) g_paths = data["paths"] g_labels = data["labels"] g_species = data["species"] else: raise FileNotFoundError("Gallery file missing!") # ground truth lookup GT_LOOKUP = {} if os.path.exists(TEST_QUERIES_DIR): for f in os.listdir(TEST_QUERIES_DIR): if f.lower().endswith(('.jpg', '.png', '.jpeg')): full_path = os.path.join(TEST_QUERIES_DIR, f) try: f_size = os.path.getsize(full_path) parts = f.split("_") if len(parts) >= 2: GT_LOOKUP[f_size] = (parts[0], parts[1]) except: pass transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches, output_name, max_lines=MAX_VIZ_LINES): if kpts0.dim() == 3: kpts0 = kpts0[0] if kpts1.dim() == 3: kpts1 = kpts1[0] # Density redction num_matches = len(matches) if num_matches > max_lines: indices = random.sample(range(num_matches), max_lines) matches_to_draw = matches[indices] else: matches_to_draw = matches img1 = Image.open(img_path1).convert("RGB") img2 = Image.open(img_path2).convert("RGB") target_h = 400 w1, h1 = img1.size w2, h2 = img2.size scale1 = target_h / h1 scale2 = target_h / h2 img1 = img1.resize((int(w1 * scale1), target_h)) img2 = img2.resize((int(w2 * scale2), target_h)) fig, ax = plt.subplots(1, 1, figsize=(10, 5)) ax.axis('off') img1_np = np.array(img1) img2_np = np.array(img2) h1_new, w1_new, _ = img1_np.shape h2_new, w2_new, _ = img2_np.shape width = w1_new + w2_new canvas = np.zeros((target_h, width, 3), dtype=np.uint8) canvas[:, :w1_new, :] = img1_np canvas[:, w1_new:, :] = img2_np ax.imshow(canvas) m_kpts0 = kpts0[matches_to_draw[..., 0]].cpu().numpy() m_kpts1 = kpts1[matches_to_draw[..., 1]].cpu().numpy() m_kpts0[:, 0] *= scale1 m_kpts0[:, 1] *= scale1 m_kpts1[:, 0] *= scale2 m_kpts1[:, 1] *= scale2 for (x0, y0), (x1, y1) in zip(m_kpts0, m_kpts1): ax.plot([x0, x1 + w1_new], [y0, y1], color="lime", linewidth=0.8, alpha=0.6) ax.scatter([x0, x1 + w1_new], [y0, y1], color="lime", s=3) plt.tight_layout() output_path = f"{output_name}.jpg" plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150) plt.close() return output_path def predict(input_path): # default return values are emopty default_header = "Please upload an image." default_logs = ["", "", ""] default_imgs = [None, None, None] if input_path is None: return default_header, default_logs[0], default_imgs[0], default_logs[1], default_imgs[1], default_logs[2], default_imgs[2] # Ground truth true_species, true_id = "Unknown", "Unknown" try: input_size = os.path.getsize(input_path) if input_size in GT_LOOKUP: true_species, true_id = GT_LOOKUP[input_size] else: filename = os.path.basename(input_path) if "_QUERY" in filename: parts = filename.split("_") true_species, true_id = parts[0], parts[1] except: pass # load Image input_image = Image.open(input_path).convert("RGB") # Coarse Search (ArcFace) img_t = transform(input_image).unsqueeze(0).to(DEVICE) with torch.no_grad(): q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1) scores = torch.mm(q_emb, g_embeddings.t()) top_scores, top_indices = torch.topk(scores, k=min(50, len(g_paths))) # Filter for unique individual candidates unique_candidates = [] seen_individuals = set() for i in range(len(top_indices[0])): if len(unique_candidates) >= 3: break idx = top_indices[0][i].item() score = top_scores[0][i].item() label = g_labels[idx] if label not in seen_individuals: seen_individuals.add(label) unique_candidates.append((idx, score)) # Fine search and visualization feats_q = extractor.extract(load_image(input_path).to(DEVICE)) best_score = -1 best_candidate_idx = -1 # Initialize output lists (size 3) cand_logs = ["Waiting for data...", "Waiting for data...", "Waiting for data..."] cand_viz_paths = [None, None, None] for rank, (idx, arcface_sim) in enumerate(unique_candidates): path = g_paths[idx] label = g_labels[idx] species = g_species[idx] try: if not os.path.exists(path): continue feats_c = extractor.extract(load_image(path).to(DEVICE)) with torch.no_grad(): matches = matcher({"image0": feats_q, "image1": feats_c}) matches = rbd(matches) geo_matches = len(matches["matches"]) sim_percent = arcface_sim * 100 # Create individual log string log_str = f"### Candidate {rank+1}: {species} / {label}\n" log_str += f"**Coarse-Search Confidence:** {sim_percent:.1f}%   |   **📐 Geometric Matches:** {geo_matches}" cand_logs[rank] = log_str viz_name = f"viz_rank_{rank}" viz_path = create_match_visualization( input_path, path, feats_q['keypoints'], feats_c['keypoints'], matches['matches'], viz_name ) cand_viz_paths[rank] = viz_path if geo_matches > best_score: best_score = geo_matches best_candidate_idx = idx except Exception as e: cand_logs[rank] = f"Error processing candidate: {e}" # Final decision calculation CONFIDENCE_THRESHOLD = 15 if best_candidate_idx != -1 and best_score > CONFIDENCE_THRESHOLD: pred_species = g_species[best_candidate_idx] pred_id = g_labels[best_candidate_idx] is_correct = (pred_id == true_id) if true_id == "Unknown": header = f"### ❓ MATCH FOUND (No Ground Truth)\n" elif is_correct: header = f"### ✅ CORRECT MATCH!\n" else: header = f"### ❌ INCORRECT MATCH\n" header += f"**Ground Truth:** {true_species} / {true_id}   ➡️   **Prediction:** {pred_species} / {pred_id}\n" header += f"*(Confirmed with {best_score} geometric keypoints)*" else: header = "### ⚠️ UNKNOWN / NO MATCH\n" header += f"**Ground Truth:** {true_species} / {true_id}\n" header += f"**Prediction:** None (Best match only had {best_score} keypoints)\n" # Return: Header, then (Log, Img) for Cand 1, then (Log, Img) for Cand 2, etc. return (header, cand_logs[0], cand_viz_paths[0], cand_logs[1], cand_viz_paths[1], cand_logs[2], cand_viz_paths[2]) # Setup for the user interface examples_list = [] if os.path.exists(TEST_QUERIES_DIR): examples_list = [os.path.join(TEST_QUERIES_DIR, f) for f in os.listdir(TEST_QUERIES_DIR) if f.lower().endswith(('.jpg', '.png'))] with gr.Blocks(title="Wildlife Re-ID: Coarse-to-Fine Demo") as demo: gr.Markdown("# Wildlife Re-ID: Coarse-to-Fine Demo") gr.Markdown("Select a test image. The system finds the Top 3 UNIQUE individuals using embeddings, then verifies them using geometry.") with gr.Row(): # Left Column: Input with gr.Column(scale=1): input_img = gr.Image(type="filepath", label="Test Image", height=300) gr.Examples(examples=examples_list, inputs=input_img, label="Test Examples", examples_per_page=4) submit_btn = gr.Button("Run Identification", variant="primary", size="lg") # Right Column: Vertical Stack of Candidates with gr.Column(scale=2): header_md = gr.Markdown(label="Final Decision") # Candidate 1 Group with gr.Group(): log1 = gr.Markdown() # FIX: Removed show_download_button AND height img1 = gr.Image(label="Visualization", show_label=False) # Candidate 2 Group with gr.Group(): log2 = gr.Markdown() img2 = gr.Image(label="Visualization", show_label=False) # Candidate 3 Group with gr.Group(): log3 = gr.Markdown() img3 = gr.Image(label="Visualization", show_label=False) submit_btn.click( fn=predict, inputs=input_img, outputs=[header_md, log1, img1, log2, img2, log3, img3] ) demo.launch(allowed_paths=[TEST_QUERIES_DIR])