SalHargis's picture
Update app.py
8c24824 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import random
# LightGlue Imports
from lightglue import LightGlue, ALIKED
from lightglue.utils import load_image, rbd
# Configuration
WEIGHTS_PATH = "MiewID_ArcFace_FineTun.pth"
GALLERY_FILE = "mini_gallery.pt"
TEST_QUERIES_DIR = "test_queries"
IMG_SIZE = 384
EMBEDDING_DIM = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_VIZ_LINES = 50
# Lets load the models
print("Loading Models...")
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, EMBEDDING_DIM)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE), strict=False)
model.to(DEVICE).eval()
extractor = ALIKED(max_num_keypoints=1024, detection_threshold=0.2).eval().to(DEVICE)
matcher = LightGlue(features='aliked').eval().to(DEVICE)
# load the gallery of images
if os.path.exists(GALLERY_FILE):
data = torch.load(GALLERY_FILE, map_location=DEVICE)
g_embeddings = data["embeddings"].to(DEVICE)
g_paths = data["paths"]
g_labels = data["labels"]
g_species = data["species"]
else:
raise FileNotFoundError("Gallery file missing!")
# ground truth lookup
GT_LOOKUP = {}
if os.path.exists(TEST_QUERIES_DIR):
for f in os.listdir(TEST_QUERIES_DIR):
if f.lower().endswith(('.jpg', '.png', '.jpeg')):
full_path = os.path.join(TEST_QUERIES_DIR, f)
try:
f_size = os.path.getsize(full_path)
parts = f.split("_")
if len(parts) >= 2:
GT_LOOKUP[f_size] = (parts[0], parts[1])
except: pass
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def create_match_visualization(img_path1, img_path2, kpts0, kpts1, matches, output_name, max_lines=MAX_VIZ_LINES):
if kpts0.dim() == 3: kpts0 = kpts0[0]
if kpts1.dim() == 3: kpts1 = kpts1[0]
# Density redction
num_matches = len(matches)
if num_matches > max_lines:
indices = random.sample(range(num_matches), max_lines)
matches_to_draw = matches[indices]
else:
matches_to_draw = matches
img1 = Image.open(img_path1).convert("RGB")
img2 = Image.open(img_path2).convert("RGB")
target_h = 400
w1, h1 = img1.size
w2, h2 = img2.size
scale1 = target_h / h1
scale2 = target_h / h2
img1 = img1.resize((int(w1 * scale1), target_h))
img2 = img2.resize((int(w2 * scale2), target_h))
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.axis('off')
img1_np = np.array(img1)
img2_np = np.array(img2)
h1_new, w1_new, _ = img1_np.shape
h2_new, w2_new, _ = img2_np.shape
width = w1_new + w2_new
canvas = np.zeros((target_h, width, 3), dtype=np.uint8)
canvas[:, :w1_new, :] = img1_np
canvas[:, w1_new:, :] = img2_np
ax.imshow(canvas)
m_kpts0 = kpts0[matches_to_draw[..., 0]].cpu().numpy()
m_kpts1 = kpts1[matches_to_draw[..., 1]].cpu().numpy()
m_kpts0[:, 0] *= scale1
m_kpts0[:, 1] *= scale1
m_kpts1[:, 0] *= scale2
m_kpts1[:, 1] *= scale2
for (x0, y0), (x1, y1) in zip(m_kpts0, m_kpts1):
ax.plot([x0, x1 + w1_new], [y0, y1], color="lime", linewidth=0.8, alpha=0.6)
ax.scatter([x0, x1 + w1_new], [y0, y1], color="lime", s=3)
plt.tight_layout()
output_path = f"{output_name}.jpg"
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150)
plt.close()
return output_path
def predict(input_path):
# default return values are emopty
default_header = "Please upload an image."
default_logs = ["", "", ""]
default_imgs = [None, None, None]
if input_path is None:
return default_header, default_logs[0], default_imgs[0], default_logs[1], default_imgs[1], default_logs[2], default_imgs[2]
# Ground truth
true_species, true_id = "Unknown", "Unknown"
try:
input_size = os.path.getsize(input_path)
if input_size in GT_LOOKUP:
true_species, true_id = GT_LOOKUP[input_size]
else:
filename = os.path.basename(input_path)
if "_QUERY" in filename:
parts = filename.split("_")
true_species, true_id = parts[0], parts[1]
except: pass
# load Image
input_image = Image.open(input_path).convert("RGB")
# Coarse Search (ArcFace)
img_t = transform(input_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
q_emb = torch.nn.functional.normalize(model(img_t), p=2, dim=1)
scores = torch.mm(q_emb, g_embeddings.t())
top_scores, top_indices = torch.topk(scores, k=min(50, len(g_paths)))
# Filter for unique individual candidates
unique_candidates = []
seen_individuals = set()
for i in range(len(top_indices[0])):
if len(unique_candidates) >= 3: break
idx = top_indices[0][i].item()
score = top_scores[0][i].item()
label = g_labels[idx]
if label not in seen_individuals:
seen_individuals.add(label)
unique_candidates.append((idx, score))
# Fine search and visualization
feats_q = extractor.extract(load_image(input_path).to(DEVICE))
best_score = -1
best_candidate_idx = -1
# Initialize output lists (size 3)
cand_logs = ["Waiting for data...", "Waiting for data...", "Waiting for data..."]
cand_viz_paths = [None, None, None]
for rank, (idx, arcface_sim) in enumerate(unique_candidates):
path = g_paths[idx]
label = g_labels[idx]
species = g_species[idx]
try:
if not os.path.exists(path): continue
feats_c = extractor.extract(load_image(path).to(DEVICE))
with torch.no_grad():
matches = matcher({"image0": feats_q, "image1": feats_c})
matches = rbd(matches)
geo_matches = len(matches["matches"])
sim_percent = arcface_sim * 100
# Create individual log string
log_str = f"### Candidate {rank+1}: {species} / {label}\n"
log_str += f"**Coarse-Search Confidence:** {sim_percent:.1f}%   |   **📐 Geometric Matches:** {geo_matches}"
cand_logs[rank] = log_str
viz_name = f"viz_rank_{rank}"
viz_path = create_match_visualization(
input_path, path,
feats_q['keypoints'], feats_c['keypoints'],
matches['matches'], viz_name
)
cand_viz_paths[rank] = viz_path
if geo_matches > best_score:
best_score = geo_matches
best_candidate_idx = idx
except Exception as e:
cand_logs[rank] = f"Error processing candidate: {e}"
# Final decision calculation
CONFIDENCE_THRESHOLD = 15
if best_candidate_idx != -1 and best_score > CONFIDENCE_THRESHOLD:
pred_species = g_species[best_candidate_idx]
pred_id = g_labels[best_candidate_idx]
is_correct = (pred_id == true_id)
if true_id == "Unknown": header = f"### ❓ MATCH FOUND (No Ground Truth)\n"
elif is_correct: header = f"### ✅ CORRECT MATCH!\n"
else: header = f"### ❌ INCORRECT MATCH\n"
header += f"**Ground Truth:** {true_species} / {true_id}   ➡️   **Prediction:** {pred_species} / {pred_id}\n"
header += f"*(Confirmed with {best_score} geometric keypoints)*"
else:
header = "### ⚠️ UNKNOWN / NO MATCH\n"
header += f"**Ground Truth:** {true_species} / {true_id}\n"
header += f"**Prediction:** None (Best match only had {best_score} keypoints)\n"
# Return: Header, then (Log, Img) for Cand 1, then (Log, Img) for Cand 2, etc.
return (header,
cand_logs[0], cand_viz_paths[0],
cand_logs[1], cand_viz_paths[1],
cand_logs[2], cand_viz_paths[2])
# Setup for the user interface
examples_list = []
if os.path.exists(TEST_QUERIES_DIR):
examples_list = [os.path.join(TEST_QUERIES_DIR, f) for f in os.listdir(TEST_QUERIES_DIR) if f.lower().endswith(('.jpg', '.png'))]
with gr.Blocks(title="Wildlife Re-ID: Coarse-to-Fine Demo") as demo:
gr.Markdown("# Wildlife Re-ID: Coarse-to-Fine Demo")
gr.Markdown("Select a test image. The system finds the Top 3 UNIQUE individuals using embeddings, then verifies them using geometry.")
with gr.Row():
# Left Column: Input
with gr.Column(scale=1):
input_img = gr.Image(type="filepath", label="Test Image", height=300)
gr.Examples(examples=examples_list, inputs=input_img, label="Test Examples", examples_per_page=4)
submit_btn = gr.Button("Run Identification", variant="primary", size="lg")
# Right Column: Vertical Stack of Candidates
with gr.Column(scale=2):
header_md = gr.Markdown(label="Final Decision")
# Candidate 1 Group
with gr.Group():
log1 = gr.Markdown()
# FIX: Removed show_download_button AND height
img1 = gr.Image(label="Visualization", show_label=False)
# Candidate 2 Group
with gr.Group():
log2 = gr.Markdown()
img2 = gr.Image(label="Visualization", show_label=False)
# Candidate 3 Group
with gr.Group():
log3 = gr.Markdown()
img3 = gr.Image(label="Visualization", show_label=False)
submit_btn.click(
fn=predict,
inputs=input_img,
outputs=[header_md, log1, img1, log2, img2, log3, img3]
)
demo.launch(allowed_paths=[TEST_QUERIES_DIR])