|
|
import sqlite3 |
|
|
import time |
|
|
from feature_matcher_utilities import extract_keypoints, feature_matching, unrotate_kps_W |
|
|
import os |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
from PIL import Image |
|
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
from lightglue import LightGlue |
|
|
from lightglue.utils import rbd |
|
|
from lightglue import SuperPoint, SIFT |
|
|
from lightglue.utils import load_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_colmap_db(db_path): |
|
|
if not os.path.exists(db_path): |
|
|
raise FileNotFoundError(f"Database file not found: {db_path}") |
|
|
conn = sqlite3.connect(db_path) |
|
|
cursor = conn.cursor() |
|
|
return conn, cursor |
|
|
|
|
|
def create_pair_id(image_id1, image_id2): |
|
|
if image_id1 > image_id2: |
|
|
image_id1, image_id2 = image_id2, image_id1 |
|
|
return image_id1 * 2147483647 + image_id2 |
|
|
|
|
|
def clean_database(cursor): |
|
|
"""Removes existing features and matches to ensure a clean overwrite.""" |
|
|
tables = ["keypoints", "descriptors"] |
|
|
for table in tables: |
|
|
cursor.execute(f"DELETE FROM {table};") |
|
|
print("Database cleaned (keypoints, descriptors, matches removed).") |
|
|
|
|
|
def insert_keypoints(cursor, image_id, keypoints, descriptors): |
|
|
""" |
|
|
keypoints: (N, 2) numpy array, float32 |
|
|
descriptors: (N, D) numpy array, float32 |
|
|
""" |
|
|
keypoints_blob = keypoints.tobytes() |
|
|
descriptors_blob = descriptors.tobytes() |
|
|
|
|
|
|
|
|
cursor.execute( |
|
|
"INSERT INTO keypoints(image_id, rows, cols, data) VALUES(?, ?, ?, ?)", |
|
|
(image_id, keypoints.shape[0], keypoints.shape[1], keypoints_blob) |
|
|
) |
|
|
|
|
|
|
|
|
cursor.execute( |
|
|
"INSERT INTO descriptors(image_id, rows, cols, data) VALUES(?, ?, ?, ?)", |
|
|
(image_id, descriptors.shape[0], descriptors.shape[1], descriptors_blob) |
|
|
) |
|
|
|
|
|
def insert_matches(cursor, image_id1, image_id2, matches): |
|
|
""" |
|
|
matches: (K, 2) numpy array, uint32. |
|
|
Col 0 is index in image1, Col 1 is index in image2 |
|
|
""" |
|
|
pair_id = create_pair_id(image_id1, image_id2) |
|
|
matches_blob = matches.tobytes() |
|
|
|
|
|
cursor.execute( |
|
|
"INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?)", |
|
|
(pair_id, matches.shape[0], matches.shape[1], matches_blob) |
|
|
) |
|
|
|
|
|
def verify_matches_visual(cursor, image_id1, image_id2, image_dir): |
|
|
""" |
|
|
Reads matches and keypoints from the COLMAP db and plots them. |
|
|
|
|
|
Args: |
|
|
cursor: SQLite cursor connected to the database. |
|
|
image_id1: ID of the first image. |
|
|
image_id2: ID of the second image. |
|
|
image_dir: Path to the directory containing the images. |
|
|
""" |
|
|
|
|
|
|
|
|
if image_id1 > image_id2: |
|
|
image_id1, image_id2 = image_id2, image_id1 |
|
|
swapped = True |
|
|
else: |
|
|
swapped = False |
|
|
|
|
|
pair_id = image_id1 * 2147483647 + image_id2 |
|
|
|
|
|
|
|
|
cursor.execute("SELECT data FROM matches WHERE pair_id = ?", (pair_id,)) |
|
|
match_row = cursor.fetchone() |
|
|
|
|
|
if match_row is None: |
|
|
print(f"No matches found in DB for pair {image_id1}-{image_id2}") |
|
|
return |
|
|
|
|
|
|
|
|
matches = np.frombuffer(match_row[0], dtype=np.uint32).reshape(-1, 2) |
|
|
|
|
|
|
|
|
|
|
|
if swapped: |
|
|
matches = matches[:, [1, 0]] |
|
|
|
|
|
|
|
|
def get_keypoints_and_name(img_id): |
|
|
|
|
|
cursor.execute("SELECT name FROM images WHERE image_id = ?", (img_id,)) |
|
|
name = cursor.fetchone()[0] |
|
|
|
|
|
|
|
|
cursor.execute("SELECT data FROM keypoints WHERE image_id = ?", (img_id,)) |
|
|
kp_row = cursor.fetchone() |
|
|
|
|
|
kpts = np.frombuffer(kp_row[0], dtype=np.float32).reshape(-1, 2) |
|
|
return name, kpts |
|
|
|
|
|
name1, kpts1 = get_keypoints_and_name(image_id1) |
|
|
name2, kpts2 = get_keypoints_and_name(image_id2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_kpts1 = kpts1[matches[:, 0]] |
|
|
valid_kpts2 = kpts2[matches[:, 1]] |
|
|
|
|
|
|
|
|
path1 = os.path.join(image_dir, name1) |
|
|
path2 = os.path.join(image_dir, name2) |
|
|
|
|
|
img1 = cv2.imread(path1) |
|
|
img2 = cv2.imread(path2) |
|
|
|
|
|
|
|
|
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) |
|
|
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
|
|
|
h1, w1, _ = img1.shape |
|
|
h2, w2, _ = img2.shape |
|
|
|
|
|
|
|
|
height = max(h1, h2) |
|
|
width = w1 + w2 |
|
|
canvas = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
|
|
|
canvas[:h1, :w1, :] = img1 |
|
|
canvas[:h2, w1:w1+w2, :] = img2 |
|
|
|
|
|
plt.figure(figsize=(15, 10)) |
|
|
plt.imshow(canvas) |
|
|
|
|
|
|
|
|
|
|
|
for (x1, y1), (x2, y2) in zip(valid_kpts1, valid_kpts2): |
|
|
plt.plot([x1, x2 + w1], [y1, y2], 'c-', alpha=0.6, linewidth=0.5) |
|
|
plt.plot(x1, y1, 'r.', markersize=2) |
|
|
plt.plot(x2 + w1, y2, 'r.', markersize=2) |
|
|
|
|
|
plt.title(f"DB Verification: {name1} (ID:{image_id1}) <-> {name2} (ID:{image_id2}) | Matches: {len(matches)}") |
|
|
plt.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import cv2 |
|
|
import os |
|
|
import sqlite3 |
|
|
|
|
|
def plot_matches_from_db(cursor, image_id1, image_id2, image_dir): |
|
|
""" |
|
|
Reads matches and keypoints for a specific pair from the COLMAP DB and plots them. |
|
|
|
|
|
Args: |
|
|
cursor: SQLite cursor. |
|
|
image_id1, image_id2: The IDs of the two images to plot. |
|
|
image_dir: Path to the directory containing the actual image files. |
|
|
""" |
|
|
|
|
|
|
|
|
if image_id1 > image_id2: |
|
|
id_a, id_b = image_id2, image_id1 |
|
|
swapped = True |
|
|
else: |
|
|
id_a, id_b = image_id1, image_id2 |
|
|
swapped = False |
|
|
|
|
|
pair_id = id_a * 2147483647 + id_b |
|
|
|
|
|
|
|
|
print(f"Fetching matches for pair {image_id1}-{image_id2} (PairID: {pair_id})...") |
|
|
cursor.execute("SELECT data, rows, cols FROM matches WHERE pair_id = ?", (pair_id,)) |
|
|
match_row = cursor.fetchone() |
|
|
|
|
|
if match_row is None: |
|
|
print(f"No matches found in database for Pair {image_id1}-{image_id2}") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
matches_blob = match_row[0] |
|
|
matches = np.frombuffer(matches_blob, dtype=np.uint32).reshape(-1, 2) |
|
|
|
|
|
|
|
|
|
|
|
if swapped: |
|
|
matches = matches[:, [1, 0]] |
|
|
|
|
|
|
|
|
def get_image_data(img_id): |
|
|
cursor.execute("SELECT name FROM images WHERE image_id = ?", (img_id,)) |
|
|
res = cursor.fetchone() |
|
|
if not res: |
|
|
raise ValueError(f"Image ID {img_id} not found in 'images' table.") |
|
|
name = res[0] |
|
|
|
|
|
cursor.execute("SELECT data FROM keypoints WHERE image_id = ?", (img_id,)) |
|
|
kp_res = cursor.fetchone() |
|
|
if not kp_res: |
|
|
raise ValueError(f"No keypoints found for Image ID {img_id}.") |
|
|
|
|
|
|
|
|
kpts = np.frombuffer(kp_res[0], dtype=np.float32).reshape(-1, 2) |
|
|
return name, kpts |
|
|
|
|
|
name1, kpts1 = get_image_data(image_id1) |
|
|
name2, kpts2 = get_image_data(image_id2) |
|
|
|
|
|
|
|
|
valid_kpts1 = kpts1[matches[:, 0]] |
|
|
valid_kpts2 = kpts2[matches[:, 1]] |
|
|
|
|
|
|
|
|
path1 = os.path.join(image_dir, name1) |
|
|
path2 = os.path.join(image_dir, name2) |
|
|
|
|
|
if not os.path.exists(path1) or not os.path.exists(path2): |
|
|
print(f"Error: Could not find image files at \n{path1}\n{path2}") |
|
|
return |
|
|
|
|
|
img1 = cv2.imread(path1) |
|
|
img2 = cv2.imread(path2) |
|
|
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) |
|
|
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
h1, w1 = img1.shape[:2] |
|
|
h2, w2 = img2.shape[:2] |
|
|
height = max(h1, h2) |
|
|
width = w1 + w2 |
|
|
canvas = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
canvas[:h1, :w1] = img1 |
|
|
canvas[:h2, w1:w1+w2] = img2 |
|
|
|
|
|
plt.figure(figsize=(20, 10)) |
|
|
plt.imshow(canvas) |
|
|
|
|
|
|
|
|
|
|
|
for (x1, y1), (x2, y2) in zip(valid_kpts1, valid_kpts2): |
|
|
plt.plot([x1, x2 + w1], [y1, y2], 'g-', alpha=0.5, linewidth=1.5) |
|
|
plt.plot(x1, y1, 'r.', markersize=4) |
|
|
plt.plot(x2 + w1, y2, 'r.', markersize=4) |
|
|
|
|
|
plt.title(f"{name1} <-> {name2} | Total Matches: {len(matches)}") |
|
|
plt.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--database", type=Path, required=True) |
|
|
parser.add_argument("--rgb_path", type=Path, required=True) |
|
|
parser.add_argument("--feature", type=str, required=True) |
|
|
parser.add_argument("--matcher", type=str, required=True) |
|
|
|
|
|
args, _ = parser.parse_known_args() |
|
|
|
|
|
DB_PATH = args.database |
|
|
IMAGE_DIR = args.rgb_path |
|
|
FEATURE_TYPE = args.feature |
|
|
MATCHER_TYPE = args.matcher |
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
matches_file_path = os.path.join(os.path.dirname(DB_PATH), "matches.txt") |
|
|
|
|
|
conn, cursor = load_colmap_db(DB_PATH) |
|
|
cursor.execute("SELECT image_id, name FROM images") |
|
|
images_info = {row[0]: row[1] for row in cursor.fetchall()} |
|
|
image_ids = sorted(images_info.keys()) |
|
|
|
|
|
clean_database(cursor) |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
extractor = SuperPoint(max_num_keypoints=128, detection_threshold=0.0).eval().cuda() |
|
|
matcher = LightGlue(width_confidence=-1).eval().cuda() |
|
|
|
|
|
total_time = 0.0 |
|
|
with open(matches_file_path, "w") as f_match: |
|
|
for i, id_i in enumerate(tqdm(image_ids, desc="Outer Loop")): |
|
|
fname_i = images_info[id_i] |
|
|
path_i = os.path.join(IMAGE_DIR, fname_i) |
|
|
img_i = Image.open(path_i).convert("RGB") |
|
|
t_i = TF.to_tensor(img_i) |
|
|
imgs_i = [] |
|
|
imgs_j = [] |
|
|
ids_j = [] |
|
|
for j, id_j in enumerate(tqdm(image_ids[i+1:], desc="Inner Loop", leave=False), start=i+1): |
|
|
fname_j = images_info[id_j] |
|
|
path_j = os.path.join(IMAGE_DIR, fname_j) |
|
|
img_j = Image.open(path_j).convert("RGB") |
|
|
t_j = TF.to_tensor(img_j) |
|
|
imgs_j.append(t_j) |
|
|
imgs_i.append(t_i) |
|
|
ids_j.append(id_j) |
|
|
|
|
|
if len(imgs_j) == 0: |
|
|
continue |
|
|
print(f"Processing batch: Image {fname_i} with {len(imgs_j)} images.") |
|
|
batch_i = torch.stack(imgs_i, dim=0).to(DEVICE) |
|
|
batch_j = torch.stack(imgs_j, dim=0).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
feats_i = extractor({"image": batch_i}) |
|
|
feats_j = extractor({"image": batch_j}) |
|
|
|
|
|
kpts = feats_i['keypoints'][0].squeeze(0).cpu().numpy().astype(np.float32) |
|
|
descs = feats_i['descriptors'][0].squeeze(0).cpu().numpy().astype(np.float32) |
|
|
insert_keypoints(cursor, id_i, kpts, descs) |
|
|
|
|
|
data = {} |
|
|
data['image0'] = {} |
|
|
data['image1'] = {} |
|
|
data['image0']['keypoints'] = feats_i['keypoints'] |
|
|
data['image0']['descriptors'] = feats_i['descriptors'] |
|
|
data['image1']['keypoints'] = feats_j['keypoints'] |
|
|
data['image1']['descriptors'] = feats_j['descriptors'] |
|
|
|
|
|
|
|
|
|
|
|
t0 = time.perf_counter() |
|
|
matches01 = matcher(data) |
|
|
t1 = time.perf_counter() |
|
|
elapsed = t1 - t0 |
|
|
print(f"Matching took {elapsed:.4f} seconds") |
|
|
total_time += elapsed |
|
|
|
|
|
for k in range(len(matches01["matches0"])): |
|
|
m0 = matches01["matches0"][k] |
|
|
valid = m0 > -1 |
|
|
if valid.any(): |
|
|
fname_j = images_info[ids_j[k]] |
|
|
f_match.write(f"{fname_i} {fname_j}\n") |
|
|
idx0 = torch.nonzero(valid, as_tuple=False).squeeze(1) |
|
|
idx1 = m0[valid].long() |
|
|
matches_np = torch.stack([idx0, idx1], dim=1).cpu().numpy().astype(int) |
|
|
np.savetxt(f_match, matches_np, fmt="%d") |
|
|
f_match.write("\n") |
|
|
|
|
|
del batch_i, batch_j, feats_i, feats_j, data, matches01, imgs_i, imgs_j |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.empty_cache() |
|
|
import gc |
|
|
gc.collect() |
|
|
|
|
|
conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
conn.close() |
|
|
print("Database overwrite complete.") |
|
|
print(f"Total matching time: {total_time:.2f} seconds.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|