Monserza commited on
Commit
c453978
·
verified ·
1 Parent(s): 729d1a9

Upload 45 files

Browse files
Files changed (46) hide show
  1. .gitattributes +1 -0
  2. senior-demo/checkpoints/depth_pro.pt +3 -0
  3. senior-demo/demo.py +596 -0
  4. senior-demo/depth-pro.ipynb +0 -0
  5. senior-demo/example/padkaprod.jpg +3 -0
  6. senior-demo/model/yolo-seg.pt +3 -0
  7. senior-demo/nutrition_data.json +92 -0
  8. senior-demo/presetdata.json +50 -0
  9. senior-demo/src/depth_pro.egg-info/PKG-INFO +113 -0
  10. senior-demo/src/depth_pro.egg-info/SOURCES.txt +28 -0
  11. senior-demo/src/depth_pro.egg-info/dependency_links.txt +1 -0
  12. senior-demo/src/depth_pro.egg-info/entry_points.txt +2 -0
  13. senior-demo/src/depth_pro.egg-info/requires.txt +6 -0
  14. senior-demo/src/depth_pro.egg-info/top_level.txt +1 -0
  15. senior-demo/src/depth_pro/__init__.py +5 -0
  16. senior-demo/src/depth_pro/__pycache__/__init__.cpython-310.pyc +0 -0
  17. senior-demo/src/depth_pro/__pycache__/__init__.cpython-39.pyc +0 -0
  18. senior-demo/src/depth_pro/__pycache__/depth_pro.cpython-310.pyc +0 -0
  19. senior-demo/src/depth_pro/__pycache__/depth_pro.cpython-39.pyc +0 -0
  20. senior-demo/src/depth_pro/__pycache__/utils.cpython-310.pyc +0 -0
  21. senior-demo/src/depth_pro/__pycache__/utils.cpython-39.pyc +0 -0
  22. senior-demo/src/depth_pro/cli/__init__.py +4 -0
  23. senior-demo/src/depth_pro/cli/run.py +154 -0
  24. senior-demo/src/depth_pro/depth_pro.py +298 -0
  25. senior-demo/src/depth_pro/eval/boundary_metrics.py +332 -0
  26. senior-demo/src/depth_pro/eval/dis5k_sample_list.txt +200 -0
  27. senior-demo/src/depth_pro/network/__init__.py +2 -0
  28. senior-demo/src/depth_pro/network/__pycache__/__init__.cpython-310.pyc +0 -0
  29. senior-demo/src/depth_pro/network/__pycache__/__init__.cpython-39.pyc +0 -0
  30. senior-demo/src/depth_pro/network/__pycache__/decoder.cpython-310.pyc +0 -0
  31. senior-demo/src/depth_pro/network/__pycache__/decoder.cpython-39.pyc +0 -0
  32. senior-demo/src/depth_pro/network/__pycache__/encoder.cpython-310.pyc +0 -0
  33. senior-demo/src/depth_pro/network/__pycache__/encoder.cpython-39.pyc +0 -0
  34. senior-demo/src/depth_pro/network/__pycache__/fov.cpython-310.pyc +0 -0
  35. senior-demo/src/depth_pro/network/__pycache__/fov.cpython-39.pyc +0 -0
  36. senior-demo/src/depth_pro/network/__pycache__/vit.cpython-310.pyc +0 -0
  37. senior-demo/src/depth_pro/network/__pycache__/vit.cpython-39.pyc +0 -0
  38. senior-demo/src/depth_pro/network/__pycache__/vit_factory.cpython-310.pyc +0 -0
  39. senior-demo/src/depth_pro/network/__pycache__/vit_factory.cpython-39.pyc +0 -0
  40. senior-demo/src/depth_pro/network/decoder.py +206 -0
  41. senior-demo/src/depth_pro/network/encoder.py +332 -0
  42. senior-demo/src/depth_pro/network/fov.py +82 -0
  43. senior-demo/src/depth_pro/network/vit.py +123 -0
  44. senior-demo/src/depth_pro/network/vit_factory.py +124 -0
  45. senior-demo/src/depth_pro/utils.py +112 -0
  46. senior-demo/test.py +6 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ senior-demo/example/padkaprod.jpg filter=lfs diff=lfs merge=lfs -text
senior-demo/checkpoints/depth_pro.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3eb35ca68168ad3d14cb150f8947a4edf85589941661fdb2686259c80685c0ce
3
+ size 1904446787
senior-demo/demo.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo.py — Depth Pro + YOLO segmentation + Portion & Nutrition post-processing (tables version)
2
+
3
+ import sys
4
+ import json
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ from PIL import Image
9
+ import gradio as gr
10
+ from ultralytics import YOLO
11
+
12
+ # -----------------------------------------------------------
13
+ # 1. Import depth_pro (adjust path if needed)
14
+ # -----------------------------------------------------------
15
+ # If depth_pro is in a local folder "ml-depth-pro/src" next to this file:
16
+ sys.path.append("ml-depth-pro/src")
17
+ import depth_pro # noqa: E402
18
+
19
+ # -----------------------------------------------------------
20
+ # 2. Device selection
21
+ # -----------------------------------------------------------
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ print(f"[INFO] Using device: {device}")
24
+
25
+ # -----------------------------------------------------------
26
+ # 3. Load Depth Pro model
27
+ # -----------------------------------------------------------
28
+ print("[INFO] Loading Depth Pro model...")
29
+ dp_model, dp_transform = depth_pro.create_model_and_transforms()
30
+ dp_model = dp_model.to(device)
31
+ dp_model.eval()
32
+ print("[INFO] Depth Pro ready.")
33
+
34
+ # -----------------------------------------------------------
35
+ # 4. Load YOLO segmentation model
36
+ # -----------------------------------------------------------
37
+ # TODO: change this to your actual best.pt path
38
+ YOLO_MODEL_PATH = r"C:\Users\monol\Desktop\Senior_demo\ml-depth-pro\model\yolo-seg.pt"
39
+ print(f"[INFO] Loading YOLO model from: {YOLO_MODEL_PATH}")
40
+ yolo_model = YOLO(YOLO_MODEL_PATH)
41
+ print("[INFO] YOLO ready.")
42
+
43
+ # -----------------------------------------------------------
44
+ # 5. Load preset + nutrition metadata
45
+ # -----------------------------------------------------------
46
+ try:
47
+ with open("presetdata.json", "r", encoding="utf-8") as f:
48
+ PRESET_LIST = json.load(f)
49
+ PRESET_BY_CLASS = {item["class"]: item for item in PRESET_LIST}
50
+ print(f"[INFO] Loaded {len(PRESET_LIST)} preset entries.")
51
+ except Exception as e:
52
+ print("[WARN] Could not load presetdata.json:", e)
53
+ PRESET_LIST = []
54
+ PRESET_BY_CLASS = {}
55
+
56
+ try:
57
+ with open("nutrition_data.json", "r", encoding="utf-8") as f:
58
+ NUTRITION_LIST = json.load(f)
59
+ NUTR_BY_CLASS = {item["class"]: item for item in NUTRITION_LIST}
60
+ print(f"[INFO] Loaded {len(NUTRITION_LIST)} nutrition entries.")
61
+ except Exception as e:
62
+ print("[WARN] Could not load nutrition_data.json:", e)
63
+ NUTRITION_LIST = []
64
+ NUTR_BY_CLASS = {}
65
+
66
+ # -----------------------------------------------------------
67
+ # 6. Helper: make depth visualization (RGB uint8)
68
+ # -----------------------------------------------------------
69
+ def make_depth_vis(depth: np.ndarray) -> np.ndarray:
70
+ """
71
+ depth: HxW float (meters), may contain NaNs
72
+ returns: HxWx3 uint8 RGB image
73
+ """
74
+ d = depth.copy()
75
+ d[~np.isfinite(d)] = np.nan
76
+
77
+ if not np.isfinite(d).any():
78
+ return np.zeros((*depth.shape, 3), dtype=np.uint8)
79
+
80
+ d_min = np.nanpercentile(d, 1)
81
+ d_max = np.nanpercentile(d, 99)
82
+ if d_max <= d_min:
83
+ d_max = d_min + 1e-6
84
+
85
+ d_norm = (d - d_min) / (d_max - d_min)
86
+ d_norm = np.clip(d_norm, 0.0, 1.0)
87
+ d_uint8 = (d_norm * 255).astype(np.uint8)
88
+
89
+ depth_color_bgr = cv2.applyColorMap(d_uint8, cv2.COLORMAP_INFERNO)
90
+ depth_color_rgb = cv2.cvtColor(depth_color_bgr, cv2.COLOR_BGR2RGB)
91
+ return depth_color_rgb
92
+
93
+
94
+ # -----------------------------------------------------------
95
+ # 7. Portion + nutrition helper functions
96
+ # Using your equation:
97
+ # Mass_in = Mass_ref * (%area_in / %area_ref) * (Z_in / Z_ref)^2
98
+ # -----------------------------------------------------------
99
+ def estimate_portion_for_class(cls_name, area_in_pct, z_in_m, default_z_in=None):
100
+ """
101
+ Estimate portion (grams) for one class using preset reference + depth.
102
+ area_in_pct: percentage area of image (0-100)
103
+ z_in_m: median depth for that class (meters)
104
+ """
105
+ preset = PRESET_BY_CLASS.get(cls_name)
106
+ if not preset:
107
+ return None
108
+
109
+ try:
110
+ mass_ref = float(preset["portion"]) # grams
111
+ area_ref = float(preset["mask_region"]) # % area in reference
112
+ z_ref = float(preset["center_depth"]) # meters
113
+ except (KeyError, ValueError, TypeError):
114
+ return None
115
+
116
+ if area_ref <= 0 or z_ref <= 0:
117
+ return None
118
+
119
+ if z_in_m is None:
120
+ z_in_m = default_z_in
121
+ if z_in_m is None or not np.isfinite(z_in_m) or z_in_m <= 0:
122
+ return None
123
+
124
+ # Apply your scaling equation
125
+ mass_in = mass_ref * (area_in_pct / area_ref) * (z_in_m / z_ref) ** 2
126
+
127
+ return {
128
+ "class": cls_name,
129
+ "estimated_portion_g": float(mass_in),
130
+ "area_in_pct": float(area_in_pct),
131
+ "area_ref_pct": float(area_ref),
132
+ "z_in_m": float(z_in_m),
133
+ "z_ref_m": float(z_ref),
134
+ "mass_ref_g": float(mass_ref),
135
+ }
136
+
137
+
138
+ def estimate_nutrition_for_mass(class_name, mass_g):
139
+ """
140
+ Use nutrition_data.json to scale nutrition by mass.
141
+ Typically data is per 100 g.
142
+ """
143
+ nutr = NUTR_BY_CLASS.get(class_name)
144
+ if not nutr:
145
+ return None
146
+
147
+ try:
148
+ ref_mass = float(nutr["amount"])
149
+ calories = float(nutr["calories"])
150
+ protein = float(nutr["protein"])
151
+ fat = float(nutr["fat"])
152
+ carbs = float(nutr["carbohydrates"])
153
+ sodium = float(nutr["sodium"])
154
+ except (KeyError, ValueError, TypeError):
155
+ return None
156
+
157
+ if ref_mass <= 0:
158
+ return None
159
+
160
+ factor = mass_g / ref_mass
161
+
162
+ return {
163
+ "class": class_name,
164
+ "mass_g": float(mass_g),
165
+ "calories": calories * factor,
166
+ "protein": protein * factor,
167
+ "fat": fat * factor,
168
+ "carbohydrates": carbs * factor,
169
+ "sodium": sodium * factor,
170
+ }
171
+
172
+
173
+ def breakdown_ingredients(dish_class_name, dish_mass_g):
174
+ """
175
+ Split a dish (e.g., pad kaprao) into ingredients using presetdata.json,
176
+ then compute ingredient-level nutrition if available in nutrition_data.json.
177
+ """
178
+ preset = PRESET_BY_CLASS.get(dish_class_name)
179
+ if not preset or "ingredients" not in preset:
180
+ return [], []
181
+
182
+ try:
183
+ portion_ref = float(preset["portion"])
184
+ except (KeyError, ValueError, TypeError):
185
+ return [], []
186
+
187
+ if portion_ref <= 0:
188
+ return [], []
189
+
190
+ ingredient_masses = []
191
+ ingredient_nutrition = []
192
+
193
+ for ing in preset["ingredients"]:
194
+ ing_name = ing.get("name")
195
+ try:
196
+ ing_ref_mass = float(ing["amount"])
197
+ except (KeyError, ValueError, TypeError):
198
+ continue
199
+
200
+ ratio = ing_ref_mass / portion_ref
201
+ ing_mass_in = dish_mass_g * ratio
202
+
203
+ ingredient_masses.append({
204
+ "dish_class": dish_class_name,
205
+ "ingredient": ing_name,
206
+ "mass_g": float(ing_mass_in),
207
+ })
208
+
209
+ nutr = estimate_nutrition_for_mass(ing_name, ing_mass_in)
210
+ if nutr:
211
+ nutr["dish_class"] = dish_class_name
212
+ ingredient_nutrition.append(nutr)
213
+
214
+ return ingredient_masses, ingredient_nutrition
215
+
216
+
217
+ def postprocess_ai_results(rows, center_depth_m):
218
+ """
219
+ rows: list of [class_name, area_pct, median_depth_m]
220
+ center_depth_m: depth at center of image (meters)
221
+
222
+ Returns:
223
+ - portions_json: list of dicts like
224
+ {
225
+ "class": "pad kaprao",
226
+ "portion": 100,
227
+ "portion_label": "gram",
228
+ "center_depth": "0.47",
229
+ "mask_region": "5.07"
230
+ }
231
+ - dish_nutr_json: list of dish-level nutrition dicts
232
+ - ingredient_nutr_json: list of ingredient-level nutrition dicts
233
+ """
234
+ portions_json = []
235
+ dish_nutr_json = []
236
+ ingredient_nutr_json = []
237
+
238
+ for cls_name, area_pct, md in rows:
239
+ if area_pct is None:
240
+ continue
241
+
242
+ # Use median depth if available; otherwise use global center depth
243
+ if md is not None and np.isfinite(md):
244
+ z_in = md
245
+ else:
246
+ z_in = center_depth_m
247
+
248
+ portion_info = estimate_portion_for_class(
249
+ cls_name=cls_name,
250
+ area_in_pct=area_pct,
251
+ z_in_m=z_in,
252
+ default_z_in=center_depth_m,
253
+ )
254
+ if portion_info is None:
255
+ continue
256
+
257
+ # Portion JSON in your requested-ish format
258
+ portions_json.append({
259
+ "class": portion_info["class"],
260
+ "portion": round(portion_info["estimated_portion_g"], 2),
261
+ "portion_label": "gram",
262
+ "center_depth": f"{portion_info['z_in_m']:.2f}",
263
+ "mask_region": f"{portion_info['area_in_pct']:.2f}",
264
+ })
265
+
266
+ # Dish-level nutrition
267
+ dish_n = estimate_nutrition_for_mass(
268
+ cls_name,
269
+ portion_info["estimated_portion_g"]
270
+ )
271
+ if dish_n:
272
+ dish_nutr_json.append({
273
+ "class": dish_n["class"],
274
+ "mass_g": round(dish_n["mass_g"], 2),
275
+ "calories": round(dish_n["calories"], 1),
276
+ "protein": round(dish_n["protein"], 1),
277
+ "fat": round(dish_n["fat"], 1),
278
+ "carbohydrates": round(dish_n["carbohydrates"], 1),
279
+ "sodium": round(dish_n["sodium"], 1),
280
+ })
281
+
282
+ # Ingredient-level nutrition (show ALL ingredients, even if we don’t know nutrition)
283
+ ing_masses, ing_nutrition = breakdown_ingredients(
284
+ dish_class_name=cls_name,
285
+ dish_mass_g=portion_info["estimated_portion_g"],
286
+ )
287
+
288
+ # Build a quick lookup: (dish_class, ingredient_name) -> nutrition dict
289
+ nutr_lookup = {}
290
+ for n in ing_nutrition:
291
+ key = (n.get("dish_class", cls_name), n["class"])
292
+ nutr_lookup[key] = n
293
+
294
+ for mass_rec in ing_masses:
295
+ dish_cls = mass_rec["dish_class"]
296
+ ing_name = mass_rec["ingredient"]
297
+ mass_g = mass_rec["mass_g"]
298
+
299
+ key = (dish_cls, ing_name)
300
+ n = nutr_lookup.get(key)
301
+
302
+ if n is not None:
303
+ # We have nutrition data for this ingredient
304
+ ingredient_nutr_json.append({
305
+ "dish_class": dish_cls,
306
+ "ingredient": ing_name,
307
+ "mass_g": round(mass_g, 2),
308
+ "calories": round(n["calories"], 1),
309
+ "protein": round(n["protein"], 1),
310
+ "fat": round(n["fat"], 1),
311
+ "carbohydrates": round(n["carbohydrates"], 1),
312
+ "sodium": round(n["sodium"], 1),
313
+ })
314
+ else:
315
+ # No nutrition data -> still show ingredient with mass, leave nutrients blank
316
+ ingredient_nutr_json.append({
317
+ "dish_class": dish_cls,
318
+ "ingredient": ing_name,
319
+ "mass_g": round(mass_g, 2),
320
+ "calories": None,
321
+ "protein": None,
322
+ "fat": None,
323
+ "carbohydrates": None,
324
+ "sodium": None,
325
+ })
326
+
327
+
328
+ return portions_json, dish_nutr_json, ingredient_nutr_json
329
+
330
+
331
+ # -----------------------------------------------------------
332
+ # 8. Main pipeline: Depth Pro + YOLO segmentation + post-processing
333
+ # -----------------------------------------------------------
334
+ def analyze_image(pil_img: Image.Image):
335
+ # ---------- safety ----------
336
+ if pil_img is None:
337
+ blank = np.zeros((10, 10, 3), dtype=np.uint8)
338
+ return blank, blank, "Please upload an image first.", [], [], [], []
339
+
340
+ # Ensure RGB
341
+ pil_img = pil_img.convert("RGB")
342
+ rgb_np = np.array(pil_img)
343
+ H_s, W_s, _ = rgb_np.shape
344
+
345
+ # =======================================================
346
+ # A) YOLO segmentation (for mask & class percentages)
347
+ # =======================================================
348
+ seg_vis = rgb_np.copy()
349
+ class_to_mask = {} # class_name -> combined bool mask H_s x W_s
350
+
351
+ # YOLO expects BGR typically; convert
352
+ bgr_np = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2BGR)
353
+
354
+ try:
355
+ results = yolo_model.predict(
356
+ source=bgr_np,
357
+ save=False, # we don't save images to disk
358
+ conf=0.25,
359
+ iou=0.7,
360
+ verbose=False
361
+ )
362
+ r = results[0]
363
+
364
+ # visualization (BGR -> RGB)
365
+ seg_plot_bgr = r.plot()
366
+ seg_vis = cv2.cvtColor(seg_plot_bgr, cv2.COLOR_BGR2RGB)
367
+
368
+ if r.masks is not None and len(r.masks.data) > 0:
369
+ masks = r.masks.data.cpu().numpy() # [N, H, W] in YOLO image space
370
+ boxes = r.boxes
371
+ for i in range(len(masks)):
372
+ cls_id = int(boxes.cls[i])
373
+ cls_name = yolo_model.names[cls_id]
374
+ mask_i = masks[i] > 0.5 # bool H_s x W_s
375
+ if cls_name not in class_to_mask:
376
+ class_to_mask[cls_name] = mask_i
377
+ else:
378
+ class_to_mask[cls_name] |= mask_i
379
+ else:
380
+ print("[YOLO] No masks found.")
381
+ except Exception as e:
382
+ print("[YOLO ERROR]", e)
383
+
384
+ seg_vis = seg_vis.astype(np.uint8)
385
+
386
+ # =======================================================
387
+ # B) Depth Pro (distance from camera)
388
+ # =======================================================
389
+ try:
390
+ dp_in = dp_transform(pil_img).to(device)
391
+ with torch.no_grad():
392
+ pred = dp_model.infer(dp_in, f_px=None)
393
+
394
+ depth = pred["depth"]
395
+ if isinstance(depth, torch.Tensor):
396
+ depth = depth.squeeze().cpu().numpy()
397
+ except Exception as e:
398
+ blank = np.zeros((10, 10, 3), dtype=np.uint8)
399
+ return blank, seg_vis, f"Depth estimation error: {e}", [], [], [], []
400
+
401
+ if depth is None or not np.isfinite(depth).any():
402
+ blank = np.zeros((10, 10, 3), dtype=np.uint8)
403
+ return blank, seg_vis, "Depth map invalid (NaN/empty).", [], [], [], []
404
+
405
+ H_d, W_d = depth.shape
406
+
407
+ # depth visualization (resized to original image size)
408
+ depth_vis = make_depth_vis(depth)
409
+ depth_vis_big = cv2.resize(depth_vis, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
410
+ depth_vis_big = depth_vis_big.astype(np.uint8)
411
+
412
+ # -------------------------------------------------------
413
+ # Global depth summary (center + ROI)
414
+ # -------------------------------------------------------
415
+ cx_d, cy_d = W_d // 2, H_d // 2
416
+ center_depth = float(depth[cy_d, cx_d])
417
+
418
+ roi = depth[int(H_d * 0.4):int(H_d * 0.6), int(W_d * 0.4):int(W_d * 0.6)]
419
+ roi = roi[np.isfinite(roi)]
420
+ roi_depth = float(np.median(roi)) if roi.size > 0 else float("nan")
421
+
422
+ depth_lines = [
423
+ "### Depth Estimate",
424
+ f"- Center depth: **{center_depth:.2f} m**",
425
+ ]
426
+ if np.isfinite(roi_depth):
427
+ depth_lines.append(f"- Center ROI median depth: **{roi_depth:.2f} m**")
428
+
429
+ # =======================================================
430
+ # C) Compute % area + median depth per class
431
+ # =======================================================
432
+ total_pixels = H_s * W_s
433
+ rows = [] # for segmentation stats table: [class, area%, median_depth]
434
+
435
+ for cls_name, mask in class_to_mask.items():
436
+ # percentage of image area
437
+ area_px = int(mask.sum())
438
+ area_pct = 100.0 * area_px / total_pixels if total_pixels > 0 else 0.0
439
+
440
+ # resize mask to depth resolution to sample depth correctly
441
+ mask_u8 = (mask.astype(np.uint8) * 255)
442
+ mask_depth = cv2.resize(
443
+ mask_u8, (W_d, H_d), interpolation=cv2.INTER_NEAREST
444
+ ) > 0
445
+
446
+ obj_depths = depth[mask_depth & np.isfinite(depth)]
447
+ if obj_depths.size > 0:
448
+ median_depth = float(np.median(obj_depths))
449
+ else:
450
+ median_depth = float("nan")
451
+
452
+ rows.append([
453
+ cls_name,
454
+ round(area_pct, 2),
455
+ None if not np.isfinite(median_depth) else round(median_depth, 2)
456
+ ])
457
+
458
+ # Post-processing: portions + nutrition based on rows + center_depth
459
+ portions_json, dish_nutr_json, ingredient_nutr_json = postprocess_ai_results(
460
+ rows, center_depth
461
+ )
462
+
463
+ if rows:
464
+ depth_lines.append("\n### Object distances (per class)")
465
+ for cls_name, area_pct, md in rows:
466
+ if md is None:
467
+ depth_lines.append(
468
+ f"- {cls_name}: {area_pct:.2f}% of image, depth: N/A"
469
+ )
470
+ else:
471
+ depth_lines.append(
472
+ f"- {cls_name}: {area_pct:.2f}% of image, median depth **{md:.2f} m**"
473
+ )
474
+ else:
475
+ depth_lines.append("\n_No segmentation masks detected._")
476
+
477
+ depth_text = "\n".join(depth_lines)
478
+
479
+ # -------------------------------------------------------
480
+ # Convert JSON-like results to table rows for Dataframe
481
+ # -------------------------------------------------------
482
+ # Portions table: class, portion(g), center_depth(m), mask_region(%)
483
+ portions_table_rows = [
484
+ [
485
+ p["class"],
486
+ p["portion"],
487
+ p["portion_label"],
488
+ p["center_depth"],
489
+ p["mask_region"],
490
+ ]
491
+ for p in portions_json
492
+ ]
493
+
494
+ # Dish nutrition table: class, mass_g, kcal, protein, fat, carbs, sodium
495
+ dish_table_rows = [
496
+ [
497
+ d["class"],
498
+ d["mass_g"],
499
+ d["calories"],
500
+ d["protein"],
501
+ d["fat"],
502
+ d["carbohydrates"],
503
+ d["sodium"],
504
+ ]
505
+ for d in dish_nutr_json
506
+ ]
507
+
508
+ # Ingredient nutrition table:
509
+ # dish_class, ingredient, mass_g, kcal, protein, fat, carbs, sodium
510
+ ingredient_table_rows = [
511
+ [
512
+ ing["dish_class"],
513
+ ing["ingredient"],
514
+ ing["mass_g"],
515
+ ing["calories"],
516
+ ing["protein"],
517
+ ing["fat"],
518
+ ing["carbohydrates"],
519
+ ing["sodium"],
520
+ ]
521
+ for ing in ingredient_nutr_json
522
+ ]
523
+
524
+ return (
525
+ depth_vis_big,
526
+ seg_vis,
527
+ depth_text,
528
+ rows,
529
+ portions_table_rows,
530
+ dish_table_rows,
531
+ ingredient_table_rows,
532
+ )
533
+
534
+
535
+ # -----------------------------------------------------------
536
+ # 9. Gradio UI (using tables/Dataframe instead of JSON)
537
+ # -----------------------------------------------------------
538
+ with gr.Blocks() as demo:
539
+ gr.Markdown(
540
+ "<h2 style='text-align:center;'>Depth Pro + YOLO Segmentation + Nutrition Demo</h2>"
541
+ "<p style='text-align:center;'>"
542
+ "Upload a food image → get depth map, object distance, estimated portion, and nutrition per dish & ingredient."
543
+ "</p>"
544
+ )
545
+
546
+ with gr.Row():
547
+ input_img = gr.Image(label="Upload food image", type="pil")
548
+
549
+ with gr.Row():
550
+ depth_out = gr.Image(label="Depth overlay", type="numpy")
551
+ seg_out = gr.Image(label="Segmentation result", type="numpy")
552
+
553
+ with gr.Row():
554
+ depth_info = gr.Markdown(label="Depth estimate")
555
+
556
+ seg_table = gr.Dataframe(
557
+ headers=["Class", "Area % of image", "Median depth (m)"],
558
+ datatype=["str", "number", "number"],
559
+ label="Segmentation stats"
560
+ )
561
+
562
+ portions_table = gr.Dataframe(
563
+ headers=["Class", "Portion (g)", "Unit", "Center depth (m)", "Mask region (%)"],
564
+ datatype=["str", "number", "str", "str", "str"],
565
+ label="Estimated Portions (per class)",
566
+ )
567
+
568
+ dish_nutrition_table = gr.Dataframe(
569
+ headers=["Class", "Mass (g)", "Calories", "Protein (g)", "Fat (g)", "Carbs (g)", "Sodium (mg)"],
570
+ datatype=["str", "number", "number", "number", "number", "number", "number"],
571
+ label="Dish Nutrition (per class)",
572
+ )
573
+
574
+ ingredient_nutrition_table = gr.Dataframe(
575
+ headers=["Dish", "Ingredient", "Mass (g)", "Calories", "Protein (g)", "Fat (g)", "Carbs (g)", "Sodium (mg)"],
576
+ datatype=["str", "str", "number", "number", "number", "number", "number", "number"],
577
+ label="Ingredient Nutrition (per ingredient)",
578
+ )
579
+
580
+ run_btn = gr.Button("Run analysis")
581
+
582
+ run_btn.click(
583
+ fn=analyze_image,
584
+ inputs=input_img,
585
+ outputs=[
586
+ depth_out,
587
+ seg_out,
588
+ depth_info,
589
+ seg_table,
590
+ portions_table,
591
+ dish_nutrition_table,
592
+ ingredient_nutrition_table,
593
+ ],
594
+ )
595
+
596
+ demo.launch(server_name="0.0.0.0", server_port=7860)
senior-demo/depth-pro.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
senior-demo/example/padkaprod.jpg ADDED

Git LFS Details

  • SHA256: c45011c928370f28f076a7f70963ca80f51727e2fa094fdb47d84f39f34f4336
  • Pointer size: 131 Bytes
  • Size of remote file: 196 kB
senior-demo/model/yolo-seg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:863665409b804cefc8121aceb416f558f2f0d9e652c4261d4f366310649a9389
3
+ size 54903708
senior-demo/nutrition_data.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "class": "pad kaprao",
4
+ "amount": 100,
5
+ "unit": "gram",
6
+ "calories": 231,
7
+ "protein": 32.2,
8
+ "fat": 6.4,
9
+ "carbohydrates": 8.9,
10
+ "sodium": 852.5
11
+ },
12
+ {
13
+ "class": "jasmine rice",
14
+ "amount": 100,
15
+ "unit": "gram",
16
+ "calories": 130,
17
+ "protein": 2.7,
18
+ "fat": 0.3,
19
+ "carbohydrates": 28.2,
20
+ "sodium": 1
21
+ },
22
+ {
23
+ "class": "cucumber",
24
+ "amount": 100,
25
+ "unit": "gram",
26
+ "calories": 18,
27
+ "protein": 0.6,
28
+ "fat": 0.1,
29
+ "carbohydrates": 3.6,
30
+ "sodium": 5
31
+ },
32
+ {
33
+ "class": "pork",
34
+ "amount": 100,
35
+ "unit": "gram",
36
+ "calories": 242,
37
+ "protein": 27,
38
+ "fat": 14,
39
+ "carbohydrates": 0,
40
+ "sodium": 62
41
+ },
42
+ {
43
+ "class": "chicken",
44
+ "amount": 100,
45
+ "unit": "gram",
46
+ "calories": 239,
47
+ "protein": 27,
48
+ "fat": 14,
49
+ "carbohydrates": 0,
50
+ "sodium": 62
51
+ },
52
+ {
53
+ "class": "beef",
54
+ "amount": 100,
55
+ "unit": "gram",
56
+ "calories": 250,
57
+ "protein": 26,
58
+ "fat": 15,
59
+ "carbohydrates": 0,
60
+ "sodium": 72
61
+ },
62
+ {
63
+ "class": "fried egg",
64
+ "amount": 100,
65
+ "unit": "gram",
66
+ "calories": 196,
67
+ "protein": 13,
68
+ "fat": 15,
69
+ "carbohydrates": 1.1,
70
+ "sodium": 124
71
+ },
72
+ {
73
+ "class": "basil",
74
+ "amount": 100,
75
+ "unit": "gram",
76
+ "calories": 23,
77
+ "protein": 3.2,
78
+ "fat": 0.6,
79
+ "carbohydrates": 2.7,
80
+ "sodium": 4
81
+ },
82
+ {
83
+ "class": "chili",
84
+ "amount": 100,
85
+ "unit": "gram",
86
+ "calories": 40,
87
+ "protein": 1.9,
88
+ "fat": 0.4,
89
+ "carbohydrates": 8.8,
90
+ "sodium": 7
91
+ }
92
+ ]
senior-demo/presetdata.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "class": "pad kaprao",
4
+ "portion": 200,
5
+ "unit": "gram",
6
+ "center_depth": "0.45",
7
+ "mask_region": "4.7",
8
+ "ingredients": [
9
+ {
10
+ "name": "basil",
11
+ "amount": 40,
12
+ "unit": "gram"
13
+ },
14
+ {
15
+ "name": "chili",
16
+ "amount": 10,
17
+ "unit": "gram"
18
+ },
19
+ {
20
+ "name": "pork",
21
+ "amount": 150,
22
+ "unit": "gram"
23
+ }
24
+ ]
25
+ },
26
+ {
27
+ "class": "rice",
28
+ "portion": 200,
29
+ "unit": "gram",
30
+ "center_depth": "0.45",
31
+ "mask_region": "5.03",
32
+ "ingredients": [{
33
+ "name" : "jasmine rice",
34
+ "amount": 200,
35
+ "unit": "gram"
36
+ }]
37
+ },
38
+ {
39
+ "class": "cucumber",
40
+ "portion": 80,
41
+ "unit": "gram",
42
+ "center_depth": "0.45",
43
+ "mask_region": "2.26",
44
+ "ingredients": [{
45
+ "name" : "cucumber",
46
+ "amount": 80,
47
+ "unit": "gram"
48
+ }]
49
+ }
50
+ ]
senior-demo/src/depth_pro.egg-info/PKG-INFO ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: depth_pro
3
+ Version: 0.1
4
+ Summary: Inference/Network/Model code for Apple Depth Pro monocular depth estimation.
5
+ Project-URL: Homepage, https://github.com/apple/ml-depth-pro
6
+ Project-URL: Repository, https://github.com/apple/ml-depth-pro
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: torch
10
+ Requires-Dist: torchvision
11
+ Requires-Dist: timm
12
+ Requires-Dist: numpy<2
13
+ Requires-Dist: pillow_heif
14
+ Requires-Dist: matplotlib
15
+ Dynamic: license-file
16
+
17
+ ## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
18
+
19
+ This software project accompanies the research paper:
20
+ **[Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://arxiv.org/abs/2410.02073)**,
21
+ *Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, and Vladlen Koltun*.
22
+
23
+ ![](data/depth-pro-teaser.jpg)
24
+
25
+ We present a foundation model for zero-shot metric monocular depth estimation. Our model, Depth Pro, synthesizes high-resolution depth maps with unparalleled sharpness and high-frequency details. The predictions are metric, with absolute scale, without relying on the availability of metadata such as camera intrinsics. And the model is fast, producing a 2.25-megapixel depth map in 0.3 seconds on a standard GPU. These characteristics are enabled by a number of technical contributions, including an efficient multi-scale vision transformer for dense prediction, a training protocol that combines real and synthetic datasets to achieve high metric accuracy alongside fine boundary tracing, dedicated evaluation metrics for boundary accuracy in estimated depth maps, and state-of-the-art focal length estimation from a single image.
26
+
27
+
28
+ The model in this repository is a reference implementation, which has been re-trained. Its performance is close to the model reported in the paper but does not match it exactly.
29
+
30
+ ## Getting Started
31
+
32
+ We recommend setting up a virtual environment. Using e.g. miniconda, the `depth_pro` package can be installed via:
33
+
34
+ ```bash
35
+ conda create -n depth-pro -y python=3.9
36
+ conda activate depth-pro
37
+
38
+ pip install -e .
39
+ ```
40
+
41
+ To download pretrained checkpoints follow the code snippet below:
42
+ ```bash
43
+ source get_pretrained_models.sh # Files will be downloaded to `checkpoints` directory.
44
+ ```
45
+
46
+ ### Running from commandline
47
+
48
+ We provide a helper script to directly run the model on a single image:
49
+ ```bash
50
+ # Run prediction on a single image:
51
+ depth-pro-run -i ./data/example.jpg
52
+ # Run `depth-pro-run -h` for available options.
53
+ ```
54
+
55
+ ### Running from python
56
+
57
+ ```python
58
+ from PIL import Image
59
+ import depth_pro
60
+
61
+ # Load model and preprocessing transform
62
+ model, transform = depth_pro.create_model_and_transforms()
63
+ model.eval()
64
+
65
+ # Load and preprocess an image.
66
+ image, _, f_px = depth_pro.load_rgb(image_path)
67
+ image = transform(image)
68
+
69
+ # Run inference.
70
+ prediction = model.infer(image, f_px=f_px)
71
+ depth = prediction["depth"] # Depth in [m].
72
+ focallength_px = prediction["focallength_px"] # Focal length in pixels.
73
+ ```
74
+
75
+
76
+ ### Evaluation (boundary metrics)
77
+
78
+ Our boundary metrics can be found under `eval/boundary_metrics.py` and used as follows:
79
+
80
+ ```python
81
+ # for a depth-based dataset
82
+ boundary_f1 = SI_boundary_F1(predicted_depth, target_depth)
83
+
84
+ # for a mask-based dataset (image matting / segmentation)
85
+ boundary_recall = SI_boundary_Recall(predicted_depth, target_mask)
86
+ ```
87
+
88
+
89
+ ## Citation
90
+
91
+ If you find our work useful, please cite the following paper:
92
+
93
+ ```bibtex
94
+ @inproceedings{Bochkovskii2024:arxiv,
95
+ author = {Aleksei Bochkovskii and Ama\"{e}l Delaunoy and Hugo Germain and Marcel Santos and
96
+ Yichao Zhou and Stephan R. Richter and Vladlen Koltun},
97
+ title = {Depth Pro: Sharp Monocular Metric Depth in Less Than a Second},
98
+ booktitle = {International Conference on Learning Representations},
99
+ year = {2025},
100
+ url = {https://arxiv.org/abs/2410.02073},
101
+ }
102
+ ```
103
+
104
+ ## License
105
+ This sample code is released under the [LICENSE](LICENSE) terms.
106
+
107
+ The model weights are released under the [LICENSE](LICENSE) terms.
108
+
109
+ ## Acknowledgements
110
+
111
+ Our codebase is built using multiple opensource contributions, please see [Acknowledgements](ACKNOWLEDGEMENTS.md) for more details.
112
+
113
+ Please check the paper for a complete list of references and datasets used in this work.
senior-demo/src/depth_pro.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ACKNOWLEDGEMENTS.md
2
+ CODE_OF_CONDUCT.md
3
+ CONTRIBUTING.md
4
+ LICENSE
5
+ README.md
6
+ get_pretrained_models.sh
7
+ pyproject.toml
8
+ data/depth-pro-teaser.jpg
9
+ data/example.jpg
10
+ src/depth_pro/__init__.py
11
+ src/depth_pro/depth_pro.py
12
+ src/depth_pro/utils.py
13
+ src/depth_pro.egg-info/PKG-INFO
14
+ src/depth_pro.egg-info/SOURCES.txt
15
+ src/depth_pro.egg-info/dependency_links.txt
16
+ src/depth_pro.egg-info/entry_points.txt
17
+ src/depth_pro.egg-info/requires.txt
18
+ src/depth_pro.egg-info/top_level.txt
19
+ src/depth_pro/cli/__init__.py
20
+ src/depth_pro/cli/run.py
21
+ src/depth_pro/eval/boundary_metrics.py
22
+ src/depth_pro/eval/dis5k_sample_list.txt
23
+ src/depth_pro/network/__init__.py
24
+ src/depth_pro/network/decoder.py
25
+ src/depth_pro/network/encoder.py
26
+ src/depth_pro/network/fov.py
27
+ src/depth_pro/network/vit.py
28
+ src/depth_pro/network/vit_factory.py
senior-demo/src/depth_pro.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
senior-demo/src/depth_pro.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ depth-pro-run = depth_pro.cli:run_main
senior-demo/src/depth_pro.egg-info/requires.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ numpy<2
5
+ pillow_heif
6
+ matplotlib
senior-demo/src/depth_pro.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ depth_pro
senior-demo/src/depth_pro/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro package."""
3
+
4
+ from .depth_pro import create_model_and_transforms # noqa
5
+ from .utils import load_rgb # noqa
senior-demo/src/depth_pro/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (300 Bytes). View file
 
senior-demo/src/depth_pro/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (298 Bytes). View file
 
senior-demo/src/depth_pro/__pycache__/depth_pro.cpython-310.pyc ADDED
Binary file (7.93 kB). View file
 
senior-demo/src/depth_pro/__pycache__/depth_pro.cpython-39.pyc ADDED
Binary file (7.78 kB). View file
 
senior-demo/src/depth_pro/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.28 kB). View file
 
senior-demo/src/depth_pro/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.26 kB). View file
 
senior-demo/src/depth_pro/cli/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro CLI and tools."""
3
+
4
+ from .run import main as run_main # noqa
senior-demo/src/depth_pro/cli/run.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Sample script to run DepthPro.
3
+
4
+ Copyright (C) 2024 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+
8
+ import argparse
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import PIL.Image
14
+ import torch
15
+ from matplotlib import pyplot as plt
16
+ from tqdm import tqdm
17
+
18
+ from depth_pro import create_model_and_transforms, load_rgb
19
+
20
+ LOGGER = logging.getLogger(__name__)
21
+
22
+
23
+ def get_torch_device() -> torch.device:
24
+ """Get the Torch device."""
25
+ device = torch.device("cpu")
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda:0")
28
+ elif torch.backends.mps.is_available():
29
+ device = torch.device("mps")
30
+ return device
31
+
32
+
33
+ def run(args):
34
+ """Run Depth Pro on a sample image."""
35
+ if args.verbose:
36
+ logging.basicConfig(level=logging.INFO)
37
+
38
+ # Load model.
39
+ model, transform = create_model_and_transforms(
40
+ device=get_torch_device(),
41
+ precision=torch.half,
42
+ )
43
+ model.eval()
44
+
45
+ image_paths = [args.image_path]
46
+ if args.image_path.is_dir():
47
+ image_paths = args.image_path.glob("**/*")
48
+ relative_path = args.image_path
49
+ else:
50
+ relative_path = args.image_path.parent
51
+
52
+ if not args.skip_display:
53
+ plt.ion()
54
+ fig = plt.figure()
55
+ ax_rgb = fig.add_subplot(121)
56
+ ax_disp = fig.add_subplot(122)
57
+
58
+ for image_path in tqdm(image_paths):
59
+ # Load image and focal length from exif info (if found.).
60
+ try:
61
+ LOGGER.info(f"Loading image {image_path} ...")
62
+ image, _, f_px = load_rgb(image_path)
63
+ except Exception as e:
64
+ LOGGER.error(str(e))
65
+ continue
66
+ # Run prediction. If `f_px` is provided, it is used to estimate the final metric depth,
67
+ # otherwise the model estimates `f_px` to compute the depth metricness.
68
+ prediction = model.infer(transform(image), f_px=f_px)
69
+
70
+ # Extract the depth and focal length.
71
+ depth = prediction["depth"].detach().cpu().numpy().squeeze()
72
+ if f_px is not None:
73
+ LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}")
74
+ elif prediction["focallength_px"] is not None:
75
+ focallength_px = prediction["focallength_px"].detach().cpu().item()
76
+ LOGGER.info(f"Estimated focal length: {focallength_px}")
77
+
78
+ inverse_depth = 1 / depth
79
+ # Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization.
80
+ max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1)
81
+ min_invdepth_vizu = max(1 / 250, inverse_depth.min())
82
+ inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / (
83
+ max_invdepth_vizu - min_invdepth_vizu
84
+ )
85
+
86
+ # Save Depth as npz file.
87
+ if args.output_path is not None:
88
+ output_file = (
89
+ args.output_path
90
+ / image_path.relative_to(relative_path).parent
91
+ / image_path.stem
92
+ )
93
+ LOGGER.info(f"Saving depth map to: {str(output_file)}")
94
+ output_file.parent.mkdir(parents=True, exist_ok=True)
95
+ np.savez_compressed(output_file, depth=depth)
96
+
97
+ # Save as color-mapped "turbo" jpg image.
98
+ cmap = plt.get_cmap("turbo")
99
+ color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(
100
+ np.uint8
101
+ )
102
+ color_map_output_file = str(output_file) + ".jpg"
103
+ LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}")
104
+ PIL.Image.fromarray(color_depth).save(
105
+ color_map_output_file, format="JPEG", quality=90
106
+ )
107
+
108
+ # Display the image and estimated depth map.
109
+ if not args.skip_display:
110
+ ax_rgb.imshow(image)
111
+ ax_disp.imshow(inverse_depth_normalized, cmap="turbo")
112
+ fig.canvas.draw()
113
+ fig.canvas.flush_events()
114
+
115
+ LOGGER.info("Done predicting depth!")
116
+ if not args.skip_display:
117
+ plt.show(block=True)
118
+
119
+
120
+ def main():
121
+ """Run DepthPro inference example."""
122
+ parser = argparse.ArgumentParser(
123
+ description="Inference scripts of DepthPro with PyTorch models."
124
+ )
125
+ parser.add_argument(
126
+ "-i",
127
+ "--image-path",
128
+ type=Path,
129
+ default="./data/example.jpg",
130
+ help="Path to input image.",
131
+ )
132
+ parser.add_argument(
133
+ "-o",
134
+ "--output-path",
135
+ type=Path,
136
+ help="Path to store output files.",
137
+ )
138
+ parser.add_argument(
139
+ "--skip-display",
140
+ action="store_true",
141
+ help="Skip matplotlib display.",
142
+ )
143
+ parser.add_argument(
144
+ "-v",
145
+ "--verbose",
146
+ action="store_true",
147
+ help="Show verbose output."
148
+ )
149
+
150
+ run(parser.parse_args())
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
senior-demo/src/depth_pro/depth_pro.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Mapping, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torchvision.transforms import (
13
+ Compose,
14
+ ConvertImageDtype,
15
+ Lambda,
16
+ Normalize,
17
+ ToTensor,
18
+ )
19
+
20
+ from .network.decoder import MultiresConvDecoder
21
+ from .network.encoder import DepthProEncoder
22
+ from .network.fov import FOVNetwork
23
+ from .network.vit_factory import VIT_CONFIG_DICT, ViTPreset, create_vit
24
+
25
+
26
+ @dataclass
27
+ class DepthProConfig:
28
+ """Configuration for DepthPro."""
29
+
30
+ patch_encoder_preset: ViTPreset
31
+ image_encoder_preset: ViTPreset
32
+ decoder_features: int
33
+
34
+ checkpoint_uri: Optional[str] = None
35
+ fov_encoder_preset: Optional[ViTPreset] = None
36
+ use_fov_head: bool = True
37
+
38
+
39
+ DEFAULT_MONODEPTH_CONFIG_DICT = DepthProConfig(
40
+ patch_encoder_preset="dinov2l16_384",
41
+ image_encoder_preset="dinov2l16_384",
42
+ checkpoint_uri="./checkpoints/depth_pro.pt",
43
+ decoder_features=256,
44
+ use_fov_head=True,
45
+ fov_encoder_preset="dinov2l16_384",
46
+ )
47
+
48
+
49
+ def create_backbone_model(
50
+ preset: ViTPreset
51
+ ) -> Tuple[nn.Module, ViTPreset]:
52
+ """Create and load a backbone model given a config.
53
+
54
+ Args:
55
+ ----
56
+ preset: A backbone preset to load pre-defind configs.
57
+
58
+ Returns:
59
+ -------
60
+ A Torch module and the associated config.
61
+
62
+ """
63
+ if preset in VIT_CONFIG_DICT:
64
+ config = VIT_CONFIG_DICT[preset]
65
+ model = create_vit(preset=preset, use_pretrained=False)
66
+ else:
67
+ raise KeyError(f"Preset {preset} not found.")
68
+
69
+ return model, config
70
+
71
+
72
+ def create_model_and_transforms(
73
+ config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT,
74
+ device: torch.device = torch.device("cpu"),
75
+ precision: torch.dtype = torch.float32,
76
+ ) -> Tuple[DepthPro, Compose]:
77
+ """Create a DepthPro model and load weights from `config.checkpoint_uri`.
78
+
79
+ Args:
80
+ ----
81
+ config: The configuration for the DPT model architecture.
82
+ device: The optional Torch device to load the model onto, default runs on "cpu".
83
+ precision: The optional precision used for the model, default is FP32.
84
+
85
+ Returns:
86
+ -------
87
+ The Torch DepthPro model and associated Transform.
88
+
89
+ """
90
+ patch_encoder, patch_encoder_config = create_backbone_model(
91
+ preset=config.patch_encoder_preset
92
+ )
93
+ image_encoder, _ = create_backbone_model(
94
+ preset=config.image_encoder_preset
95
+ )
96
+
97
+ fov_encoder = None
98
+ if config.use_fov_head and config.fov_encoder_preset is not None:
99
+ fov_encoder, _ = create_backbone_model(preset=config.fov_encoder_preset)
100
+
101
+ dims_encoder = patch_encoder_config.encoder_feature_dims
102
+ hook_block_ids = patch_encoder_config.encoder_feature_layer_ids
103
+ encoder = DepthProEncoder(
104
+ dims_encoder=dims_encoder,
105
+ patch_encoder=patch_encoder,
106
+ image_encoder=image_encoder,
107
+ hook_block_ids=hook_block_ids,
108
+ decoder_features=config.decoder_features,
109
+ )
110
+ decoder = MultiresConvDecoder(
111
+ dims_encoder=[config.decoder_features] + list(encoder.dims_encoder),
112
+ dim_decoder=config.decoder_features,
113
+ )
114
+ model = DepthPro(
115
+ encoder=encoder,
116
+ decoder=decoder,
117
+ last_dims=(32, 1),
118
+ use_fov_head=config.use_fov_head,
119
+ fov_encoder=fov_encoder,
120
+ ).to(device)
121
+
122
+ if precision == torch.half:
123
+ model.half()
124
+
125
+ transform = Compose(
126
+ [
127
+ ToTensor(),
128
+ Lambda(lambda x: x.to(device)),
129
+ Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
130
+ ConvertImageDtype(precision),
131
+ ]
132
+ )
133
+
134
+ if config.checkpoint_uri is not None:
135
+ state_dict = torch.load(config.checkpoint_uri, map_location="cpu")
136
+ missing_keys, unexpected_keys = model.load_state_dict(
137
+ state_dict=state_dict, strict=True
138
+ )
139
+
140
+ if len(unexpected_keys) != 0:
141
+ raise KeyError(
142
+ f"Found unexpected keys when loading monodepth: {unexpected_keys}"
143
+ )
144
+
145
+ # fc_norm is only for the classification head,
146
+ # which we would not use. We only use the encoding.
147
+ missing_keys = [key for key in missing_keys if "fc_norm" not in key]
148
+ if len(missing_keys) != 0:
149
+ raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}")
150
+
151
+ return model, transform
152
+
153
+
154
+ class DepthPro(nn.Module):
155
+ """DepthPro network."""
156
+
157
+ def __init__(
158
+ self,
159
+ encoder: DepthProEncoder,
160
+ decoder: MultiresConvDecoder,
161
+ last_dims: tuple[int, int],
162
+ use_fov_head: bool = True,
163
+ fov_encoder: Optional[nn.Module] = None,
164
+ ):
165
+ """Initialize DepthPro.
166
+
167
+ Args:
168
+ ----
169
+ encoder: The DepthProEncoder backbone.
170
+ decoder: The MultiresConvDecoder decoder.
171
+ last_dims: The dimension for the last convolution layers.
172
+ use_fov_head: Whether to use the field-of-view head.
173
+ fov_encoder: A separate encoder for the field of view.
174
+
175
+ """
176
+ super().__init__()
177
+
178
+ self.encoder = encoder
179
+ self.decoder = decoder
180
+
181
+ dim_decoder = decoder.dim_decoder
182
+ self.head = nn.Sequential(
183
+ nn.Conv2d(
184
+ dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1
185
+ ),
186
+ nn.ConvTranspose2d(
187
+ in_channels=dim_decoder // 2,
188
+ out_channels=dim_decoder // 2,
189
+ kernel_size=2,
190
+ stride=2,
191
+ padding=0,
192
+ bias=True,
193
+ ),
194
+ nn.Conv2d(
195
+ dim_decoder // 2,
196
+ last_dims[0],
197
+ kernel_size=3,
198
+ stride=1,
199
+ padding=1,
200
+ ),
201
+ nn.ReLU(True),
202
+ nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
203
+ nn.ReLU(),
204
+ )
205
+
206
+ # Set the final convolution layer's bias to be 0.
207
+ self.head[4].bias.data.fill_(0)
208
+
209
+ # Set the FOV estimation head.
210
+ if use_fov_head:
211
+ self.fov = FOVNetwork(num_features=dim_decoder, fov_encoder=fov_encoder)
212
+
213
+ @property
214
+ def img_size(self) -> int:
215
+ """Return the internal image size of the network."""
216
+ return self.encoder.img_size
217
+
218
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
219
+ """Decode by projection and fusion of multi-resolution encodings.
220
+
221
+ Args:
222
+ ----
223
+ x (torch.Tensor): Input image.
224
+
225
+ Returns:
226
+ -------
227
+ The canonical inverse depth map [m] and the optional estimated field of view [deg].
228
+
229
+ """
230
+ _, _, H, W = x.shape
231
+ assert H == self.img_size and W == self.img_size
232
+
233
+ encodings = self.encoder(x)
234
+ features, features_0 = self.decoder(encodings)
235
+ canonical_inverse_depth = self.head(features)
236
+
237
+ fov_deg = None
238
+ if hasattr(self, "fov"):
239
+ fov_deg = self.fov.forward(x, features_0.detach())
240
+
241
+ return canonical_inverse_depth, fov_deg
242
+
243
+ @torch.no_grad()
244
+ def infer(
245
+ self,
246
+ x: torch.Tensor,
247
+ f_px: Optional[Union[float, torch.Tensor]] = None,
248
+ interpolation_mode="bilinear",
249
+ ) -> Mapping[str, torch.Tensor]:
250
+ """Infer depth and fov for a given image.
251
+
252
+ If the image is not at network resolution, it is resized to 1536x1536 and
253
+ the estimated depth is resized to the original image resolution.
254
+ Note: if the focal length is given, the estimated value is ignored and the provided
255
+ focal length is use to generate the metric depth values.
256
+
257
+ Args:
258
+ ----
259
+ x (torch.Tensor): Input image
260
+ f_px (torch.Tensor): Optional focal length in pixels corresponding to `x`.
261
+ interpolation_mode (str): Interpolation function for downsampling/upsampling.
262
+
263
+ Returns:
264
+ -------
265
+ Tensor dictionary (torch.Tensor): depth [m], focallength [pixels].
266
+
267
+ """
268
+ if len(x.shape) == 3:
269
+ x = x.unsqueeze(0)
270
+ _, _, H, W = x.shape
271
+ resize = H != self.img_size or W != self.img_size
272
+
273
+ if resize:
274
+ x = nn.functional.interpolate(
275
+ x,
276
+ size=(self.img_size, self.img_size),
277
+ mode=interpolation_mode,
278
+ align_corners=False,
279
+ )
280
+
281
+ canonical_inverse_depth, fov_deg = self.forward(x)
282
+ if f_px is None:
283
+ f_px = 0.5 * W / torch.tan(0.5 * torch.deg2rad(fov_deg.to(torch.float)))
284
+
285
+ inverse_depth = canonical_inverse_depth * (W / f_px)
286
+ f_px = f_px.squeeze()
287
+
288
+ if resize:
289
+ inverse_depth = nn.functional.interpolate(
290
+ inverse_depth, size=(H, W), mode=interpolation_mode, align_corners=False
291
+ )
292
+
293
+ depth = 1.0 / torch.clamp(inverse_depth, min=1e-4, max=1e4)
294
+
295
+ return {
296
+ "depth": depth.squeeze(),
297
+ "focallength_px": f_px,
298
+ }
senior-demo/src/depth_pro/eval/boundary_metrics.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+
5
+
6
+ def connected_component(r: np.ndarray, c: np.ndarray) -> List[List[int]]:
7
+ """Find connected components in the given row and column indices.
8
+
9
+ Args:
10
+ ----
11
+ r (np.ndarray): Row indices.
12
+ c (np.ndarray): Column indices.
13
+
14
+ Yields:
15
+ ------
16
+ List[int]: Indices of connected components.
17
+
18
+ """
19
+ indices = [0]
20
+ for i in range(1, r.size):
21
+ if r[i] == r[indices[-1]] and c[i] == c[indices[-1]] + 1:
22
+ indices.append(i)
23
+ else:
24
+ yield indices
25
+ indices = [i]
26
+ yield indices
27
+
28
+
29
+ def nms_horizontal(ratio: np.ndarray, threshold: float) -> np.ndarray:
30
+ """Apply Non-Maximum Suppression (NMS) horizontally on the given ratio matrix.
31
+
32
+ Args:
33
+ ----
34
+ ratio (np.ndarray): Input ratio matrix.
35
+ threshold (float): Threshold for NMS.
36
+
37
+ Returns:
38
+ -------
39
+ np.ndarray: Binary mask after applying NMS.
40
+
41
+ """
42
+ mask = np.zeros_like(ratio, dtype=bool)
43
+ r, c = np.nonzero(ratio > threshold)
44
+ if len(r) == 0:
45
+ return mask
46
+ for ids in connected_component(r, c):
47
+ values = [ratio[r[i], c[i]] for i in ids]
48
+ mi = np.argmax(values)
49
+ mask[r[ids[mi]], c[ids[mi]]] = True
50
+ return mask
51
+
52
+
53
+ def nms_vertical(ratio: np.ndarray, threshold: float) -> np.ndarray:
54
+ """Apply Non-Maximum Suppression (NMS) vertically on the given ratio matrix.
55
+
56
+ Args:
57
+ ----
58
+ ratio (np.ndarray): Input ratio matrix.
59
+ threshold (float): Threshold for NMS.
60
+
61
+ Returns:
62
+ -------
63
+ np.ndarray: Binary mask after applying NMS.
64
+
65
+ """
66
+ return np.transpose(nms_horizontal(np.transpose(ratio), threshold))
67
+
68
+
69
+ def fgbg_depth(
70
+ d: np.ndarray, t: float
71
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
72
+ """Find foreground-background relations between neighboring pixels.
73
+
74
+ Args:
75
+ ----
76
+ d (np.ndarray): Depth matrix.
77
+ t (float): Threshold for comparison.
78
+
79
+ Returns:
80
+ -------
81
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
82
+ left, top, right, and bottom foreground-background relations.
83
+
84
+ """
85
+ right_is_big_enough = (d[..., :, 1:] / d[..., :, :-1]) > t
86
+ left_is_big_enough = (d[..., :, :-1] / d[..., :, 1:]) > t
87
+ bottom_is_big_enough = (d[..., 1:, :] / d[..., :-1, :]) > t
88
+ top_is_big_enough = (d[..., :-1, :] / d[..., 1:, :]) > t
89
+ return (
90
+ left_is_big_enough,
91
+ top_is_big_enough,
92
+ right_is_big_enough,
93
+ bottom_is_big_enough,
94
+ )
95
+
96
+
97
+ def fgbg_depth_thinned(
98
+ d: np.ndarray, t: float
99
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
100
+ """Find foreground-background relations between neighboring pixels with Non-Maximum Suppression.
101
+
102
+ Args:
103
+ ----
104
+ d (np.ndarray): Depth matrix.
105
+ t (float): Threshold for NMS.
106
+
107
+ Returns:
108
+ -------
109
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
110
+ left, top, right, and bottom foreground-background relations with NMS applied.
111
+
112
+ """
113
+ right_is_big_enough = nms_horizontal(d[..., :, 1:] / d[..., :, :-1], t)
114
+ left_is_big_enough = nms_horizontal(d[..., :, :-1] / d[..., :, 1:], t)
115
+ bottom_is_big_enough = nms_vertical(d[..., 1:, :] / d[..., :-1, :], t)
116
+ top_is_big_enough = nms_vertical(d[..., :-1, :] / d[..., 1:, :], t)
117
+ return (
118
+ left_is_big_enough,
119
+ top_is_big_enough,
120
+ right_is_big_enough,
121
+ bottom_is_big_enough,
122
+ )
123
+
124
+
125
+ def fgbg_binary_mask(
126
+ d: np.ndarray,
127
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
128
+ """Find foreground-background relations between neighboring pixels in binary masks.
129
+
130
+ Args:
131
+ ----
132
+ d (np.ndarray): Binary depth matrix.
133
+
134
+ Returns:
135
+ -------
136
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
137
+ left, top, right, and bottom foreground-background relations in binary masks.
138
+
139
+ """
140
+ assert d.dtype == bool
141
+ right_is_big_enough = d[..., :, 1:] & ~d[..., :, :-1]
142
+ left_is_big_enough = d[..., :, :-1] & ~d[..., :, 1:]
143
+ bottom_is_big_enough = d[..., 1:, :] & ~d[..., :-1, :]
144
+ top_is_big_enough = d[..., :-1, :] & ~d[..., 1:, :]
145
+ return (
146
+ left_is_big_enough,
147
+ top_is_big_enough,
148
+ right_is_big_enough,
149
+ bottom_is_big_enough,
150
+ )
151
+
152
+
153
+ def edge_recall_matting(pr: np.ndarray, gt: np.ndarray, t: float) -> float:
154
+ """Calculate edge recall for image matting.
155
+
156
+ Args:
157
+ ----
158
+ pr (np.ndarray): Predicted depth matrix.
159
+ gt (np.ndarray): Ground truth binary mask.
160
+ t (float): Threshold for NMS.
161
+
162
+ Returns:
163
+ -------
164
+ float: Edge recall value.
165
+
166
+ """
167
+ assert gt.dtype == bool
168
+ ap, bp, cp, dp = fgbg_depth_thinned(pr, t)
169
+ ag, bg, cg, dg = fgbg_binary_mask(gt)
170
+ return 0.25 * (
171
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
172
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
173
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
174
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
175
+ )
176
+
177
+
178
+ def boundary_f1(
179
+ pr: np.ndarray,
180
+ gt: np.ndarray,
181
+ t: float,
182
+ return_p: bool = False,
183
+ return_r: bool = False,
184
+ ) -> float:
185
+ """Calculate Boundary F1 score.
186
+
187
+ Args:
188
+ ----
189
+ pr (np.ndarray): Predicted depth matrix.
190
+ gt (np.ndarray): Ground truth depth matrix.
191
+ t (float): Threshold for comparison.
192
+ return_p (bool, optional): If True, return precision. Defaults to False.
193
+ return_r (bool, optional): If True, return recall. Defaults to False.
194
+
195
+ Returns:
196
+ -------
197
+ float: Boundary F1 score, or precision, or recall depending on the flags.
198
+
199
+ """
200
+ ap, bp, cp, dp = fgbg_depth(pr, t)
201
+ ag, bg, cg, dg = fgbg_depth(gt, t)
202
+
203
+ r = 0.25 * (
204
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
205
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
206
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
207
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
208
+ )
209
+ p = 0.25 * (
210
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ap), 1)
211
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bp), 1)
212
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cp), 1)
213
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dp), 1)
214
+ )
215
+ if r + p == 0:
216
+ return 0.0
217
+ if return_p:
218
+ return p
219
+ if return_r:
220
+ return r
221
+ return 2 * (r * p) / (r + p)
222
+
223
+
224
+ def get_thresholds_and_weights(
225
+ t_min: float, t_max: float, N: int
226
+ ) -> Tuple[np.ndarray, np.ndarray]:
227
+ """Generate thresholds and weights for the given range.
228
+
229
+ Args:
230
+ ----
231
+ t_min (float): Minimum threshold.
232
+ t_max (float): Maximum threshold.
233
+ N (int): Number of thresholds.
234
+
235
+ Returns:
236
+ -------
237
+ Tuple[np.ndarray, np.ndarray]: Array of thresholds and corresponding weights.
238
+
239
+ """
240
+ thresholds = np.linspace(t_min, t_max, N)
241
+ weights = thresholds / thresholds.sum()
242
+ return thresholds, weights
243
+
244
+
245
+ def invert_depth(depth: np.ndarray, eps: float = 1e-6) -> np.ndarray:
246
+ """Inverts a depth map with numerical stability.
247
+
248
+ Args:
249
+ ----
250
+ depth (np.ndarray): Depth map to be inverted.
251
+ eps (float): Minimum value to avoid division by zero (default is 1e-6).
252
+
253
+ Returns:
254
+ -------
255
+ np.ndarray: Inverted depth map.
256
+
257
+ """
258
+ inverse_depth = 1.0 / depth.clip(min=eps)
259
+ return inverse_depth
260
+
261
+
262
+ def SI_boundary_F1(
263
+ predicted_depth: np.ndarray,
264
+ target_depth: np.ndarray,
265
+ t_min: float = 1.05,
266
+ t_max: float = 1.25,
267
+ N: int = 10,
268
+ ) -> float:
269
+ """Calculate Scale-Invariant Boundary F1 Score for depth-based ground-truth.
270
+
271
+ Args:
272
+ ----
273
+ predicted_depth (np.ndarray): Predicted depth matrix.
274
+ target_depth (np.ndarray): Ground truth depth matrix.
275
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
276
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
277
+ N (int, optional): Number of thresholds. Defaults to 10.
278
+
279
+ Returns:
280
+ -------
281
+ float: Scale-Invariant Boundary F1 Score.
282
+
283
+ """
284
+ assert predicted_depth.ndim == target_depth.ndim == 2
285
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
286
+ f1_scores = np.array(
287
+ [
288
+ boundary_f1(invert_depth(predicted_depth), invert_depth(target_depth), t)
289
+ for t in thresholds
290
+ ]
291
+ )
292
+ return np.sum(f1_scores * weights)
293
+
294
+
295
+ def SI_boundary_Recall(
296
+ predicted_depth: np.ndarray,
297
+ target_mask: np.ndarray,
298
+ t_min: float = 1.05,
299
+ t_max: float = 1.25,
300
+ N: int = 10,
301
+ alpha_threshold: float = 0.1,
302
+ ) -> float:
303
+ """Calculate Scale-Invariant Boundary Recall Score for mask-based ground-truth.
304
+
305
+ Args:
306
+ ----
307
+ predicted_depth (np.ndarray): Predicted depth matrix.
308
+ target_mask (np.ndarray): Ground truth binary mask.
309
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
310
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
311
+ N (int, optional): Number of thresholds. Defaults to 10.
312
+ alpha_threshold (float, optional): Threshold for alpha masking. Defaults to 0.1.
313
+
314
+ Returns:
315
+ -------
316
+ float: Scale-Invariant Boundary Recall Score.
317
+
318
+ """
319
+ assert predicted_depth.ndim == target_mask.ndim == 2
320
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
321
+ thresholded_target = target_mask > alpha_threshold
322
+
323
+ recall_scores = np.array(
324
+ [
325
+ edge_recall_matting(
326
+ invert_depth(predicted_depth), thresholded_target, t=float(t)
327
+ )
328
+ for t in thresholds
329
+ ]
330
+ )
331
+ weighted_recall = np.sum(recall_scores * weights)
332
+ return weighted_recall
senior-demo/src/depth_pro/eval/dis5k_sample_list.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DIS5K/DIS-TE1/im/12#Graphics#4#TrafficSign#8245751856_821be14f86_o.jpg
2
+ DIS5K/DIS-TE1/im/13#Insect#4#Butterfly#16023994688_7ff8cdccb1_o.jpg
3
+ DIS5K/DIS-TE1/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205538.jpg
4
+ DIS5K/DIS-TE1/im/14#Kitchenware#8#SweetStand#4848284981_fc90f54b50_o.jpg
5
+ DIS5K/DIS-TE1/im/17#Non-motor Vehicle#4#Cart#15012855035_d10b57014f_o.jpg
6
+ DIS5K/DIS-TE1/im/2#Aircraft#5#Kite#13104545564_5afceec9bd_o.jpg
7
+ DIS5K/DIS-TE1/im/20#Sports#10#Skateboarding#8472763540_bb2390e928_o.jpg
8
+ DIS5K/DIS-TE1/im/21#Tool#14#Sword#32473146960_dcc6b77848_o.jpg
9
+ DIS5K/DIS-TE1/im/21#Tool#15#Tapeline#9680492386_2d2020f282_o.jpg
10
+ DIS5K/DIS-TE1/im/21#Tool#4#Flag#507752845_ef852100f0_o.jpg
11
+ DIS5K/DIS-TE1/im/21#Tool#6#Key#11966089533_3becd78b44_o.jpg
12
+ DIS5K/DIS-TE1/im/21#Tool#8#Scale#31946428472_d28def471b_o.jpg
13
+ DIS5K/DIS-TE1/im/22#Weapon#4#Rifle#8472656430_3eb908b211_o.jpg
14
+ DIS5K/DIS-TE1/im/8#Electronics#3#Earphone#1177468301_641df8c267_o.jpg
15
+ DIS5K/DIS-TE1/im/8#Electronics#9#MusicPlayer#2235782872_7d47847bb4_o.jpg
16
+ DIS5K/DIS-TE2/im/11#Furniture#13#Ladder#3878434417_2ed740586e_o.jpg
17
+ DIS5K/DIS-TE2/im/13#Insect#1#Ant#27047700955_3b3a1271f8_o.jpg
18
+ DIS5K/DIS-TE2/im/13#Insect#11#Spider#5567179191_38d1f65589_o.jpg
19
+ DIS5K/DIS-TE2/im/13#Insect#8#Locust#5237933769_e6687c05e4_o.jpg
20
+ DIS5K/DIS-TE2/im/14#Kitchenware#2#DishRack#70838854_40cf689da7_o.jpg
21
+ DIS5K/DIS-TE2/im/14#Kitchenware#8#SweetStand#8467929412_fef7f4275d_o.jpg
22
+ DIS5K/DIS-TE2/im/16#Music Instrument#2#Harp#28058219806_28e05ff24a_o.jpg
23
+ DIS5K/DIS-TE2/im/17#Non-motor Vehicle#1#BabyCarriage#29794777180_2e1695a0cf_o.jpg
24
+ DIS5K/DIS-TE2/im/19#Ship#3#Sailboat#22442908623_5977e3becf_o.jpg
25
+ DIS5K/DIS-TE2/im/2#Aircraft#5#Kite#44654358051_1400e71cc4_o.jpg
26
+ DIS5K/DIS-TE2/im/21#Tool#11#Stand#IMG_20210520_205442.jpg
27
+ DIS5K/DIS-TE2/im/21#Tool#17#Tripod#9318977876_34615ec9a0_o.jpg
28
+ DIS5K/DIS-TE2/im/5#Artifact#3#Handcraft#50860882577_8482143b1b_o.jpg
29
+ DIS5K/DIS-TE2/im/8#Electronics#10#Robot#3093360210_fee54dc5c5_o.jpg
30
+ DIS5K/DIS-TE2/im/8#Electronics#6#Microphone#47411477652_6da66cbc10_o.jpg
31
+ DIS5K/DIS-TE3/im/14#Kitchenware#4#Kitchenware#2451122898_ef883175dd_o.jpg
32
+ DIS5K/DIS-TE3/im/15#Machine#4#SewingMachine#9311164128_97ba1d3947_o.jpg
33
+ DIS5K/DIS-TE3/im/16#Music Instrument#2#Harp#7670920550_59e992fd7b_o.jpg
34
+ DIS5K/DIS-TE3/im/17#Non-motor Vehicle#1#BabyCarriage#8389984877_1fddf8715c_o.jpg
35
+ DIS5K/DIS-TE3/im/17#Non-motor Vehicle#3#Carriage#5947122724_98e0fc3d1f_o.jpg
36
+ DIS5K/DIS-TE3/im/2#Aircraft#2#Balloon#2487168092_641505883f_o.jpg
37
+ DIS5K/DIS-TE3/im/2#Aircraft#4#Helicopter#8401177591_06c71c8df2_o.jpg
38
+ DIS5K/DIS-TE3/im/20#Sports#1#Archery#12520003103_faa43ea3e0_o.jpg
39
+ DIS5K/DIS-TE3/im/21#Tool#11#Stand#IMG_20210709_221507.jpg
40
+ DIS5K/DIS-TE3/im/21#Tool#2#Clip#5656649687_63d0c6696d_o.jpg
41
+ DIS5K/DIS-TE3/im/21#Tool#6#Key#12878459244_6387a140ea_o.jpg
42
+ DIS5K/DIS-TE3/im/3#Aquatic#1#Lobster#109214461_f52b4b6093_o.jpg
43
+ DIS5K/DIS-TE3/im/4#Architecture#19#Windmill#20195851863_2627117e0e_o.jpg
44
+ DIS5K/DIS-TE3/im/5#Artifact#2#Cage#5821476369_ea23927487_o.jpg
45
+ DIS5K/DIS-TE3/im/8#Electronics#7#MobileHolder#49732997896_7f53c290b5_o.jpg
46
+ DIS5K/DIS-TE4/im/13#Insect#6#Centipede#15302179708_a267850881_o.jpg
47
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#11#Tricycle#5771069105_a3aef6f665_o.jpg
48
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#2#Bicycle#4245936196_fdf812dcb7_o.jpg
49
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#9#ShoppingCart#4674052920_a5b7a2b236_o.jpg
50
+ DIS5K/DIS-TE4/im/18#Plant#1#Bonsai#3539420884_ca8973e2c0_o.jpg
51
+ DIS5K/DIS-TE4/im/2#Aircraft#6#Parachute#33590416634_9d6f2325e7_o.jpg
52
+ DIS5K/DIS-TE4/im/20#Sports#1#Archery#46924476515_0be1caa684_o.jpg
53
+ DIS5K/DIS-TE4/im/20#Sports#8#Racket#19337607166_dd1985fb59_o.jpg
54
+ DIS5K/DIS-TE4/im/21#Tool#6#Key#3193329588_839b0c74ce_o.jpg
55
+ DIS5K/DIS-TE4/im/5#Artifact#2#Cage#5821886526_0573ba2d0d_o.jpg
56
+ DIS5K/DIS-TE4/im/5#Artifact#3#Handcraft#50105138282_3c1d02c968_o.jpg
57
+ DIS5K/DIS-TE4/im/8#Electronics#1#Antenna#4305034305_874f21a701_o.jpg
58
+ DIS5K/DIS-TR/im/1#Accessories#1#Bag#15554964549_3105e51b6f_o.jpg
59
+ DIS5K/DIS-TR/im/1#Accessories#1#Bag#41104261980_098a6c4a56_o.jpg
60
+ DIS5K/DIS-TR/im/1#Accessories#2#Clothes#2284764037_871b2e8ca4_o.jpg
61
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#1824643784_70d0134156_o.jpg
62
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#3590020230_37b09a29b3_o.jpg
63
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#4809652879_4da8a69f3b_o.jpg
64
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#792204934_f9b28f99b4_o.jpg
65
+ DIS5K/DIS-TR/im/1#Accessories#5#Jewelry#13909132974_c4750c5fb7_o.jpg
66
+ DIS5K/DIS-TR/im/1#Accessories#7#Shoe#2483391615_9199ece8d6_o.jpg
67
+ DIS5K/DIS-TR/im/1#Accessories#8#Watch#4343266960_f6633b029b_o.jpg
68
+ DIS5K/DIS-TR/im/10#Frame#2#BicycleFrame#17897573_42964dd104_o.jpg
69
+ DIS5K/DIS-TR/im/10#Frame#5#Rack#15898634812_64807069ff_o.jpg
70
+ DIS5K/DIS-TR/im/10#Frame#5#Rack#23928546819_c184cb0b60_o.jpg
71
+ DIS5K/DIS-TR/im/11#Furniture#19#Shower#6189119596_77bcfe80ee_o.jpg
72
+ DIS5K/DIS-TR/im/11#Furniture#2#Bench#3263647075_9306e280b5_o.jpg
73
+ DIS5K/DIS-TR/im/11#Furniture#5#CoatHanger#12774091054_cd5ff520ef_o.jpg
74
+ DIS5K/DIS-TR/im/11#Furniture#6#DentalChair#13878156865_d0439dcb32_o.jpg
75
+ DIS5K/DIS-TR/im/11#Furniture#9#Easel#5861024714_2070cd480c_o.jpg
76
+ DIS5K/DIS-TR/im/12#Graphics#4#TrafficSign#40621867334_f3c32ec189_o.jpg
77
+ DIS5K/DIS-TR/im/13#Insect#1#Ant#3295038190_db5dd0d4f4_o.jpg
78
+ DIS5K/DIS-TR/im/13#Insect#10#Mosquito#24341339_a88a1dad4c_o.jpg
79
+ DIS5K/DIS-TR/im/13#Insect#11#Spider#27171518270_63b78069ff_o.jpg
80
+ DIS5K/DIS-TR/im/13#Insect#11#Spider#49925050281_fa727c154e_o.jpg
81
+ DIS5K/DIS-TR/im/13#Insect#2#Beatle#279616486_2f1e64f591_o.jpg
82
+ DIS5K/DIS-TR/im/13#Insect#3#Bee#43892067695_82cf3e536b_o.jpg
83
+ DIS5K/DIS-TR/im/13#Insect#6#Centipede#20874281788_3e15c90a1c_o.jpg
84
+ DIS5K/DIS-TR/im/13#Insect#7#Dragonfly#14106671120_1b824d77e4_o.jpg
85
+ DIS5K/DIS-TR/im/13#Insect#8#Locust#21637491048_676ef7c9f7_o.jpg
86
+ DIS5K/DIS-TR/im/13#Insect#9#Mantis#1381120202_9dff6987b2_o.jpg
87
+ DIS5K/DIS-TR/im/14#Kitchenware#1#Cup#12812517473_327d6474b8_o.jpg
88
+ DIS5K/DIS-TR/im/14#Kitchenware#10#WineGlass#6402491641_389275d4d1_o.jpg
89
+ DIS5K/DIS-TR/im/14#Kitchenware#3#Hydrovalve#3129932040_8c05825004_o.jpg
90
+ DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#2881934780_87d5218ebb_o.jpg
91
+ DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205527.jpg
92
+ DIS5K/DIS-TR/im/14#Kitchenware#6#Spoon#32989113501_b69eccf0df_o.jpg
93
+ DIS5K/DIS-TR/im/14#Kitchenware#8#SweetStand#2867322189_c56d1e0b87_o.jpg
94
+ DIS5K/DIS-TR/im/15#Machine#1#Gear#19217846720_f5f2807475_o.jpg
95
+ DIS5K/DIS-TR/im/15#Machine#2#Machine#1620160659_9571b7a7ab_o.jpg
96
+ DIS5K/DIS-TR/im/16#Music Instrument#2#Harp#6012801603_1a6e2c16a6_o.jpg
97
+ DIS5K/DIS-TR/im/16#Music Instrument#5#Trombone#8683292118_d223c17ccb_o.jpg
98
+ DIS5K/DIS-TR/im/16#Music Instrument#6#Trumpet#8393262740_b8c216142c_o.jpg
99
+ DIS5K/DIS-TR/im/16#Music Instrument#8#Violin#1511267391_40e4949d68_o.jpg
100
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#1#BabyCarriage#6989512997_38b3dbc88b_o.jpg
101
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#14627183228_b2d68cf501_o.jpg
102
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#2932226475_1b2403e549_o.jpg
103
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#5420155648_86459905b8_o.jpg
104
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#2#Bicycle#IMG_20210513_134904.jpg
105
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#3#Carriage#3311962551_6f211b7bd6_o.jpg
106
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#4#Cart#2609732026_baf7fff3a1_o.jpg
107
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#5#Handcart#5821282211_201cefeaf2_o.jpg
108
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#7#Mower#5779003232_3bb3ae531a_o.jpg
109
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#10051622843_ace07e32b8_o.jpg
110
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#8075259294_f23e243849_o.jpg
111
+ DIS5K/DIS-TR/im/18#Plant#2#Tree#44800999741_e377e16dbb_o.jpg
112
+ DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#2631761913_3ac67d0223_o.jpg
113
+ DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#37707911566_e908a261b6_o.jpg
114
+ DIS5K/DIS-TR/im/2#Aircraft#3#HangGlider#2557220131_b8506920c5_o.jpg
115
+ DIS5K/DIS-TR/im/2#Aircraft#4#Helicopter#6215659280_5dbd9b4546_o.jpg
116
+ DIS5K/DIS-TR/im/2#Aircraft#6#Parachute#20185790493_e56fcaf8c6_o.jpg
117
+ DIS5K/DIS-TR/im/20#Sports#1#Archery#3871269982_ae4c59a7eb_o.jpg
118
+ DIS5K/DIS-TR/im/20#Sports#9#RockClimbing#9662433268_51299bc50e_o.jpg
119
+ DIS5K/DIS-TR/im/21#Tool#14#Sword#26258479365_2950d7fa37_o.jpg
120
+ DIS5K/DIS-TR/im/21#Tool#15#Tapeline#15505703447_e0fdeaa5a6_o.jpg
121
+ DIS5K/DIS-TR/im/21#Tool#4#Flag#26678602024_9b665742de_o.jpg
122
+ DIS5K/DIS-TR/im/21#Tool#4#Flag#5774823110_d603ce3cc8_o.jpg
123
+ DIS5K/DIS-TR/im/21#Tool#5#Hook#6867989814_dba18d673c_o.jpg
124
+ DIS5K/DIS-TR/im/22#Weapon#4#Rifle#4451713125_cd91719189_o.jpg
125
+ DIS5K/DIS-TR/im/3#Aquatic#2#Seadragon#4910944581_913139b238_o.jpg
126
+ DIS5K/DIS-TR/im/4#Architecture#12#Scaffold#3661448960_8aff24cc4d_o.jpg
127
+ DIS5K/DIS-TR/im/4#Architecture#13#Sculpture#6385318715_9a88d4eba7_o.jpg
128
+ DIS5K/DIS-TR/im/4#Architecture#17#Well#5011603479_75cf42808a_o.jpg
129
+ DIS5K/DIS-TR/im/5#Artifact#2#Cage#4892828841_7f1bc05682_o.jpg
130
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#15404211628_9e9ff2ce2e_o.jpg
131
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#3200169865_7c84cfcccf_o.jpg
132
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#5859295071_c217e7c22f_o.jpg
133
+ DIS5K/DIS-TR/im/6#Automobile#10#SteeringWheel#17200338026_f1e2122d8e_o.jpg
134
+ DIS5K/DIS-TR/im/6#Automobile#3#Car#3780893425_1a7d275e09_o.jpg
135
+ DIS5K/DIS-TR/im/6#Automobile#5#Crane#15282506502_1b1132a7c3_o.jpg
136
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#16767791875_8e6df41752_o.jpg
137
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#3291433361_38747324c4_o.jpg
138
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#4195104238_12a754c61a_o.jpg
139
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#49645415132_61e5664ecf_o.jpg
140
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#IMG_20210521_232406.jpg
141
+ DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#3298312021_92f431e3e9_o.jpg
142
+ DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#47950134773_fbfff63f4e_o.jpg
143
+ DIS5K/DIS-TR/im/7#Electrical#11#VacuumCleaner#5448403677_6a29e21881_o.jpg
144
+ DIS5K/DIS-TR/im/7#Electrical#2#CeilingLamp#611568868_680ed5d39f_o.jpg
145
+ DIS5K/DIS-TR/im/7#Electrical#3#Fan#3391683115_990525a693_o.jpg
146
+ DIS5K/DIS-TR/im/7#Electrical#6#StreetLamp#150049122_0692266618_o.jpg
147
+ DIS5K/DIS-TR/im/7#Electrical#9#TransmissionTower#31433908671_7e7e277dfe_o.jpg
148
+ DIS5K/DIS-TR/im/8#Electronics#1#Antenna#8727884873_e0622ee5c4_o.jpg
149
+ DIS5K/DIS-TR/im/8#Electronics#2#Camcorder#4172690390_7e5f280ace_o.jpg
150
+ DIS5K/DIS-TR/im/8#Electronics#3#Earphone#413984555_f290febdf5_o.jpg
151
+ DIS5K/DIS-TR/im/8#Electronics#5#Headset#30574225373_3717ed9fa4_o.jpg
152
+ DIS5K/DIS-TR/im/8#Electronics#6#Microphone#538006482_4aae4f5bd6_o.jpg
153
+ DIS5K/DIS-TR/im/8#Electronics#9#MusicPlayer#1306012480_2ea80d2afd_o.jpg
154
+ DIS5K/DIS-TR/im/9#Entertainment#1#GymEquipment#33071754135_8f3195cbd1_o.jpg
155
+ DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#2305807849_be53d724ea_o.jpg
156
+ DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#3862040422_5bbf903204_o.jpg
157
+ DIS5K/DIS-TR/im/9#Entertainment#3#OutdoorFitnessEquipment#10814507005_3dacaa28b3_o.jpg
158
+ DIS5K/DIS-TR/im/9#Entertainment#4#FerrisWheel#81640293_4b0ee62040_o.jpg
159
+ DIS5K/DIS-TR/im/9#Entertainment#5#Swing#49867339188_08073f4b76_o.jpg
160
+ DIS5K/DIS-VD/im/1#Accessories#1#Bag#6815402415_e01c1a41e6_o.jpg
161
+ DIS5K/DIS-VD/im/1#Accessories#5#Jewelry#2744070193_1486582e8d_o.jpg
162
+ DIS5K/DIS-VD/im/10#Frame#1#BasketballHoop#IMG_20210521_232650.jpg
163
+ DIS5K/DIS-VD/im/10#Frame#5#Rack#6156611713_49ebf12b1e_o.jpg
164
+ DIS5K/DIS-VD/im/11#Furniture#11#Handrail#3276641240_1b84b5af85_o.jpg
165
+ DIS5K/DIS-VD/im/11#Furniture#13#Ladder#33423266_5391cf47e9_o.jpg
166
+ DIS5K/DIS-VD/im/11#Furniture#17#Table#3725111755_4fc101e7ab_o.jpg
167
+ DIS5K/DIS-VD/im/11#Furniture#2#Bench#35556410400_7235b58070_o.jpg
168
+ DIS5K/DIS-VD/im/11#Furniture#4#Chair#3301769985_e49de6739f_o.jpg
169
+ DIS5K/DIS-VD/im/11#Furniture#6#DentalChair#23811071619_2a95c3a688_o.jpg
170
+ DIS5K/DIS-VD/im/11#Furniture#9#Easel#8322807354_df6d56542e_o.jpg
171
+ DIS5K/DIS-VD/im/13#Insect#10#Mosquito#12391674863_0cdf430d3f_o.jpg
172
+ DIS5K/DIS-VD/im/13#Insect#7#Dragonfly#14693028899_344ea118f2_o.jpg
173
+ DIS5K/DIS-VD/im/14#Kitchenware#10#WineGlass#4450148455_8f460f541a_o.jpg
174
+ DIS5K/DIS-VD/im/14#Kitchenware#3#Hydrovalve#IMG_20210520_203410.jpg
175
+ DIS5K/DIS-VD/im/15#Machine#3#PlowHarrow#34521712846_df4babb024_o.jpg
176
+ DIS5K/DIS-VD/im/16#Music Instrument#5#Trombone#6222242743_e7189405cd_o.jpg
177
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#12#Wheel#25677578797_ea47e1d9e8_o.jpg
178
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#2#Bicycle#5153474856_21560b081b_o.jpg
179
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#7#Mower#16992510572_8a6ff27398_o.jpg
180
+ DIS5K/DIS-VD/im/19#Ship#2#Canoe#40571458163_7faf8b73d9_o.jpg
181
+ DIS5K/DIS-VD/im/2#Aircraft#1#Airplane#4270588164_66a619e834_o.jpg
182
+ DIS5K/DIS-VD/im/2#Aircraft#4#Helicopter#86789665_650b94b2ee_o.jpg
183
+ DIS5K/DIS-VD/im/20#Sports#14#Wakesurfing#5589577652_5061c168d2_o.jpg
184
+ DIS5K/DIS-VD/im/21#Tool#10#Spade#37018312543_63b21b0784_o.jpg
185
+ DIS5K/DIS-VD/im/21#Tool#14#Sword#24789047250_42df9bf422_o.jpg
186
+ DIS5K/DIS-VD/im/21#Tool#18#Umbrella#IMG_20210513_140445.jpg
187
+ DIS5K/DIS-VD/im/21#Tool#6#Key#43939732715_5a6e28b518_o.jpg
188
+ DIS5K/DIS-VD/im/22#Weapon#1#Cannon#12758066705_90b54295e7_o.jpg
189
+ DIS5K/DIS-VD/im/22#Weapon#4#Rifle#8019368790_fb6dc469a7_o.jpg
190
+ DIS5K/DIS-VD/im/3#Aquatic#5#Shrimp#2582833427_7a99e7356e_o.jpg
191
+ DIS5K/DIS-VD/im/4#Architecture#12#Scaffold#1013402687_590750354e_o.jpg
192
+ DIS5K/DIS-VD/im/4#Architecture#13#Sculpture#17176841759_272a3ed6e3_o.jpg
193
+ DIS5K/DIS-VD/im/4#Architecture#14#Stair#15079108505_0d11281624_o.jpg
194
+ DIS5K/DIS-VD/im/4#Architecture#19#Windmill#2928111082_ceb3051c04_o.jpg
195
+ DIS5K/DIS-VD/im/4#Architecture#3#Crack#3551574032_17dd106d31_o.jpg
196
+ DIS5K/DIS-VD/im/4#Architecture#5#GasStation#4564307581_c3069bdc62_o.jpg
197
+ DIS5K/DIS-VD/im/4#Architecture#8#ObservationTower#2704526950_d4f0ddc807_o.jpg
198
+ DIS5K/DIS-VD/im/5#Artifact#3#Handcraft#10873642323_1bafce3aa5_o.jpg
199
+ DIS5K/DIS-VD/im/6#Automobile#11#Tractor#8594504006_0c2c557d85_o.jpg
200
+ DIS5K/DIS-VD/im/8#Electronics#3#Earphone#8106454803_1178d867cc_o.jpg
senior-demo/src/depth_pro/network/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro network blocks."""
senior-demo/src/depth_pro/network/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (211 Bytes). View file
 
senior-demo/src/depth_pro/network/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (209 Bytes). View file
 
senior-demo/src/depth_pro/network/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (5.33 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/decoder.cpython-39.pyc ADDED
Binary file (5.18 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (7.44 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/encoder.cpython-39.pyc ADDED
Binary file (7.29 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/fov.cpython-310.pyc ADDED
Binary file (2.11 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/fov.cpython-39.pyc ADDED
Binary file (2.1 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/vit.cpython-310.pyc ADDED
Binary file (2.83 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/vit.cpython-39.pyc ADDED
Binary file (2.83 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/vit_factory.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
senior-demo/src/depth_pro/network/__pycache__/vit_factory.cpython-39.pyc ADDED
Binary file (2.95 kB). View file
 
senior-demo/src/depth_pro/network/decoder.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+
3
+ Dense Prediction Transformer Decoder architecture.
4
+
5
+ Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Iterable
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class MultiresConvDecoder(nn.Module):
17
+ """Decoder for multi-resolution encodings."""
18
+
19
+ def __init__(
20
+ self,
21
+ dims_encoder: Iterable[int],
22
+ dim_decoder: int,
23
+ ):
24
+ """Initialize multiresolution convolutional decoder.
25
+
26
+ Args:
27
+ ----
28
+ dims_encoder: Expected dims at each level from the encoder.
29
+ dim_decoder: Dim of decoder features.
30
+
31
+ """
32
+ super().__init__()
33
+ self.dims_encoder = list(dims_encoder)
34
+ self.dim_decoder = dim_decoder
35
+ self.dim_out = dim_decoder
36
+
37
+ num_encoders = len(self.dims_encoder)
38
+
39
+ # At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
40
+ # when the dimensions mismatch. Otherwise we do not do anything, which is
41
+ # the default behavior of monodepth.
42
+ conv0 = (
43
+ nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
44
+ if self.dims_encoder[0] != dim_decoder
45
+ else nn.Identity()
46
+ )
47
+
48
+ convs = [conv0]
49
+ for i in range(1, num_encoders):
50
+ convs.append(
51
+ nn.Conv2d(
52
+ self.dims_encoder[i],
53
+ dim_decoder,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False,
58
+ )
59
+ )
60
+
61
+ self.convs = nn.ModuleList(convs)
62
+
63
+ fusions = []
64
+ for i in range(num_encoders):
65
+ fusions.append(
66
+ FeatureFusionBlock2d(
67
+ num_features=dim_decoder,
68
+ deconv=(i != 0),
69
+ batch_norm=False,
70
+ )
71
+ )
72
+ self.fusions = nn.ModuleList(fusions)
73
+
74
+ def forward(self, encodings: torch.Tensor) -> torch.Tensor:
75
+ """Decode the multi-resolution encodings."""
76
+ num_levels = len(encodings)
77
+ num_encoders = len(self.dims_encoder)
78
+
79
+ if num_levels != num_encoders:
80
+ raise ValueError(
81
+ f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
82
+ )
83
+
84
+ # Project features of different encoder dims to the same decoder dim.
85
+ # Fuse features from the lowest resolution (num_levels-1)
86
+ # to the highest (0).
87
+ features = self.convs[-1](encodings[-1])
88
+ lowres_features = features
89
+ features = self.fusions[-1](features)
90
+ for i in range(num_levels - 2, -1, -1):
91
+ features_i = self.convs[i](encodings[i])
92
+ features = self.fusions[i](features, features_i)
93
+ return features, lowres_features
94
+
95
+
96
+ class ResidualBlock(nn.Module):
97
+ """Generic implementation of residual blocks.
98
+
99
+ This implements a generic residual block from
100
+ He et al. - Identity Mappings in Deep Residual Networks (2016),
101
+ https://arxiv.org/abs/1603.05027
102
+ which can be further customized via factory functions.
103
+ """
104
+
105
+ def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
106
+ """Initialize ResidualBlock."""
107
+ super().__init__()
108
+ self.residual = residual
109
+ self.shortcut = shortcut
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ """Apply residual block."""
113
+ delta_x = self.residual(x)
114
+
115
+ if self.shortcut is not None:
116
+ x = self.shortcut(x)
117
+
118
+ return x + delta_x
119
+
120
+
121
+ class FeatureFusionBlock2d(nn.Module):
122
+ """Feature fusion for DPT."""
123
+
124
+ def __init__(
125
+ self,
126
+ num_features: int,
127
+ deconv: bool = False,
128
+ batch_norm: bool = False,
129
+ ):
130
+ """Initialize feature fusion block.
131
+
132
+ Args:
133
+ ----
134
+ num_features: Input and output dimensions.
135
+ deconv: Whether to use deconv before the final output conv.
136
+ batch_norm: Whether to use batch normalization in resnet blocks.
137
+
138
+ """
139
+ super().__init__()
140
+
141
+ self.resnet1 = self._residual_block(num_features, batch_norm)
142
+ self.resnet2 = self._residual_block(num_features, batch_norm)
143
+
144
+ self.use_deconv = deconv
145
+ if deconv:
146
+ self.deconv = nn.ConvTranspose2d(
147
+ in_channels=num_features,
148
+ out_channels=num_features,
149
+ kernel_size=2,
150
+ stride=2,
151
+ padding=0,
152
+ bias=False,
153
+ )
154
+
155
+ self.out_conv = nn.Conv2d(
156
+ num_features,
157
+ num_features,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0,
161
+ bias=True,
162
+ )
163
+
164
+ self.skip_add = nn.quantized.FloatFunctional()
165
+
166
+ def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
167
+ """Process and fuse input features."""
168
+ x = x0
169
+
170
+ if x1 is not None:
171
+ res = self.resnet1(x1)
172
+ x = self.skip_add.add(x, res)
173
+
174
+ x = self.resnet2(x)
175
+
176
+ if self.use_deconv:
177
+ x = self.deconv(x)
178
+ x = self.out_conv(x)
179
+
180
+ return x
181
+
182
+ @staticmethod
183
+ def _residual_block(num_features: int, batch_norm: bool):
184
+ """Create a residual block."""
185
+
186
+ def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
187
+ layers = [
188
+ nn.ReLU(False),
189
+ nn.Conv2d(
190
+ num_features,
191
+ num_features,
192
+ kernel_size=3,
193
+ stride=1,
194
+ padding=1,
195
+ bias=not batch_norm,
196
+ ),
197
+ ]
198
+ if batch_norm:
199
+ layers.append(nn.BatchNorm2d(dim))
200
+ return layers
201
+
202
+ residual = nn.Sequential(
203
+ *_create_block(dim=num_features, batch_norm=batch_norm),
204
+ *_create_block(dim=num_features, batch_norm=batch_norm),
205
+ )
206
+ return ResidualBlock(residual)
senior-demo/src/depth_pro/network/encoder.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # DepthProEncoder combining patch and image encoders.
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import Iterable, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class DepthProEncoder(nn.Module):
15
+ """DepthPro Encoder.
16
+
17
+ An encoder aimed at creating multi-resolution encodings from Vision Transformers.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dims_encoder: Iterable[int],
23
+ patch_encoder: nn.Module,
24
+ image_encoder: nn.Module,
25
+ hook_block_ids: Iterable[int],
26
+ decoder_features: int,
27
+ ):
28
+ """Initialize DepthProEncoder.
29
+
30
+ The framework
31
+ 1. creates an image pyramid,
32
+ 2. generates overlapping patches with a sliding window at each pyramid level,
33
+ 3. creates batched encodings via vision transformer backbones,
34
+ 4. produces multi-resolution encodings.
35
+
36
+ Args:
37
+ ----
38
+ img_size: Backbone image resolution.
39
+ dims_encoder: Dimensions of the encoder at different layers.
40
+ patch_encoder: Backbone used for patches.
41
+ image_encoder: Backbone used for global image encoder.
42
+ hook_block_ids: Hooks to obtain intermediate features for the patch encoder model.
43
+ decoder_features: Number of feature output in the decoder.
44
+
45
+ """
46
+ super().__init__()
47
+
48
+ self.dims_encoder = list(dims_encoder)
49
+ self.patch_encoder = patch_encoder
50
+ self.image_encoder = image_encoder
51
+ self.hook_block_ids = list(hook_block_ids)
52
+
53
+ patch_encoder_embed_dim = patch_encoder.embed_dim
54
+ image_encoder_embed_dim = image_encoder.embed_dim
55
+
56
+ self.out_size = int(
57
+ patch_encoder.patch_embed.img_size[0] // patch_encoder.patch_embed.patch_size[0]
58
+ )
59
+
60
+ def _create_project_upsample_block(
61
+ dim_in: int,
62
+ dim_out: int,
63
+ upsample_layers: int,
64
+ dim_int: Optional[int] = None,
65
+ ) -> nn.Module:
66
+ if dim_int is None:
67
+ dim_int = dim_out
68
+ # Projection.
69
+ blocks = [
70
+ nn.Conv2d(
71
+ in_channels=dim_in,
72
+ out_channels=dim_int,
73
+ kernel_size=1,
74
+ stride=1,
75
+ padding=0,
76
+ bias=False,
77
+ )
78
+ ]
79
+
80
+ # Upsampling.
81
+ blocks += [
82
+ nn.ConvTranspose2d(
83
+ in_channels=dim_int if i == 0 else dim_out,
84
+ out_channels=dim_out,
85
+ kernel_size=2,
86
+ stride=2,
87
+ padding=0,
88
+ bias=False,
89
+ )
90
+ for i in range(upsample_layers)
91
+ ]
92
+
93
+ return nn.Sequential(*blocks)
94
+
95
+ self.upsample_latent0 = _create_project_upsample_block(
96
+ dim_in=patch_encoder_embed_dim,
97
+ dim_int=self.dims_encoder[0],
98
+ dim_out=decoder_features,
99
+ upsample_layers=3,
100
+ )
101
+ self.upsample_latent1 = _create_project_upsample_block(
102
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=2
103
+ )
104
+
105
+ self.upsample0 = _create_project_upsample_block(
106
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=1
107
+ )
108
+ self.upsample1 = _create_project_upsample_block(
109
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1
110
+ )
111
+ self.upsample2 = _create_project_upsample_block(
112
+ dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1
113
+ )
114
+
115
+ self.upsample_lowres = nn.ConvTranspose2d(
116
+ in_channels=image_encoder_embed_dim,
117
+ out_channels=self.dims_encoder[3],
118
+ kernel_size=2,
119
+ stride=2,
120
+ padding=0,
121
+ bias=True,
122
+ )
123
+ self.fuse_lowres = nn.Conv2d(
124
+ in_channels=(self.dims_encoder[3] + self.dims_encoder[3]),
125
+ out_channels=self.dims_encoder[3],
126
+ kernel_size=1,
127
+ stride=1,
128
+ padding=0,
129
+ bias=True,
130
+ )
131
+
132
+ # Obtain intermediate outputs of the blocks.
133
+ self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook(
134
+ self._hook0
135
+ )
136
+ self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook(
137
+ self._hook1
138
+ )
139
+
140
+ def _hook0(self, model, input, output):
141
+ self.backbone_highres_hook0 = output
142
+
143
+ def _hook1(self, model, input, output):
144
+ self.backbone_highres_hook1 = output
145
+
146
+ @property
147
+ def img_size(self) -> int:
148
+ """Return the full image size of the SPN network."""
149
+ return self.patch_encoder.patch_embed.img_size[0] * 4
150
+
151
+ def _create_pyramid(
152
+ self, x: torch.Tensor
153
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154
+ """Create a 3-level image pyramid."""
155
+ # Original resolution: 1536 by default.
156
+ x0 = x
157
+
158
+ # Middle resolution: 768 by default.
159
+ x1 = F.interpolate(
160
+ x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False
161
+ )
162
+
163
+ # Low resolution: 384 by default, corresponding to the backbone resolution.
164
+ x2 = F.interpolate(
165
+ x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False
166
+ )
167
+
168
+ return x0, x1, x2
169
+
170
+ def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
171
+ """Split the input into small patches with sliding window."""
172
+ patch_size = 384
173
+ patch_stride = int(patch_size * (1 - overlap_ratio))
174
+
175
+ image_size = x.shape[-1]
176
+ steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1
177
+
178
+ x_patch_list = []
179
+ for j in range(steps):
180
+ j0 = j * patch_stride
181
+ j1 = j0 + patch_size
182
+
183
+ for i in range(steps):
184
+ i0 = i * patch_stride
185
+ i1 = i0 + patch_size
186
+ x_patch_list.append(x[..., j0:j1, i0:i1])
187
+
188
+ return torch.cat(x_patch_list, dim=0)
189
+
190
+ def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor:
191
+ """Merge the patched input into a image with sliding window."""
192
+ steps = int(math.sqrt(x.shape[0] // batch_size))
193
+
194
+ idx = 0
195
+
196
+ output_list = []
197
+ for j in range(steps):
198
+ output_row_list = []
199
+ for i in range(steps):
200
+ output = x[batch_size * idx : batch_size * (idx + 1)]
201
+
202
+ if j != 0:
203
+ output = output[..., padding:, :]
204
+ if i != 0:
205
+ output = output[..., :, padding:]
206
+ if j != steps - 1:
207
+ output = output[..., :-padding, :]
208
+ if i != steps - 1:
209
+ output = output[..., :, :-padding]
210
+
211
+ output_row_list.append(output)
212
+ idx += 1
213
+
214
+ output_row = torch.cat(output_row_list, dim=-1)
215
+ output_list.append(output_row)
216
+ output = torch.cat(output_list, dim=-2)
217
+ return output
218
+
219
+ def reshape_feature(
220
+ self, embeddings: torch.Tensor, width, height, cls_token_offset=1
221
+ ):
222
+ """Discard class token and reshape 1D feature map to a 2D grid."""
223
+ b, hw, c = embeddings.shape
224
+
225
+ # Remove class token.
226
+ if cls_token_offset > 0:
227
+ embeddings = embeddings[:, cls_token_offset:, :]
228
+
229
+ # Shape: (batch, height, width, dim) -> (batch, dim, height, width)
230
+ embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2)
231
+ return embeddings
232
+
233
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
234
+ """Encode input at multiple resolutions.
235
+
236
+ Args:
237
+ ----
238
+ x (torch.Tensor): Input image.
239
+
240
+ Returns:
241
+ -------
242
+ Multi resolution encoded features.
243
+
244
+ """
245
+ batch_size = x.shape[0]
246
+
247
+ # Step 0: create a 3-level image pyramid.
248
+ x0, x1, x2 = self._create_pyramid(x)
249
+
250
+ # Step 1: split to create batched overlapped mini-images at the backbone (BeiT/ViT/Dino)
251
+ # resolution.
252
+ # 5x5 @ 384x384 at the highest resolution (1536x1536).
253
+ x0_patches = self.split(x0, overlap_ratio=0.25)
254
+ # 3x3 @ 384x384 at the middle resolution (768x768).
255
+ x1_patches = self.split(x1, overlap_ratio=0.5)
256
+ # 1x1 # 384x384 at the lowest resolution (384x384).
257
+ x2_patches = x2
258
+
259
+ # Concatenate all the sliding window patches and form a batch of size (35=5x5+3x3+1x1).
260
+ x_pyramid_patches = torch.cat(
261
+ (x0_patches, x1_patches, x2_patches),
262
+ dim=0,
263
+ )
264
+
265
+ # Step 2: Run the backbone (BeiT) model and get the result of large batch size.
266
+ x_pyramid_encodings = self.patch_encoder(x_pyramid_patches)
267
+ x_pyramid_encodings = self.reshape_feature(
268
+ x_pyramid_encodings, self.out_size, self.out_size
269
+ )
270
+
271
+ # Step 3: merging.
272
+ # Merge highres latent encoding.
273
+ x_latent0_encodings = self.reshape_feature(
274
+ self.backbone_highres_hook0,
275
+ self.out_size,
276
+ self.out_size,
277
+ )
278
+ x_latent0_features = self.merge(
279
+ x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
280
+ )
281
+
282
+ x_latent1_encodings = self.reshape_feature(
283
+ self.backbone_highres_hook1,
284
+ self.out_size,
285
+ self.out_size,
286
+ )
287
+ x_latent1_features = self.merge(
288
+ x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
289
+ )
290
+
291
+ # Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
292
+ x0_encodings, x1_encodings, x2_encodings = torch.split(
293
+ x_pyramid_encodings,
294
+ [len(x0_patches), len(x1_patches), len(x2_patches)],
295
+ dim=0,
296
+ )
297
+
298
+ # 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
299
+ x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
300
+
301
+ # 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
302
+ x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
303
+
304
+ # 24x24 feature maps.
305
+ x2_features = x2_encodings
306
+
307
+ # Apply the image encoder model.
308
+ x_global_features = self.image_encoder(x2_patches)
309
+ x_global_features = self.reshape_feature(
310
+ x_global_features, self.out_size, self.out_size
311
+ )
312
+
313
+ # Upsample feature maps.
314
+ x_latent0_features = self.upsample_latent0(x_latent0_features)
315
+ x_latent1_features = self.upsample_latent1(x_latent1_features)
316
+
317
+ x0_features = self.upsample0(x0_features)
318
+ x1_features = self.upsample1(x1_features)
319
+ x2_features = self.upsample2(x2_features)
320
+
321
+ x_global_features = self.upsample_lowres(x_global_features)
322
+ x_global_features = self.fuse_lowres(
323
+ torch.cat((x2_features, x_global_features), dim=1)
324
+ )
325
+
326
+ return [
327
+ x_latent0_features,
328
+ x_latent1_features,
329
+ x0_features,
330
+ x1_features,
331
+ x_global_features,
332
+ ]
senior-demo/src/depth_pro/network/fov.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # Field of View network architecture.
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class FOVNetwork(nn.Module):
12
+ """Field of View estimation network."""
13
+
14
+ def __init__(
15
+ self,
16
+ num_features: int,
17
+ fov_encoder: Optional[nn.Module] = None,
18
+ ):
19
+ """Initialize the Field of View estimation block.
20
+
21
+ Args:
22
+ ----
23
+ num_features: Number of features used.
24
+ fov_encoder: Optional encoder to bring additional network capacity.
25
+
26
+ """
27
+ super().__init__()
28
+
29
+ # Create FOV head.
30
+ fov_head0 = [
31
+ nn.Conv2d(
32
+ num_features, num_features // 2, kernel_size=3, stride=2, padding=1
33
+ ), # 128 x 24 x 24
34
+ nn.ReLU(True),
35
+ ]
36
+ fov_head = [
37
+ nn.Conv2d(
38
+ num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=1
39
+ ), # 64 x 12 x 12
40
+ nn.ReLU(True),
41
+ nn.Conv2d(
42
+ num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=1
43
+ ), # 32 x 6 x 6
44
+ nn.ReLU(True),
45
+ nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0),
46
+ ]
47
+ if fov_encoder is not None:
48
+ self.encoder = nn.Sequential(
49
+ fov_encoder, nn.Linear(fov_encoder.embed_dim, num_features // 2)
50
+ )
51
+ self.downsample = nn.Sequential(*fov_head0)
52
+ else:
53
+ fov_head = fov_head0 + fov_head
54
+ self.head = nn.Sequential(*fov_head)
55
+
56
+ def forward(self, x: torch.Tensor, lowres_feature: torch.Tensor) -> torch.Tensor:
57
+ """Forward the fov network.
58
+
59
+ Args:
60
+ ----
61
+ x (torch.Tensor): Input image.
62
+ lowres_feature (torch.Tensor): Low resolution feature.
63
+
64
+ Returns:
65
+ -------
66
+ The field of view tensor.
67
+
68
+ """
69
+ if hasattr(self, "encoder"):
70
+ x = F.interpolate(
71
+ x,
72
+ size=None,
73
+ scale_factor=0.25,
74
+ mode="bilinear",
75
+ align_corners=False,
76
+ )
77
+ x = self.encoder(x)[:, 1:].permute(0, 2, 1)
78
+ lowres_feature = self.downsample(lowres_feature)
79
+ x = x.reshape_as(lowres_feature) + lowres_feature
80
+ else:
81
+ x = lowres_feature
82
+ return self.head(x)
senior-demo/src/depth_pro/network/vit.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+
3
+
4
+ try:
5
+ from timm.layers import resample_abs_pos_embed
6
+ except ImportError as err:
7
+ print("ImportError: {0}".format(err))
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+
13
+ def make_vit_b16_backbone(
14
+ model,
15
+ encoder_feature_dims,
16
+ encoder_feature_layer_ids,
17
+ vit_features,
18
+ start_index=1,
19
+ use_grad_checkpointing=False,
20
+ ) -> nn.Module:
21
+ """Make a ViTb16 backbone for the DPT model."""
22
+ if use_grad_checkpointing:
23
+ model.set_grad_checkpointing()
24
+
25
+ vit_model = nn.Module()
26
+ vit_model.hooks = encoder_feature_layer_ids
27
+ vit_model.model = model
28
+ vit_model.features = encoder_feature_dims
29
+ vit_model.vit_features = vit_features
30
+ vit_model.model.start_index = start_index
31
+ vit_model.model.patch_size = vit_model.model.patch_embed.patch_size
32
+ vit_model.model.is_vit = True
33
+ vit_model.model.forward = vit_model.model.forward_features
34
+
35
+ return vit_model
36
+
37
+
38
+ def forward_features_eva_fixed(self, x):
39
+ """Encode features."""
40
+ x = self.patch_embed(x)
41
+ x, rot_pos_embed = self._pos_embed(x)
42
+ for blk in self.blocks:
43
+ if self.grad_checkpointing:
44
+ x = checkpoint(blk, x, rot_pos_embed)
45
+ else:
46
+ x = blk(x, rot_pos_embed)
47
+ x = self.norm(x)
48
+ return x
49
+
50
+
51
+ def resize_vit(model: nn.Module, img_size) -> nn.Module:
52
+ """Resample the ViT module to the given size."""
53
+ patch_size = model.patch_embed.patch_size
54
+ model.patch_embed.img_size = img_size
55
+ grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
56
+ model.patch_embed.grid_size = grid_size
57
+
58
+ pos_embed = resample_abs_pos_embed(
59
+ model.pos_embed,
60
+ grid_size, # img_size
61
+ num_prefix_tokens=(
62
+ 0 if getattr(model, "no_embed_class", False) else model.num_prefix_tokens
63
+ ),
64
+ )
65
+ model.pos_embed = torch.nn.Parameter(pos_embed)
66
+
67
+ return model
68
+
69
+
70
+ def resize_patch_embed(model: nn.Module, new_patch_size=(16, 16)) -> nn.Module:
71
+ """Resample the ViT patch size to the given one."""
72
+ # interpolate patch embedding
73
+ if hasattr(model, "patch_embed"):
74
+ old_patch_size = model.patch_embed.patch_size
75
+
76
+ if (
77
+ new_patch_size[0] != old_patch_size[0]
78
+ or new_patch_size[1] != old_patch_size[1]
79
+ ):
80
+ patch_embed_proj = model.patch_embed.proj.weight
81
+ patch_embed_proj_bias = model.patch_embed.proj.bias
82
+ use_bias = True if patch_embed_proj_bias is not None else False
83
+ _, _, h, w = patch_embed_proj.shape
84
+
85
+ new_patch_embed_proj = torch.nn.functional.interpolate(
86
+ patch_embed_proj,
87
+ size=[new_patch_size[0], new_patch_size[1]],
88
+ mode="bicubic",
89
+ align_corners=False,
90
+ )
91
+ new_patch_embed_proj = (
92
+ new_patch_embed_proj * (h / new_patch_size[0]) * (w / new_patch_size[1])
93
+ )
94
+
95
+ model.patch_embed.proj = nn.Conv2d(
96
+ in_channels=model.patch_embed.proj.in_channels,
97
+ out_channels=model.patch_embed.proj.out_channels,
98
+ kernel_size=new_patch_size,
99
+ stride=new_patch_size,
100
+ bias=use_bias,
101
+ )
102
+
103
+ if use_bias:
104
+ model.patch_embed.proj.bias = patch_embed_proj_bias
105
+
106
+ model.patch_embed.proj.weight = torch.nn.Parameter(new_patch_embed_proj)
107
+
108
+ model.patch_size = new_patch_size
109
+ model.patch_embed.patch_size = new_patch_size
110
+ model.patch_embed.img_size = (
111
+ int(
112
+ model.patch_embed.img_size[0]
113
+ * new_patch_size[0]
114
+ / old_patch_size[0]
115
+ ),
116
+ int(
117
+ model.patch_embed.img_size[1]
118
+ * new_patch_size[1]
119
+ / old_patch_size[1]
120
+ ),
121
+ )
122
+
123
+ return model
senior-demo/src/depth_pro/network/vit_factory.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # Factory functions to build and load ViT models.
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import types
9
+ from dataclasses import dataclass
10
+ from typing import Dict, List, Literal, Optional
11
+
12
+ import timm
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from .vit import (
17
+ forward_features_eva_fixed,
18
+ make_vit_b16_backbone,
19
+ resize_patch_embed,
20
+ resize_vit,
21
+ )
22
+
23
+ LOGGER = logging.getLogger(__name__)
24
+
25
+
26
+ ViTPreset = Literal[
27
+ "dinov2l16_384",
28
+ ]
29
+
30
+
31
+ @dataclass
32
+ class ViTConfig:
33
+ """Configuration for ViT."""
34
+
35
+ in_chans: int
36
+ embed_dim: int
37
+
38
+ img_size: int = 384
39
+ patch_size: int = 16
40
+
41
+ # In case we need to rescale the backbone when loading from timm.
42
+ timm_preset: Optional[str] = None
43
+ timm_img_size: int = 384
44
+ timm_patch_size: int = 16
45
+
46
+ # The following 2 parameters are only used by DPT. See dpt_factory.py.
47
+ encoder_feature_layer_ids: List[int] = None
48
+ """The layers in the Beit/ViT used to constructs encoder features for DPT."""
49
+ encoder_feature_dims: List[int] = None
50
+ """The dimension of features of encoder layers from Beit/ViT features for DPT."""
51
+
52
+
53
+ VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = {
54
+ "dinov2l16_384": ViTConfig(
55
+ in_chans=3,
56
+ embed_dim=1024,
57
+ encoder_feature_layer_ids=[5, 11, 17, 23],
58
+ encoder_feature_dims=[256, 512, 1024, 1024],
59
+ img_size=384,
60
+ patch_size=16,
61
+ timm_preset="vit_large_patch14_dinov2",
62
+ timm_img_size=518,
63
+ timm_patch_size=14,
64
+ ),
65
+ }
66
+
67
+
68
+ def create_vit(
69
+ preset: ViTPreset,
70
+ use_pretrained: bool = False,
71
+ checkpoint_uri: str | None = None,
72
+ use_grad_checkpointing: bool = False,
73
+ ) -> nn.Module:
74
+ """Create and load a VIT backbone module.
75
+
76
+ Args:
77
+ ----
78
+ preset: The VIT preset to load the pre-defined config.
79
+ use_pretrained: Load pretrained weights if True, default is False.
80
+ checkpoint_uri: Checkpoint to load the wights from.
81
+ use_grad_checkpointing: Use grandient checkpointing.
82
+
83
+ Returns:
84
+ -------
85
+ A Torch ViT backbone module.
86
+
87
+ """
88
+ config = VIT_CONFIG_DICT[preset]
89
+
90
+ img_size = (config.img_size, config.img_size)
91
+ patch_size = (config.patch_size, config.patch_size)
92
+
93
+ if "eva02" in preset:
94
+ model = timm.create_model(config.timm_preset, pretrained=use_pretrained)
95
+ model.forward_features = types.MethodType(forward_features_eva_fixed, model)
96
+ else:
97
+ model = timm.create_model(
98
+ config.timm_preset, pretrained=use_pretrained, dynamic_img_size=True
99
+ )
100
+ model = make_vit_b16_backbone(
101
+ model,
102
+ encoder_feature_dims=config.encoder_feature_dims,
103
+ encoder_feature_layer_ids=config.encoder_feature_layer_ids,
104
+ vit_features=config.embed_dim,
105
+ use_grad_checkpointing=use_grad_checkpointing,
106
+ )
107
+ if config.patch_size != config.timm_patch_size:
108
+ model.model = resize_patch_embed(model.model, new_patch_size=patch_size)
109
+ if config.img_size != config.timm_img_size:
110
+ model.model = resize_vit(model.model, img_size=img_size)
111
+
112
+ if checkpoint_uri is not None:
113
+ state_dict = torch.load(checkpoint_uri, map_location="cpu")
114
+ missing_keys, unexpected_keys = model.load_state_dict(
115
+ state_dict=state_dict, strict=False
116
+ )
117
+
118
+ if len(unexpected_keys) != 0:
119
+ raise KeyError(f"Found unexpected keys when loading vit: {unexpected_keys}")
120
+ if len(missing_keys) != 0:
121
+ raise KeyError(f"Keys are missing when loading vit: {missing_keys}")
122
+
123
+ LOGGER.info(model)
124
+ return model.model
senior-demo/src/depth_pro/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Tuple, Union
6
+
7
+ import numpy as np
8
+ import pillow_heif
9
+ from PIL import ExifTags, Image, TiffTags
10
+ from pillow_heif import register_heif_opener
11
+
12
+ register_heif_opener()
13
+ LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ def extract_exif(img_pil: Image) -> Dict[str, Any]:
17
+ """Return exif information as a dictionary.
18
+
19
+ Args:
20
+ ----
21
+ img_pil: A Pillow image.
22
+
23
+ Returns:
24
+ -------
25
+ A dictionary with extracted EXIF information.
26
+
27
+ """
28
+ # Get full exif description from get_ifd(0x8769):
29
+ # cf https://pillow.readthedocs.io/en/stable/releasenotes/8.2.0.html#image-getexif-exif-and-gps-ifd
30
+ img_exif = img_pil.getexif().get_ifd(0x8769)
31
+ exif_dict = {ExifTags.TAGS[k]: v for k, v in img_exif.items() if k in ExifTags.TAGS}
32
+
33
+ tiff_tags = img_pil.getexif()
34
+ tiff_dict = {
35
+ TiffTags.TAGS_V2[k].name: v
36
+ for k, v in tiff_tags.items()
37
+ if k in TiffTags.TAGS_V2
38
+ }
39
+ return {**exif_dict, **tiff_dict}
40
+
41
+
42
+ def fpx_from_f35(width: float, height: float, f_mm: float = 50) -> float:
43
+ """Convert a focal length given in mm (35mm film equivalent) to pixels."""
44
+ return f_mm * np.sqrt(width**2.0 + height**2.0) / np.sqrt(36**2 + 24**2)
45
+
46
+
47
+ def load_rgb(
48
+ path: Union[Path, str], auto_rotate: bool = True, remove_alpha: bool = True
49
+ ) -> Tuple[np.ndarray, List[bytes], float]:
50
+ """Load an RGB image.
51
+
52
+ Args:
53
+ ----
54
+ path: The url to the image to load.
55
+ auto_rotate: Rotate the image based on the EXIF data, default is True.
56
+ remove_alpha: Remove the alpha channel, default is True.
57
+
58
+ Returns:
59
+ -------
60
+ img: The image loaded as a numpy array.
61
+ icc_profile: The color profile of the image.
62
+ f_px: The optional focal length in pixels, extracting from the exif data.
63
+
64
+ """
65
+ LOGGER.debug(f"Loading image {path} ...")
66
+
67
+ path = Path(path)
68
+ if path.suffix.lower() in [".heic"]:
69
+ heif_file = pillow_heif.open_heif(path, convert_hdr_to_8bit=True)
70
+ img_pil = heif_file.to_pillow()
71
+ else:
72
+ img_pil = Image.open(path)
73
+
74
+ img_exif = extract_exif(img_pil)
75
+ icc_profile = img_pil.info.get("icc_profile", None)
76
+
77
+ # Rotate the image.
78
+ if auto_rotate:
79
+ exif_orientation = img_exif.get("Orientation", 1)
80
+ if exif_orientation == 3:
81
+ img_pil = img_pil.transpose(Image.ROTATE_180)
82
+ elif exif_orientation == 6:
83
+ img_pil = img_pil.transpose(Image.ROTATE_270)
84
+ elif exif_orientation == 8:
85
+ img_pil = img_pil.transpose(Image.ROTATE_90)
86
+ elif exif_orientation != 1:
87
+ LOGGER.warning(f"Ignoring image orientation {exif_orientation}.")
88
+
89
+ img = np.array(img_pil)
90
+ # Convert to RGB if single channel.
91
+ if img.ndim < 3 or img.shape[2] == 1:
92
+ img = np.dstack((img, img, img))
93
+
94
+ if remove_alpha:
95
+ img = img[:, :, :3]
96
+
97
+ LOGGER.debug(f"\tHxW: {img.shape[0]}x{img.shape[1]}")
98
+
99
+ # Extract the focal length from exif data.
100
+ f_35mm = img_exif.get(
101
+ "FocalLengthIn35mmFilm",
102
+ img_exif.get(
103
+ "FocalLenIn35mmFilm", img_exif.get("FocalLengthIn35mmFormat", None)
104
+ ),
105
+ )
106
+ if f_35mm is not None and f_35mm > 0:
107
+ LOGGER.debug(f"\tfocal length @ 35mm film: {f_35mm}mm")
108
+ f_px = fpx_from_f35(img.shape[1], img.shape[0], f_35mm)
109
+ else:
110
+ f_px = None
111
+
112
+ return img, icc_profile, f_px
senior-demo/test.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ print("PyTorch version:", torch.__version__)
4
+ print("CUDA available:", torch.cuda.is_available())
5
+ print("CUDA compiled version:", torch.version.cuda)
6
+ print("GPU count:", torch.cuda.device_count())