fontan commited on
Commit
c2b8483
·
1 Parent(s): 45446db

lightglue matcher

Browse files
.gitignore CHANGED
@@ -1,3 +1,6 @@
1
  vocab_tree_flickr100K_words1M.bin
2
  vocab_tree_flickr100K_words256K.bin
3
- vocab_tree_flickr100K_words32K.bin
 
 
 
 
1
  vocab_tree_flickr100K_words1M.bin
2
  vocab_tree_flickr100K_words256K.bin
3
+ vocab_tree_flickr100K_words32K.bin
4
+
5
+ LightGlue/
6
+ __pycache__/
colmap_matcher.sh CHANGED
@@ -26,7 +26,7 @@ python3 Baselines/colmap/create_colmap_image_list.py "$rgb_csv" "$colmap_image_l
26
  # Create Colmap Database
27
  database="${exp_folder_colmap}/colmap_database.db"
28
  rm -rf ${database}
29
- colmap database_creator --database_path ${database}
30
 
31
  # Feature extractor
32
  echo " colmap feature_extractor ..."
@@ -141,4 +141,15 @@ then
141
  --SequentialMatching.loop_detection 1 \
142
  --SequentialMatching.vocab_tree_path ${vocabulary_tree} \
143
  --FeatureMatching.use_gpu "${use_gpu}"
144
- fi
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Create Colmap Database
27
  database="${exp_folder_colmap}/colmap_database.db"
28
  rm -rf ${database}
29
+ colmap database_creator --database_path ${database}
30
 
31
  # Feature extractor
32
  echo " colmap feature_extractor ..."
 
141
  --SequentialMatching.loop_detection 1 \
142
  --SequentialMatching.vocab_tree_path ${vocabulary_tree} \
143
  --FeatureMatching.use_gpu "${use_gpu}"
144
+ fi
145
+
146
+ # LightGlue Feature Matcher
147
+ if [ "${matcher_type}" == "lightglue" ]
148
+ then
149
+ pixi run -e colmap-sp python3 Baselines/colmap/lightglue_matcher.py
150
+ colmap matches_importer \
151
+ --database_path ${database} \
152
+ --match_list_path "${exp_folder_colmap}/matches.txt" \
153
+ --match_type raw
154
+ fi
155
+
lightglue_matcher.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from utilities import lightglue_keypoints, lightglue_matching, unrotate_kps_W
3
+ import os
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import cv2
9
+ import random
10
+
11
+ # ==========================================
12
+ # CONFIGURATION
13
+ # ==========================================
14
+ DB_PATH = "/home/alejandro/VSLAM-LAB-NEXT-ITERATION/VSLAM-LAB-Evaluation/demo/SESOKO/sskall-s01/colmap_00000/colmap_database.db"
15
+ IMAGE_DIR = "/home/alejandro/VSLAM-LAB-NEXT-ITERATION/VSLAM-LAB-Benchmark/SESOKO/sskall-s01/rgb_0"
16
+ FEATURE_TYPE = 'superpoint'
17
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ matches_file_path = os.path.join(os.path.dirname(DB_PATH), "matches.txt")
19
+
20
+ # ==========================================
21
+ # ==========================================
22
+ # DATABASE UTILITIES
23
+ # ==========================================
24
+ def load_colmap_db(db_path):
25
+ if not os.path.exists(db_path):
26
+ raise FileNotFoundError(f"Database file not found: {db_path}")
27
+ conn = sqlite3.connect(db_path)
28
+ cursor = conn.cursor()
29
+ return conn, cursor
30
+
31
+ def create_pair_id(image_id1, image_id2):
32
+ if image_id1 > image_id2:
33
+ image_id1, image_id2 = image_id2, image_id1
34
+ return image_id1 * 2147483647 + image_id2
35
+
36
+ def clean_database(cursor):
37
+ """Removes existing features and matches to ensure a clean overwrite."""
38
+ tables = ["keypoints", "descriptors"]#, "matches"], "two_view_geometry"]
39
+ for table in tables:
40
+ cursor.execute(f"DELETE FROM {table};")
41
+ print("Database cleaned (keypoints, descriptors, matches removed).")
42
+
43
+ def insert_keypoints(cursor, image_id, keypoints, descriptors):
44
+ """
45
+ keypoints: (N, 2) numpy array, float32
46
+ descriptors: (N, D) numpy array, float32
47
+ """
48
+ keypoints_blob = keypoints.tobytes()
49
+ descriptors_blob = descriptors.tobytes()
50
+
51
+ # Keypoints
52
+ cursor.execute(
53
+ "INSERT INTO keypoints(image_id, rows, cols, data) VALUES(?, ?, ?, ?)",
54
+ (image_id, keypoints.shape[0], keypoints.shape[1], keypoints_blob)
55
+ )
56
+
57
+ # Descriptors (Optional but good practice)
58
+ cursor.execute(
59
+ "INSERT INTO descriptors(image_id, rows, cols, data) VALUES(?, ?, ?, ?)",
60
+ (image_id, descriptors.shape[0], descriptors.shape[1], descriptors_blob)
61
+ )
62
+
63
+ def insert_matches(cursor, image_id1, image_id2, matches):
64
+ """
65
+ matches: (K, 2) numpy array, uint32.
66
+ Col 0 is index in image1, Col 1 is index in image2
67
+ """
68
+ pair_id = create_pair_id(image_id1, image_id2)
69
+ matches_blob = matches.tobytes()
70
+
71
+ cursor.execute(
72
+ "INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?)",
73
+ (pair_id, matches.shape[0], matches.shape[1], matches_blob)
74
+ )
75
+
76
+ def verify_matches_visual(cursor, image_id1, image_id2, image_dir):
77
+ """
78
+ Reads matches and keypoints from the COLMAP db and plots them.
79
+
80
+ Args:
81
+ cursor: SQLite cursor connected to the database.
82
+ image_id1: ID of the first image.
83
+ image_id2: ID of the second image.
84
+ image_dir: Path to the directory containing the images.
85
+ """
86
+
87
+ # 1. Helper to ensure image_id1 < image_id2 for pair_id calculation
88
+ if image_id1 > image_id2:
89
+ image_id1, image_id2 = image_id2, image_id1
90
+ swapped = True
91
+ else:
92
+ swapped = False
93
+
94
+ pair_id = image_id1 * 2147483647 + image_id2
95
+
96
+ # 2. Fetch Matches
97
+ cursor.execute("SELECT data FROM matches WHERE pair_id = ?", (pair_id,))
98
+ match_row = cursor.fetchone()
99
+
100
+ if match_row is None:
101
+ print(f"No matches found in DB for pair {image_id1}-{image_id2}")
102
+ return
103
+
104
+ # Decode Matches: UINT32 (N, 2)
105
+ matches = np.frombuffer(match_row[0], dtype=np.uint32).reshape(-1, 2)
106
+
107
+ # If we swapped inputs to generate pair_id, we must swap columns in matches
108
+ # so matches[:,0] corresponds to the requested image_id1
109
+ if swapped:
110
+ matches = matches[:, [1, 0]]
111
+
112
+ # 3. Fetch Keypoints for both images
113
+ def get_keypoints_and_name(img_id):
114
+ # Get Name
115
+ cursor.execute("SELECT name FROM images WHERE image_id = ?", (img_id,))
116
+ name = cursor.fetchone()[0]
117
+
118
+ # Get Keypoints
119
+ cursor.execute("SELECT data FROM keypoints WHERE image_id = ?", (img_id,))
120
+ kp_row = cursor.fetchone()
121
+ # Decode Keypoints: FLOAT32 (N, 2)
122
+ kpts = np.frombuffer(kp_row[0], dtype=np.float32).reshape(-1, 2)
123
+ return name, kpts
124
+
125
+ name1, kpts1 = get_keypoints_and_name(image_id1)
126
+ name2, kpts2 = get_keypoints_and_name(image_id2)
127
+
128
+ # 4. Filter Keypoints using the Matches indices
129
+ # matches[:, 0] are indices into kpts1
130
+ # matches[:, 1] are indices into kpts2
131
+ valid_kpts1 = kpts1[matches[:, 0]]
132
+ valid_kpts2 = kpts2[matches[:, 1]]
133
+
134
+ # 5. Load Images
135
+ path1 = os.path.join(image_dir, name1)
136
+ path2 = os.path.join(image_dir, name2)
137
+
138
+ img1 = cv2.imread(path1)
139
+ img2 = cv2.imread(path2)
140
+
141
+ # Convert BGR (OpenCV) to RGB (Matplotlib)
142
+ img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
143
+ img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
144
+
145
+ # 6. Plotting
146
+ # Concatenate images side-by-side
147
+ h1, w1, _ = img1.shape
148
+ h2, w2, _ = img2.shape
149
+
150
+ # Create a canvas large enough for both
151
+ height = max(h1, h2)
152
+ width = w1 + w2
153
+ canvas = np.zeros((height, width, 3), dtype=np.uint8)
154
+
155
+ canvas[:h1, :w1, :] = img1
156
+ canvas[:h2, w1:w1+w2, :] = img2
157
+
158
+ plt.figure(figsize=(15, 10))
159
+ plt.imshow(canvas)
160
+
161
+ # Plot lines
162
+ # Shift x-coordinates of image2 by w1
163
+ for (x1, y1), (x2, y2) in zip(valid_kpts1, valid_kpts2):
164
+ plt.plot([x1, x2 + w1], [y1, y2], 'c-', alpha=0.6, linewidth=0.5)
165
+ plt.plot(x1, y1, 'r.', markersize=2)
166
+ plt.plot(x2 + w1, y2, 'r.', markersize=2)
167
+
168
+ plt.title(f"DB Verification: {name1} (ID:{image_id1}) <-> {name2} (ID:{image_id2}) | Matches: {len(matches)}")
169
+ plt.axis('off')
170
+ plt.tight_layout()
171
+ plt.show()
172
+
173
+ import numpy as np
174
+ import matplotlib.pyplot as plt
175
+ import cv2
176
+ import os
177
+ import sqlite3
178
+
179
+ def plot_matches_from_db(cursor, image_id1, image_id2, image_dir):
180
+ """
181
+ Reads matches and keypoints for a specific pair from the COLMAP DB and plots them.
182
+
183
+ Args:
184
+ cursor: SQLite cursor.
185
+ image_id1, image_id2: The IDs of the two images to plot.
186
+ image_dir: Path to the directory containing the actual image files.
187
+ """
188
+
189
+ # 1. Resolve Pair ID (Colmap requires id1 < id2 for unique pair_id)
190
+ if image_id1 > image_id2:
191
+ id_a, id_b = image_id2, image_id1
192
+ swapped = True
193
+ else:
194
+ id_a, id_b = image_id1, image_id2
195
+ swapped = False
196
+
197
+ pair_id = id_a * 2147483647 + id_b
198
+
199
+ # 2. Fetch Matches
200
+ print(f"Fetching matches for pair {image_id1}-{image_id2} (PairID: {pair_id})...")
201
+ cursor.execute("SELECT data, rows, cols FROM matches WHERE pair_id = ?", (pair_id,))
202
+ match_row = cursor.fetchone()
203
+
204
+ if match_row is None:
205
+ print(f"No matches found in database for Pair {image_id1}-{image_id2}")
206
+ return
207
+
208
+ # Decode Matches (UINT32)
209
+ # Blob is match_row[0], rows is [1], cols is [2]
210
+ matches_blob = match_row[0]
211
+ matches = np.frombuffer(matches_blob, dtype=np.uint32).reshape(-1, 2)
212
+
213
+ # If inputs were swapped relative to how COLMAP stores them, swap the columns
214
+ # so matches[:,0] refers to image_id1 and matches[:,1] refers to image_id2
215
+ if swapped:
216
+ matches = matches[:, [1, 0]]
217
+
218
+ # 3. Fetch Keypoints & Image Names
219
+ def get_image_data(img_id):
220
+ cursor.execute("SELECT name FROM images WHERE image_id = ?", (img_id,))
221
+ res = cursor.fetchone()
222
+ if not res:
223
+ raise ValueError(f"Image ID {img_id} not found in 'images' table.")
224
+ name = res[0]
225
+
226
+ cursor.execute("SELECT data FROM keypoints WHERE image_id = ?", (img_id,))
227
+ kp_res = cursor.fetchone()
228
+ if not kp_res:
229
+ raise ValueError(f"No keypoints found for Image ID {img_id}.")
230
+
231
+ # Decode Keypoints (FLOAT32)
232
+ kpts = np.frombuffer(kp_res[0], dtype=np.float32).reshape(-1, 2)
233
+ return name, kpts
234
+
235
+ name1, kpts1 = get_image_data(image_id1)
236
+ name2, kpts2 = get_image_data(image_id2)
237
+
238
+ # 4. Filter Keypoints using Match Indices
239
+ valid_kpts1 = kpts1[matches[:, 0]]
240
+ valid_kpts2 = kpts2[matches[:, 1]]
241
+
242
+ # 5. Visualization
243
+ path1 = os.path.join(image_dir, name1)
244
+ path2 = os.path.join(image_dir, name2)
245
+
246
+ if not os.path.exists(path1) or not os.path.exists(path2):
247
+ print(f"Error: Could not find image files at \n{path1}\n{path2}")
248
+ return
249
+
250
+ img1 = cv2.imread(path1)
251
+ img2 = cv2.imread(path2)
252
+ img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
253
+ img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
254
+
255
+ # Create canvas
256
+ h1, w1 = img1.shape[:2]
257
+ h2, w2 = img2.shape[:2]
258
+ height = max(h1, h2)
259
+ width = w1 + w2
260
+ canvas = np.zeros((height, width, 3), dtype=np.uint8)
261
+ canvas[:h1, :w1] = img1
262
+ canvas[:h2, w1:w1+w2] = img2
263
+
264
+ plt.figure(figsize=(20, 10))
265
+ plt.imshow(canvas)
266
+
267
+ # Plot matches
268
+ # x2 coordinates need to be shifted by w1
269
+ for (x1, y1), (x2, y2) in zip(valid_kpts1, valid_kpts2):
270
+ plt.plot([x1, x2 + w1], [y1, y2], 'g-', alpha=0.5, linewidth=1.5)
271
+ plt.plot(x1, y1, 'r.', markersize=4)
272
+ plt.plot(x2 + w1, y2, 'r.', markersize=4)
273
+
274
+ plt.title(f"{name1} <-> {name2} | Total Matches: {len(matches)}")
275
+ plt.axis('off')
276
+ plt.tight_layout()
277
+ plt.show()
278
+
279
+ if __name__ == "__main__":
280
+
281
+ conn, cursor = load_colmap_db(DB_PATH)
282
+ cursor.execute("SELECT image_id, name FROM images")
283
+ images_info = {row[0]: row[1] for row in cursor.fetchall()}
284
+ image_ids = sorted(images_info.keys())
285
+ h = 505
286
+ w = 607
287
+ # plot_matches_from_db(cursor, image_ids[0], image_ids[1], IMAGE_DIR)
288
+ # exit(0)
289
+
290
+ clean_database(cursor)
291
+ conn.commit()
292
+
293
+ fts = {}
294
+ for i in tqdm(range(len(image_ids)), desc="Feature Extraction"):
295
+ id = image_ids[i]
296
+ fname = images_info[id]
297
+ path = os.path.join(IMAGE_DIR, fname)
298
+
299
+ feats_dict = lightglue_keypoints(path, features='superpoint')
300
+
301
+ fts[id] = feats_dict
302
+
303
+ kpts = feats_dict['keypoints'].squeeze(0).cpu().numpy().astype(np.float32)
304
+ descs = feats_dict['descriptors'].squeeze(0).cpu().numpy().astype(np.float32)
305
+
306
+ kpts_rot = unrotate_kps_W(kpts, feats_dict['rotations'].squeeze(0).cpu().numpy().astype(np.float32), h, w)
307
+ insert_keypoints(cursor, id, kpts_rot, descs)
308
+
309
+ conn.commit()
310
+ with open(matches_file_path, "w") as f_match:
311
+ for i in tqdm(range(len(image_ids)), desc="Feature Extraction"):
312
+ id1 = image_ids[i]
313
+ fname1 = images_info[id1]
314
+ path1 = os.path.join(IMAGE_DIR, fname1)
315
+
316
+ for j in range(i + 1, len(image_ids)):
317
+ if j == i:
318
+ continue
319
+ id2 = image_ids[j]
320
+
321
+ fname2 = images_info[id2]
322
+ path2 = os.path.join(IMAGE_DIR, fname2)
323
+ matches_tensor = lightglue_matching(fts[id1], fts[id2], plot=False, features='superpoint', path_to_image0=path1, path_to_image1=path2)
324
+
325
+ if matches_tensor is not None and len(matches_tensor) > 0:
326
+ matches_np = matches_tensor.cpu().numpy().astype(np.uint32)
327
+ #insert_matches(cursor, id1, id2, matches_np)
328
+
329
+ f_match.write(f"{fname1} {fname2}\n")
330
+ np.savetxt(f_match, matches_np, fmt="%d")
331
+ f_match.write("\n")
332
+
333
+ #verify_matches_visual(cursor, image_ids[i], image_ids[j], IMAGE_DIR)
334
+ #tqdm.write(f"Processed matches for Image ID {id1} in {duration:.2f} seconds.")
335
+
336
+ #plt.show()
337
+
338
+ conn.commit()
339
+
340
+ #plot_matches_from_db(cursor, image_ids[0], image_ids[1], IMAGE_DIR)
341
+
342
+ conn.close()
343
+ print("Database overwrite complete.")
lightglue_matcher_utilities.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ from lightglue import LightGlue
5
+ from lightglue.utils import rbd
6
+
7
+ def unrotate_kps_W(kps_rot, k, H, W):
8
+ import numpy as np
9
+
10
+ # Ensure inputs are Numpy
11
+ if hasattr(kps_rot, 'cpu'): kps_rot = kps_rot.cpu().numpy()
12
+ if hasattr(k, 'cpu'): k = k.cpu().numpy()
13
+
14
+ # Squeeze if necessary
15
+ if k.ndim > 1: k = k.squeeze()
16
+ if kps_rot.ndim > 2: kps_rot = kps_rot.squeeze()
17
+
18
+ x_r = kps_rot[:, 0]
19
+ y_r = kps_rot[:, 1]
20
+
21
+ x = np.zeros_like(x_r)
22
+ y = np.zeros_like(y_r)
23
+
24
+ mask0 = (k == 0)
25
+ x[mask0], y[mask0] = x_r[mask0], y_r[mask0]
26
+
27
+ mask1 = (k == 1)
28
+ x[mask1], y[mask1] = (W - 1) - y_r[mask1], x_r[mask1]
29
+
30
+ mask2 = (k == 2)
31
+ x[mask2], y[mask2] = (W - 1) - x_r[mask2], (H - 1) - y_r[mask2]
32
+
33
+ mask3 = (k == 3)
34
+ x[mask3], y[mask3] = y_r[mask3], (H - 1) - x_r[mask3]
35
+
36
+ return np.stack([x, y], axis=-1)
37
+
38
+ def unrotate_kps(kps_rot, k, H, W):
39
+ import torch
40
+ # k is how many times you rotated CCW by 90° to create the rotated image
41
+ x_r, y_r = kps_rot[:, 0].clone(), kps_rot[:, 1].clone()
42
+ if k == 0:
43
+ x, y = x_r, y_r
44
+ elif k == 1: # 90° CCW
45
+ x = (W - 1) - y_r
46
+ y = x_r
47
+ elif k == 2: # 180°
48
+ x = (W - 1) - x_r
49
+ y = (H - 1) - y_r
50
+ elif k == 3: # 270° CCW
51
+ x = y_r
52
+ y = (H - 1) - x_r
53
+ else:
54
+ raise ValueError("k must be 0..3")
55
+ return torch.stack([x, y], dim=-1)
56
+
57
+ # def lightglue_matching(path_to_image0, path_to_image1, plot=False, features='superpoint'):
58
+ # from lightglue import LightGlue, SuperPoint, SIFT
59
+ # from lightglue.utils import load_image, rbd
60
+ # from lightglue import viz2d
61
+ # import torch
62
+
63
+ # # --- Models on GPU ---
64
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
65
+
66
+ # if features == 'superpoint':
67
+ # extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
68
+ # if features == 'sift':
69
+ # extractor = SIFT(max_num_keypoints=2048).eval().to(device)
70
+
71
+ # matcher = LightGlue(features=features).eval().to(device)
72
+
73
+ # # --- Load images as Torch tensors (3,H,W) in [0,1] ---
74
+ # timg0 = load_image(path_to_image0).to(device)
75
+ # timg1 = load_image(path_to_image1).to(device)
76
+
77
+ # # --- Extract local features ---
78
+ # feats0 = extractor.extract(timg0) # auto-resize inside
79
+
80
+ # max_num_matches = -1
81
+ # best_k = 0
82
+ # best_feats0 = None
83
+ # best_feats1 = None
84
+ # for k in range(4):
85
+ # timg1_rotated = torch.rot90(timg1, k, dims=(1, 2))
86
+ # feats1_k = extractor.extract(timg1_rotated)
87
+ # out_k = matcher({'image0': feats0, 'image1': feats1_k})
88
+ # feats0_k, feats1_k, out_k = [rbd(x) for x in [feats0, feats1_k, out_k]] # remove batch dim
89
+ # matches_k = out_k['matches'] # (K,2) long
90
+ # num_k = len(matches_k)
91
+ # if num_k > max_num_matches:
92
+ # max_num_matches = num_k
93
+ # matches = matches_k
94
+ # best_feats0 = feats0_k
95
+ # best_feats1 = feats1_k
96
+ # best_k = k
97
+
98
+ # # --- Keypoints in matched order (Torch tensors on CPU) ---
99
+ # H1, W1 = timg1.shape[-2], timg1.shape[-1]
100
+
101
+ # kpts0 = best_feats0['keypoints'][matches[:, 0]]
102
+ # kpts1 = best_feats1['keypoints'][matches[:, 1]]
103
+ # kpts1 = unrotate_kps(kpts1, best_k, H1, W1) # (K,2) mapped to original image1 coords
104
+
105
+ # desc0 = best_feats0['descriptors'][matches[:, 0]]
106
+ # desc1 = best_feats1['descriptors'][matches[:, 1]]
107
+
108
+ # if plot:
109
+ # if len(kpts0) == 0 or len(kpts1) == 0:
110
+ # print("No matches found.")
111
+ # return None, None
112
+ # ax = viz2d.plot_images([timg0.cpu(), timg1.cpu()])
113
+ # viz2d.plot_matches(kpts0.cpu(), kpts1.cpu(), color=None, lw=0.8, axes=ax)
114
+ # #ax0 = ax[0] if isinstance(ax, (list, tuple, np.ndarray)) else ax
115
+ # #fig = ax0.figure
116
+
117
+ # #return kpts0, kpts1 #, fig, ax
118
+
119
+
120
+ # return kpts0, kpts1, desc0, desc1
121
+
122
+ def lightglue_keypoints(path_to_image0, features='superpoint', rotations = [0,1,2,3]):
123
+ from lightglue import LightGlue, SuperPoint, SIFT
124
+ from lightglue.utils import load_image, rbd
125
+ from lightglue import viz2d
126
+ import torch
127
+
128
+ # --- Models on GPU ---
129
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
130
+
131
+ if features == 'superpoint':
132
+ extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
133
+ if features == 'sift':
134
+ extractor = SIFT(max_num_keypoints=2048).eval().to(device)
135
+
136
+ # --- Load images as Torch tensors (3,H,W) in [0,1] ---
137
+ timg = load_image(path_to_image0).to(device)
138
+ _, h, w = timg.shape
139
+
140
+
141
+ # --- Extract local features ---
142
+ feats = {}
143
+ for k in (rotations):
144
+ timg_rotated = torch.rot90(timg, k, dims=(1, 2))
145
+ feats[k] = extractor.extract(timg_rotated)
146
+ print(f"Extracted {feats[k]['keypoints'].shape[1]} keypoints for rotation {k}")
147
+
148
+ # --- Merge features back to original coordinate system ---
149
+ all_keypoints = []
150
+ all_scores = []
151
+ all_descriptors = []
152
+ all_rotations = []
153
+ for k, feat in feats.items():
154
+ kpts = feat['keypoints'] # Shape (1, N, 2)
155
+ num_kpts = kpts.shape[1]
156
+ if k == 0:
157
+ kpts_corrected = kpts
158
+ elif k == 1:
159
+ kpts_corrected = torch.stack(
160
+ [w - 1 - kpts[..., 1], kpts[..., 0]], dim=-1
161
+ )
162
+ elif k == 2:
163
+ kpts_corrected = torch.stack(
164
+ [w - 1 - kpts[..., 0], h - 1 - kpts[..., 1]], dim=-1
165
+ )
166
+ elif k == 3:
167
+ kpts_corrected = torch.stack(
168
+ [kpts[..., 1], h - 1 - kpts[..., 0]], dim=-1
169
+ )
170
+
171
+ rot_indices = torch.full((1, num_kpts), k, dtype=torch.long, device=device)
172
+ all_keypoints.append(feat['keypoints'])
173
+ all_scores.append(feat['keypoint_scores'])
174
+ all_descriptors.append(feat['descriptors'])
175
+ all_rotations.append(rot_indices)
176
+
177
+ # Concatenate all features along the keypoint dimension (dim=1)
178
+ feats_merged = {
179
+ 'keypoints': torch.cat(all_keypoints, dim=1),
180
+ 'keypoint_scores': torch.cat(all_scores, dim=1),
181
+ 'descriptors': torch.cat(all_descriptors, dim=1),
182
+ 'rotations': torch.cat(all_rotations, dim=1)
183
+ }
184
+
185
+ num_kpts = feats_merged['keypoints'].shape[1]
186
+ # perm = torch.randperm(num_kpts, device=device)
187
+
188
+ # feats_merged['keypoints'] = feats_merged['keypoints'][:, perm, :]
189
+ # feats_merged['keypoint_scores'] = feats_merged['keypoint_scores'][:, perm]
190
+ # feats_merged['descriptors'] = feats_merged['descriptors'][:, perm, :]
191
+
192
+ # Optional: If you want to retain other keys like 'shape' or 'image_size'
193
+ feats_merged['image_size'] = torch.tensor([w, h], device=device).unsqueeze(0)
194
+ return feats_merged
195
+
196
+ def lightglue_matching(feats0, feats1, plot=False, features='superpoint', path_to_image0=None, path_to_image1=None):
197
+ from lightglue import LightGlue, SuperPoint, SIFT
198
+ from lightglue.utils import load_image, rbd
199
+ from lightglue import viz2d
200
+ import torch
201
+
202
+ # --- Models on GPU ---
203
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
204
+
205
+ matcher = LightGlue(features=features).eval().to(device)
206
+
207
+ # --- Load images as Torch tensors (3,H,W) in [0,1] ---
208
+ if plot:
209
+ timg0 = load_image(path_to_image0).to(device)
210
+ timg1 = load_image(path_to_image1).to(device)
211
+
212
+ # --- Extract local features ---
213
+
214
+ max_num_matches = -1
215
+ best_k = 0
216
+ best_feats0 = None
217
+ best_feats1 = None
218
+ for k in range(1):
219
+ #timg1_rotated = torch.rot90(timg1, k, dims=(1, 2))
220
+ feats1_k = feats1 #extractor.extract(timg1_rotated)
221
+ out_k = matcher({'image0': feats0, 'image1': feats1_k})
222
+ feats0_k, feats1_k, out_k = [rbd(x) for x in [feats0, feats1_k, out_k]] # remove batch dim
223
+ matches_k = out_k['matches'] # (K,2) long
224
+ num_k = len(matches_k)
225
+ if num_k > max_num_matches:
226
+ max_num_matches = num_k
227
+ matches = matches_k
228
+ best_feats0 = feats0_k
229
+ best_feats1 = feats1_k
230
+ best_k = k
231
+ print(f"LightGlue found {len(matches)} matches.")
232
+ # --- Keypoints in matched order (Torch tensors on CPU) ---
233
+ #H1, W1 = timg1.shape[-2], timg1.shape[-1]
234
+
235
+ # kpts0 = best_feats0['keypoints'][matches[:, 0]]
236
+ # kpts1 = best_feats1['keypoints'][matches[:, 1]]
237
+ # #kpts1 = unrotate_kps(kpts1, best_k, H1, W1) # (K,2) mapped to original image1 coords
238
+
239
+ # desc0 = best_feats0['descriptors'][matches[:, 0]]
240
+ # desc1 = best_feats1['descriptors'][matches[:, 1]]
241
+
242
+ # pts0 = kpts0.detach().cpu().numpy().astype(np.float32) # (K,2)
243
+ # pts1 = kpts1.detach().cpu().numpy().astype(np.float32) # (K,2)
244
+ # H, inliers = cv2.findHomography(pts0, pts1, cv2.RANSAC, 5.0)
245
+
246
+ # if inliers is not None:
247
+ # mask = inliers.ravel() == 1
248
+ # mask_tensor = torch.from_numpy(mask).to(matches.device)
249
+ # matches = matches[mask_tensor]
250
+ # else:
251
+ # # If geometry check failed completely, return no matches
252
+ # return None
253
+
254
+ # if plot:
255
+ # if len(kpts0) == 0 or len(kpts1) == 0:
256
+ # print("No matches found.")
257
+ # return None, None
258
+ # ax = viz2d.plot_images([timg0.cpu(), timg1.cpu()])
259
+ # viz2d.plot_matches(kpts0.cpu(), kpts1.cpu(), color=None, lw=0.8, axes=ax)
260
+ # #ax0 = ax[0] if isinstance(ax, (list, tuple, np.ndarray)) else ax
261
+ # #fig = ax0.figure
262
+
263
+ # #return kpts0, kpts1 #, fig, ax
264
+
265
+
266
+ return matches