qizhangslam commited on
Commit
f9347d2
·
verified ·
1 Parent(s): ad313b8

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. outdoor_v48_4gpu_v2/code/05_02-14:21:58/mytrain.py +601 -0
  2. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/camera_head.py +175 -0
  3. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/dpt_head.py +471 -0
  4. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/head_act.py +116 -0
  5. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_head.py +102 -0
  6. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/__init__.py +0 -0
  7. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/base_track_predictor.py +195 -0
  8. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/blocks.py +237 -0
  9. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/modules.py +211 -0
  10. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/utils.py +216 -0
  11. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/utils.py +99 -0
  12. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/__init__.py +5 -0
  13. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/attention.py +129 -0
  14. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/block.py +263 -0
  15. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/drop_path.py +24 -0
  16. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/layer_scale.py +20 -0
  17. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/mlp.py +30 -0
  18. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/patch_embed.py +79 -0
  19. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/rope.py +172 -0
  20. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/swiglu_ffn.py +67 -0
  21. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/vision_transformer.py +398 -0
  22. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/models/aggregator.py +394 -0
  23. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/models/streamvggt.py +248 -0
  24. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/geometry.py +166 -0
  25. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/load_fn.py +146 -0
  26. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/pose_enc.py +130 -0
  27. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/rotation.py +138 -0
  28. outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/visual_track.py +239 -0
  29. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/camera_head.py +162 -0
  30. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/dpt_head.py +497 -0
  31. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/head_act.py +125 -0
  32. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_head.py +108 -0
  33. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/__init__.py +5 -0
  34. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/base_track_predictor.py +209 -0
  35. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/blocks.py +246 -0
  36. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/modules.py +218 -0
  37. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/utils.py +226 -0
  38. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/utils.py +108 -0
  39. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/layers/patch_embed.py +88 -0
  40. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/models/aggregator.py +332 -0
  41. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/models/vggt.py +228 -0
  42. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/augmentation.py +72 -0
  43. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/general.py +369 -0
  44. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/normalization.py +130 -0
  45. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/normalization_v37.py +130 -0
  46. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/geometry.py +166 -0
  47. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/load_fn.py +147 -0
  48. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/pose_enc.py +130 -0
  49. outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/visual_track.py +239 -0
  50. outdoor_v48_4gpu_v2/mytrain.log +0 -0
outdoor_v48_4gpu_v2/code/05_02-14:21:58/mytrain.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # training code for CUT3R
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # DUSt3R: https://github.com/naver/dust3r
6
+ # --------------------------------------------------------
7
+ import argparse
8
+ import datetime
9
+ import json
10
+ import numpy as np
11
+ import os
12
+ import sys
13
+ import time
14
+ import math
15
+ from collections import defaultdict
16
+ from pathlib import Path
17
+ from typing import Sized
18
+ from itertools import islice
19
+
20
+ import torch
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.nn.functional as F
23
+ from torch.utils.tensorboard import SummaryWriter
24
+
25
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
26
+
27
+ from dust3r.model import (
28
+ PreTrainedModel,
29
+ ARCroco3DStereo,
30
+ ARCroco3DStereoConfig,
31
+ inf,
32
+ strip_module,
33
+ ) # noqa: F401, needed when loading the model
34
+ from dust3r.datasets import get_data_loader
35
+ from dust3r.losses_noteacher import * # noqa: F401, needed when loading the model
36
+ from dust3r.inference import loss_of_one_batch # noqa
37
+ from dust3r.viz import colorize
38
+ from dust3r.utils.render import get_render_results
39
+ import dust3r.utils.path_to_croco # noqa: F401
40
+ import croco.utils.misc as misc # noqa
41
+ from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa
42
+
43
+ import hydra
44
+ from omegaconf import OmegaConf
45
+ import logging
46
+ import pathlib
47
+ from tqdm import tqdm
48
+ import random
49
+ import builtins
50
+ import shutil
51
+
52
+ from accelerate import Accelerator
53
+ from accelerate import DistributedDataParallelKwargs, InitProcessGroupKwargs
54
+ from accelerate.logging import get_logger
55
+ from datetime import timedelta
56
+ import torch.multiprocessing
57
+
58
+ from slamformer.models.slamformer import SLAMFormer # upstream typo: pi3 → slamformer
59
+
60
+
61
+ torch.multiprocessing.set_sharing_strategy("file_system")
62
+
63
+ printer = get_logger(__name__, log_level="DEBUG")
64
+
65
+
66
+ def setup_for_distributed(accelerator: Accelerator):
67
+ """
68
+ This function disables printing when not in master process
69
+ """
70
+ builtin_print = builtins.print
71
+
72
+ def print(*args, **kwargs):
73
+ force = kwargs.pop("force", False)
74
+ force = force or (accelerator.num_processes > 8)
75
+ if accelerator.is_main_process or force:
76
+ now = datetime.datetime.now().time()
77
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
78
+ builtin_print(*args, **kwargs)
79
+
80
+ builtins.print = print
81
+
82
+
83
+ def save_current_code(outdir):
84
+ now = datetime.datetime.now() # current date and time
85
+ date_time = now.strftime("%m_%d-%H:%M:%S")
86
+ src_dir = "."
87
+ dst_dir = os.path.join(outdir, "code", "{}".format(date_time))
88
+ shutil.copytree(
89
+ src_dir,
90
+ dst_dir,
91
+ ignore=shutil.ignore_patterns(
92
+ ".vscode*",
93
+ "assets*",
94
+ "example*",
95
+ "checkpoints*",
96
+ "OLD*",
97
+ "logs*",
98
+ "out*",
99
+ "runs*",
100
+ "*.png",
101
+ "*.mp4",
102
+ "*__pycache__*",
103
+ "*.git*",
104
+ "*.idea*",
105
+ "*.zip",
106
+ "*.jpg",
107
+ ),
108
+ dirs_exist_ok=True,
109
+ )
110
+ return dst_dir
111
+
112
+
113
+ def train(args):
114
+
115
+ accelerator = Accelerator(
116
+ gradient_accumulation_steps=args.accum_iter,
117
+ mixed_precision="bf16",
118
+ kwargs_handlers=[
119
+ DistributedDataParallelKwargs(find_unused_parameters=True),
120
+ InitProcessGroupKwargs(timeout=timedelta(seconds=6000)),
121
+ ],
122
+ )
123
+ device = accelerator.device
124
+
125
+ setup_for_distributed(accelerator)
126
+
127
+ printer.info("output_dir: " + args.output_dir)
128
+ if args.output_dir:
129
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
130
+
131
+ if accelerator.is_main_process:
132
+ dst_dir = save_current_code(outdir=args.output_dir)
133
+ printer.info(f"Saving current code to {dst_dir}")
134
+
135
+ # auto resume
136
+ if not args.resume:
137
+ last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-last.pth")
138
+ #last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-7.pth")
139
+
140
+ args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
141
+
142
+ printer.info("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
143
+
144
+ # fix the seed
145
+ seed = args.seed + accelerator.state.process_index
146
+ printer.info(
147
+ f"Setting seed to {seed} for process {accelerator.state.process_index}"
148
+ )
149
+ torch.manual_seed(seed)
150
+ np.random.seed(seed)
151
+ random.seed(seed)
152
+ cudnn.benchmark = args.benchmark
153
+
154
+ # training dataset and loader
155
+ printer.info("Building train dataset %s", args.train_dataset)
156
+ # dataset and loader
157
+ data_loader_train = build_dataset(
158
+ args.train_dataset,
159
+ args.batch_size,
160
+ args.num_workers,
161
+ accelerator=accelerator,
162
+ test=False,
163
+ fixed_length=args.fixed_length
164
+ )
165
+ printer.info("Building test dataset %s", args.test_dataset)
166
+ data_loader_test = {
167
+ dataset.split("(")[0]: build_dataset(
168
+ dataset,
169
+ args.batch_size,
170
+ args.num_workers,
171
+ accelerator=accelerator,
172
+ test=True,
173
+ fixed_length=True
174
+ )
175
+ for dataset in args.test_dataset.split("+")
176
+ }
177
+
178
+ # model
179
+ printer.info("Loading model")
180
+ model = SLAMFormer()
181
+ teacher = None
182
+
183
+ # model: PreTrainedModel = eval(args.model)
184
+ printer.info(f"All model parameters: {sum(p.numel() for p in model.parameters())}")
185
+
186
+
187
+ printer.info(f">> Creating train criterion = {args.train_criterion}")
188
+ train_criterion = eval(args.train_criterion).to(device)
189
+ printer.info(
190
+ f">> Creating test criterion = {args.test_criterion or args.train_criterion}"
191
+ )
192
+ test_criterion = eval(args.test_criterion or args.criterion).to(device)
193
+
194
+ model.to(device)
195
+
196
+ if args.gradient_checkpointing:
197
+ model.gradient_checkpointing_enable()
198
+ if args.long_context:
199
+ model.fixed_input_length = False
200
+
201
+ freeze_keys = None
202
+ print('NOTE:', args.pretrained, args.resume)
203
+ if args.pretrained and not args.resume:
204
+ printer.info(f"Loading pretrained: {args.pretrained}")
205
+ ckpt = torch.load(args.pretrained, map_location=device)
206
+ '''
207
+ ckpt_ = dict()
208
+ for key, v in ckpt.items():
209
+ ckpt_[key[7:]] = v
210
+ '''
211
+ '''
212
+ freeze_keys = list(ckpt.keys())
213
+
214
+ ls = dict()
215
+ for key, v in ckpt.items():
216
+ if 'aggregator' in key:
217
+ key_ = key.replace('aggregator', 'backend_transformer')
218
+ ls[key_] = key
219
+ for key_ in ls.keys():
220
+ key = ls[key_]
221
+ ckpt[key_] = ckpt[key]
222
+ '''
223
+ printer.info(
224
+ model.load_state_dict(ckpt, strict=False)
225
+ )
226
+ del ckpt# in case it occupies memory
227
+ '''
228
+ if freeze_keys is None:
229
+ freeze_keys = []
230
+
231
+ for name, param in model.named_parameters():
232
+ if 'backend_transformer' not in name:
233
+ freeze_keys.append(name)
234
+ '''
235
+ '''
236
+ printer.info("Loading teacher model")
237
+ ckpt_teacher = torch.load(args.teacher, map_location=device)
238
+ teacher.load_state_dict(ckpt_teacher, strict=True)
239
+ teacher = teacher.to("cuda")
240
+ for p in teacher.parameters():
241
+ p.requires_grad = False
242
+ teacher.eval()
243
+ del ckpt_teacher
244
+
245
+ '''
246
+ # freeze
247
+ printer.info("Freezing patch embedding and positional encoding parameters...")
248
+ frozen_params = 0
249
+ total_params = 0
250
+
251
+ frozen_param_names = []
252
+
253
+ for name, param in model.named_parameters():
254
+ total_params += param.numel()
255
+ param.requires_grad = True
256
+
257
+ if hasattr(model, 'encoder'):# and hasattr(model.aggregator, 'patch_embed'):
258
+ for param in model.encoder.parameters():#aggregator.patch_embed.parameters():
259
+ if param.requires_grad:
260
+ param.requires_grad = False
261
+
262
+ if hasattr(model, 'register_token'):
263
+ model.register_token.requires_grad = False
264
+
265
+ # YIJUN: Skip the freezekeys
266
+ '''
267
+ for name, param in model.named_parameters():
268
+ if 'camera_decoder' in name or 'camera_head' in name:
269
+ print(name)
270
+ param.requires_grad = False
271
+ '''
272
+
273
+ for name, p in model.named_parameters():
274
+ if not p.requires_grad:
275
+ frozen_params += p.numel()
276
+ frozen_param_names.append(name)
277
+
278
+ printer.info(
279
+ f"Frozen {frozen_params:,} parameters out of {total_params:,} total parameters. ({frozen_params / total_params:.2%})")
280
+ printer.info(
281
+ f"Trainable parameters: {total_params - frozen_params:,} ({(total_params - frozen_params) / total_params:.2%})")
282
+ if frozen_param_names:
283
+ printer.info(
284
+ f"Example frozen parameters: {', '.join(frozen_param_names[:5])}{'...' if len(frozen_param_names) > 5 else ''}")
285
+
286
+
287
+
288
+ # following timm: set wd as 0 for bias and norm layers
289
+ param_groups = misc.get_parameter_groups(model, args.weight_decay)
290
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
291
+ # print(optimizer)
292
+ loss_scaler = NativeScaler(accelerator=accelerator)
293
+
294
+ best_so_far = misc.load_model(
295
+ args=args, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler
296
+ )
297
+ if best_so_far is None:
298
+ best_so_far = float("inf")
299
+
300
+ accelerator.even_batches = False
301
+ optimizer, model, data_loader_train = accelerator.prepare(
302
+ optimizer, model, data_loader_train
303
+ )
304
+
305
+ def write_log_stats(epoch, train_stats, test_stats):
306
+ if accelerator.is_main_process:
307
+ if log_writer is not None:
308
+ log_writer.flush()
309
+
310
+ log_stats = dict(
311
+ epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()}
312
+ )
313
+ for test_name in data_loader_test:
314
+ if test_name not in test_stats:
315
+ continue
316
+ log_stats.update(
317
+ {test_name + "_" + k: v for k, v in test_stats[test_name].items()}
318
+ )
319
+
320
+ with open(
321
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
322
+ ) as f:
323
+ f.write(json.dumps(log_stats) + "\n")
324
+
325
+ def save_model(epoch, fname, best_so_far, data_iter_step):
326
+ misc.save_model(
327
+ accelerator=accelerator,
328
+ args=args,
329
+ model_without_ddp=model,
330
+ optimizer=optimizer,
331
+ loss_scaler=loss_scaler,
332
+ epoch=epoch,
333
+ step=data_iter_step,
334
+ fname=fname,
335
+ best_so_far=best_so_far,
336
+ )
337
+
338
+ log_writer = (
339
+ SummaryWriter(log_dir=args.output_dir) if accelerator.is_main_process else None
340
+ )
341
+
342
+ printer.info(f"Start training for {args.epochs} epochs")
343
+ start_time = time.time()
344
+ train_stats = test_stats = {}
345
+
346
+ for epoch in range(args.start_epoch, args.epochs + 1):
347
+
348
+ # Save immediately the last checkpoint
349
+ if epoch > args.start_epoch:
350
+ if (
351
+ args.save_freq
352
+ and np.allclose(epoch / args.save_freq, int(epoch / args.save_freq))
353
+ or epoch == args.epochs
354
+ ):
355
+ save_model(epoch - 1, "last", best_so_far, args.start_step)
356
+
357
+ new_best = False
358
+
359
+ if epoch > args.start_epoch:
360
+ if args.keep_freq and epoch % args.keep_freq == 0:
361
+ save_model(epoch - 1, str(epoch), best_so_far, args.start_step)
362
+ if new_best:
363
+ save_model(epoch - 1, "best", best_so_far, args.start_step)
364
+ if epoch >= args.epochs:
365
+ break # exit after writing last test to disk
366
+
367
+
368
+ # Train
369
+ train_stats = train_one_epoch(
370
+ model,
371
+ teacher,
372
+ train_criterion,
373
+ data_loader_train,
374
+ optimizer,
375
+ accelerator,
376
+ epoch,
377
+ loss_scaler,
378
+ log_writer=log_writer,
379
+ args=args
380
+ )
381
+
382
+
383
+ total_time = time.time() - start_time
384
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
385
+ printer.info("Training time {}".format(total_time_str))
386
+
387
+ save_final_model(accelerator, args, args.epochs, model, best_so_far=best_so_far)
388
+
389
+
390
+ def save_final_model(accelerator, args, epoch, model_without_ddp, best_so_far=None):
391
+ output_dir = Path(args.output_dir)
392
+ checkpoint_path = output_dir / "checkpoint-final.pth"
393
+ to_save = {
394
+ "args": args,
395
+ "model": (
396
+ model_without_ddp
397
+ if isinstance(model_without_ddp, dict)
398
+ else model_without_ddp.cpu().state_dict()
399
+ ),
400
+ "epoch": epoch,
401
+ }
402
+ if best_so_far is not None:
403
+ to_save["best_so_far"] = best_so_far
404
+ printer.info(f">> Saving model to {checkpoint_path} ...")
405
+ misc.save_on_master(accelerator, to_save, checkpoint_path)
406
+
407
+
408
+ def build_dataset(dataset, batch_size, num_workers, accelerator, test=False, fixed_length=False):
409
+ split = ["Train", "Test"][test]
410
+ printer.info(f"Building {split} Data loader for dataset: {dataset}")
411
+ loader = get_data_loader(
412
+ dataset,
413
+ batch_size=batch_size,
414
+ num_workers=num_workers,
415
+ pin_mem=True,
416
+ shuffle=not (test),
417
+ drop_last=not (test),
418
+ accelerator=accelerator,
419
+ fixed_length=fixed_length
420
+ )
421
+ return loader
422
+
423
+
424
+ def train_one_epoch(
425
+ model: torch.nn.Module,
426
+ teacher: torch.nn.Module,
427
+ criterion: torch.nn.Module,
428
+ data_loader: Sized,
429
+ optimizer: torch.optim.Optimizer,
430
+ accelerator: Accelerator,
431
+ epoch: int,
432
+ loss_scaler,
433
+ args,
434
+ log_writer=None,
435
+ ):
436
+ assert torch.backends.cuda.matmul.allow_tf32 == True
437
+
438
+ model.train(True)
439
+ metric_logger = misc.MetricLogger(delimiter=" ")
440
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
441
+ header = "Epoch: [{}]".format(epoch)
442
+ accum_iter = args.accum_iter
443
+
444
+ def save_model(epoch, fname, best_so_far, data_iter_step):
445
+ unwrapped_model = accelerator.unwrap_model(model)
446
+ misc.save_model(
447
+ accelerator=accelerator,
448
+ args=args,
449
+ model_without_ddp=unwrapped_model,
450
+ optimizer=optimizer,
451
+ loss_scaler=loss_scaler,
452
+ epoch=epoch,
453
+ step=data_iter_step,
454
+ fname=fname,
455
+ best_so_far=best_so_far,
456
+ )
457
+
458
+ if log_writer is not None:
459
+ printer.info("log_dir: {}".format(log_writer.log_dir))
460
+
461
+ if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
462
+ data_loader.dataset.set_epoch(epoch)
463
+ if (
464
+ hasattr(data_loader, "batch_sampler")
465
+ and hasattr(data_loader.batch_sampler, "batch_sampler")
466
+ and hasattr(data_loader.batch_sampler.batch_sampler, "set_epoch")
467
+ ):
468
+ data_loader.batch_sampler.batch_sampler.set_epoch(epoch)
469
+
470
+
471
+ optimizer.zero_grad()
472
+
473
+ start_step = args.start_step
474
+
475
+ data_iter = metric_logger.log_every(data_loader, args.print_freq, accelerator, header)
476
+
477
+ for data_iter_step, batch in enumerate(data_iter):
478
+
479
+ with accelerator.accumulate(model):
480
+ # change the range of the image to [0, 1]
481
+ if isinstance(batch, dict) and "img" in batch:
482
+ batch["img"] = (batch["img"] + 1.0) / 2.0
483
+ elif isinstance(batch, list) and all(isinstance(v, dict) and "img" in v for v in batch):
484
+ for view in batch:
485
+ view["img"] = (view["img"] + 1.0) / 2.0
486
+
487
+ epoch_f = epoch + data_iter_step / len(data_loader)
488
+ # we use a per iteration (instead of per epoch) lr scheduler
489
+ if data_iter_step % accum_iter == 0:
490
+ misc.adjust_learning_rate(optimizer, epoch_f, args)
491
+
492
+ epoch_f = epoch + data_iter_step / len(data_loader)
493
+ step = int(epoch_f * len(data_loader))
494
+
495
+ result = loss_of_one_batch(
496
+ batch,
497
+ model,
498
+ criterion,
499
+ accelerator,
500
+ teacher=teacher,
501
+ inference=False,
502
+ symmetrize_batch=False,
503
+ use_amp=bool(args.amp),
504
+ )
505
+
506
+ loss, loss_details = result["loss"] # criterion returns two values
507
+
508
+ loss_value = float(loss)
509
+
510
+ if not math.isfinite(loss_value):
511
+ print(
512
+ f"Loss is {loss_value}, stopping training, loss details: {loss_details}"
513
+ )
514
+ sys.exit(1)
515
+ if not result.get("already_backprop", False):
516
+ loss_scaler(
517
+ loss,
518
+ optimizer,
519
+ parameters=model.parameters(),
520
+ update_grad=True,
521
+ clip_grad=1.0,
522
+ )
523
+ optimizer.zero_grad()
524
+
525
+ is_metric = batch[0]["is_metric"]
526
+ curr_num_view = len(batch)
527
+
528
+ del loss
529
+
530
+ tb_vis_img = (data_iter_step + 1) % accum_iter == 0 and (
531
+ (step + 1) % (args.print_img_freq)
532
+ ) == 0
533
+ if not tb_vis_img:
534
+ del batch
535
+ else:
536
+ torch.cuda.empty_cache()
537
+
538
+ lr = optimizer.param_groups[0]["lr"]
539
+ metric_logger.update(epoch=epoch_f)
540
+ metric_logger.update(lr=lr)
541
+ metric_logger.update(step=step)
542
+ #
543
+ metric_logger.update(loss=loss_value, **loss_details)
544
+ #
545
+ if (data_iter_step + 1) % accum_iter == 0 and (
546
+ (data_iter_step + 1) % (accum_iter * args.print_freq)
547
+ ) == 0:
548
+ loss_value_reduce = accelerator.gather(
549
+ torch.tensor(loss_value).to(accelerator.device)
550
+ ).mean() # MUST BE EXECUTED BY ALL NODES
551
+
552
+ if log_writer is None:
553
+ continue
554
+ """ We use epoch_1000x as the x-axis in tensorboard.
555
+ This calibrates different curves when batch size changes.
556
+ """
557
+ epoch_1000x = int(epoch_f * 1000)
558
+ log_writer.add_scalar("train_loss", loss_value_reduce, step)
559
+ log_writer.add_scalar("train_lr", lr, step)
560
+ log_writer.add_scalar("train_iter", epoch_1000x, step)
561
+ for name, val in loss_details.items():
562
+ if isinstance(val, torch.Tensor):
563
+ if val.ndim > 0:
564
+ continue
565
+ if isinstance(val, dict):
566
+ continue
567
+ log_writer.add_scalar("train_" + name, val, step)
568
+
569
+ if (
570
+ data_iter_step % int(args.save_freq * len(data_loader)) == 0
571
+ and data_iter_step != 0
572
+ and data_iter_step != len(data_loader) - 1
573
+ ):
574
+ print("saving at step", data_iter_step)
575
+ save_model(epoch - 1, "last", float("inf"), data_iter_step)
576
+
577
+ # gather the stats from all processes
578
+ metric_logger.synchronize_between_processes(accelerator)
579
+ printer.info("Averaged stats: %s", metric_logger)
580
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
581
+
582
+ def batch_append(original_list, new_list):
583
+ for sublist, new_item in zip(original_list, new_list):
584
+ sublist.append(new_item)
585
+ return original_list
586
+
587
+
588
+ @hydra.main(
589
+ version_base=None,
590
+ config_path=str(os.path.dirname(os.path.abspath(__file__))) + "/../config",
591
+ config_name="mytrain.yaml",
592
+ )
593
+ def run(cfg: OmegaConf):
594
+ OmegaConf.resolve(cfg)
595
+ logdir = pathlib.Path(cfg.logdir)
596
+ logdir.mkdir(parents=True, exist_ok=True)
597
+ train(cfg)
598
+
599
+
600
+ if __name__ == "__main__":
601
+ run()
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/camera_head.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from streamvggt.layers import Mlp
9
+ from streamvggt.layers.block import Block
10
+ from streamvggt.heads.head_act import activate_pose
11
+
12
+
13
+ class CameraHead(nn.Module):
14
+ def __init__(
15
+ self,
16
+ dim_in: int = 2048,
17
+ trunk_depth: int = 4,
18
+ pose_encoding_type: str = "absT_quaR_FoV",
19
+ num_heads: int = 16,
20
+ mlp_ratio: int = 4,
21
+ init_values: float = 0.01,
22
+ trans_act: str = "linear",
23
+ quat_act: str = "linear",
24
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
25
+ ):
26
+ super().__init__()
27
+
28
+ if pose_encoding_type == "absT_quaR_FoV":
29
+ self.target_dim = 9
30
+ else:
31
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
32
+
33
+ self.trans_act = trans_act
34
+ self.quat_act = quat_act
35
+ self.fl_act = fl_act
36
+ self.trunk_depth = trunk_depth
37
+
38
+ # Build the trunk using a sequence of transformer blocks.
39
+ self.trunk = nn.Sequential(
40
+ *[
41
+ Block(
42
+ dim=dim_in,
43
+ num_heads=num_heads,
44
+ mlp_ratio=mlp_ratio,
45
+ init_values=init_values,
46
+ )
47
+ for _ in range(trunk_depth)
48
+ ]
49
+ )
50
+
51
+ # Normalizations for camera token and trunk output.
52
+ self.token_norm = nn.LayerNorm(dim_in)
53
+ self.trunk_norm = nn.LayerNorm(dim_in)
54
+
55
+ # Learnable empty camera pose token.
56
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
57
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
58
+
59
+ # Module for producing modulation parameters: shift, scale, and a gate.
60
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
61
+
62
+ # Adaptive layer normalization without affine parameters.
63
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
64
+ self.pose_branch = Mlp(
65
+ in_features=dim_in,
66
+ hidden_features=dim_in // 2,
67
+ out_features=self.target_dim,
68
+ drop=0,
69
+ )
70
+
71
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, past_key_values_camera = None, use_cache: bool = False) -> list:
72
+ """
73
+ Forward pass to predict camera parameters.
74
+
75
+ Args:
76
+ aggregated_tokens_list (list): List of token tensors from the network;
77
+ the last tensor is used for prediction.
78
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
79
+
80
+ Returns:
81
+ list: A list of predicted camera encodings (post-activation) from each iteration.
82
+ """
83
+ # Use tokens from the last block for camera prediction.
84
+ tokens = aggregated_tokens_list[-1]
85
+
86
+ # Extract the camera tokens
87
+ pose_tokens = tokens[:, :, 0]
88
+ pose_tokens = self.token_norm(pose_tokens)
89
+
90
+ if use_cache:
91
+ pred_pose_enc_list, past_key_values_camera = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera, use_cache)
92
+ return pred_pose_enc_list, past_key_values_camera
93
+ else:
94
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera=None, use_cache=use_cache)
95
+ return pred_pose_enc_list
96
+
97
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int, past_key_values_camera, use_cache: bool) -> list:
98
+ """
99
+ Iteratively refine camera pose predictions.
100
+
101
+ Args:
102
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
103
+ num_iterations (int): Number of refinement iterations.
104
+
105
+ Returns:
106
+ list: List of activated camera encodings from each iteration.
107
+ """
108
+ B, S, C = pose_tokens.shape # S is expected to be 1.
109
+ pred_pose_enc = None
110
+ pred_pose_enc_list = []
111
+
112
+ for _ in range(num_iterations):
113
+ # Use a learned empty pose for the first iteration.
114
+ if pred_pose_enc is None:
115
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
116
+ else:
117
+ # Detach the previous prediction to avoid backprop through time.
118
+ pred_pose_enc = pred_pose_enc.detach()
119
+ module_input = self.embed_pose(pred_pose_enc)
120
+
121
+ # Generate modulation parameters and split them into shift, scale, and gate components.
122
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
123
+
124
+ # Adaptive layer normalization and modulation.
125
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
126
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
127
+
128
+ if not use_cache:
129
+ L = S * 1
130
+ frame_ids = torch.arange(L, device=pose_tokens_modulated.device) // 1 # [0,0,...,1,1,...,S-1]
131
+ future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
132
+ attn_mask = future_frame.to(pose_tokens_modulated.dtype) * torch.finfo(pose_tokens_modulated.dtype).min
133
+ else:
134
+ attn_mask = None
135
+
136
+ if use_cache:
137
+ for idx in range(self.trunk_depth):
138
+ pose_tokens_modulated, block_kv = self.trunk[idx](
139
+ pose_tokens_modulated,
140
+ attn_mask=attn_mask,
141
+ past_key_values=past_key_values_camera[idx] if past_key_values_camera[idx] is not None else None,
142
+ use_cache=True
143
+ )
144
+ past_key_values_camera[idx] = block_kv
145
+ else:
146
+ for idx in range(self.trunk_depth):
147
+ pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, attn_mask=attn_mask)
148
+
149
+ # Compute the delta update for the pose encoding.
150
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
151
+
152
+ if pred_pose_enc is None:
153
+ pred_pose_enc = pred_pose_enc_delta
154
+ else:
155
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
156
+
157
+ # Apply final activation functions for translation, quaternion, and field-of-view.
158
+ activated_pose = activate_pose(
159
+ pred_pose_enc,
160
+ trans_act=self.trans_act,
161
+ quat_act=self.quat_act,
162
+ fl_act=self.fl_act,
163
+ )
164
+ pred_pose_enc_list.append(activated_pose)
165
+
166
+ if use_cache:
167
+ return pred_pose_enc_list, past_key_values_camera
168
+ return pred_pose_enc_list
169
+
170
+
171
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
172
+ """
173
+ Modulate the input tensor using scaling and shifting parameters.
174
+ """
175
+ return x * (1 + scale) + shift
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/dpt_head.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from .head_act import activate_head
8
+ from .utils import create_uv_grid, position_grid_to_embed
9
+
10
+
11
+ class DPTHead(nn.Module):
12
+ """
13
+ Args:
14
+ dim_in (int): Input dimension (channels).
15
+ patch_size (int, optional): Patch size. Default is 14.
16
+ output_dim (int, optional): Number of output channels. Default is 4.
17
+ activation (str, optional): Activation type. Default is "inv_log".
18
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
19
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
20
+ out_channels (List[int], optional): Output channels for each intermediate layer.
21
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
22
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
23
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
24
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ dim_in: int,
30
+ patch_size: int = 14,
31
+ output_dim: int = 4,
32
+ activation: str = "inv_log",
33
+ conf_activation: str = "expp1",
34
+ features: int = 256,
35
+ out_channels: List[int] = [256, 512, 1024, 1024],
36
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
37
+ pos_embed: bool = True,
38
+ feature_only: bool = False,
39
+ down_ratio: int = 1,
40
+ ) -> None:
41
+ super(DPTHead, self).__init__()
42
+ self.patch_size = patch_size
43
+ self.activation = activation
44
+ self.conf_activation = conf_activation
45
+ self.pos_embed = pos_embed
46
+ self.feature_only = feature_only
47
+ self.down_ratio = down_ratio
48
+ self.intermediate_layer_idx = intermediate_layer_idx
49
+
50
+ self.norm = nn.LayerNorm(dim_in)
51
+
52
+ # Projection layers for each output channel from tokens.
53
+ self.projects = nn.ModuleList(
54
+ [
55
+ nn.Conv2d(
56
+ in_channels=dim_in,
57
+ out_channels=oc,
58
+ kernel_size=1,
59
+ stride=1,
60
+ padding=0,
61
+ )
62
+ for oc in out_channels
63
+ ]
64
+ )
65
+
66
+ # Resize layers for upsampling feature maps.
67
+ self.resize_layers = nn.ModuleList(
68
+ [
69
+ nn.ConvTranspose2d(
70
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
71
+ ),
72
+ nn.ConvTranspose2d(
73
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
74
+ ),
75
+ nn.Identity(),
76
+ nn.Conv2d(
77
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
78
+ ),
79
+ ]
80
+ )
81
+
82
+ self.scratch = _make_scratch(
83
+ out_channels,
84
+ features,
85
+ expand=False,
86
+ )
87
+
88
+ # Attach additional modules to scratch.
89
+ self.scratch.stem_transpose = None
90
+ self.scratch.refinenet1 = _make_fusion_block(features)
91
+ self.scratch.refinenet2 = _make_fusion_block(features)
92
+ self.scratch.refinenet3 = _make_fusion_block(features)
93
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
94
+
95
+ head_features_1 = features
96
+ head_features_2 = 32
97
+
98
+ if feature_only:
99
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
100
+ else:
101
+ self.scratch.output_conv1 = nn.Conv2d(
102
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
103
+ )
104
+ conv2_in_channels = head_features_1 // 2
105
+
106
+ self.scratch.output_conv2 = nn.Sequential(
107
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
108
+ nn.ReLU(inplace=True),
109
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ aggregated_tokens_list: List[torch.Tensor],
115
+ images: torch.Tensor,
116
+ patch_start_idx: int,
117
+ frames_chunk_size: int = 8,
118
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
119
+ """
120
+ Forward pass through the DPT head, supports processing by chunking frames.
121
+ Args:
122
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
123
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
124
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
125
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
126
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
127
+ If None or larger than S, all frames are processed at once. Default: 8.
128
+
129
+ Returns:
130
+ Tensor or Tuple[Tensor, Tensor]:
131
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
132
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
133
+ """
134
+ B, S, _, H, W = images.shape
135
+
136
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
137
+ if frames_chunk_size is None or frames_chunk_size >= S:
138
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
139
+
140
+ # Otherwise, process frames in chunks to manage memory usage
141
+ assert frames_chunk_size > 0
142
+
143
+ # Process frames in batches
144
+ all_preds = []
145
+ all_conf = []
146
+
147
+ for frames_start_idx in range(0, S, frames_chunk_size):
148
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
149
+
150
+ # Process batch of frames
151
+ if self.feature_only:
152
+ chunk_output = self._forward_impl(
153
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
154
+ )
155
+ all_preds.append(chunk_output)
156
+ else:
157
+ chunk_preds, chunk_conf = self._forward_impl(
158
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
159
+ )
160
+ all_preds.append(chunk_preds)
161
+ all_conf.append(chunk_conf)
162
+
163
+ # Concatenate results along the sequence dimension
164
+ if self.feature_only:
165
+ return torch.cat(all_preds, dim=1)
166
+ else:
167
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
168
+
169
+ def _forward_impl(
170
+ self,
171
+ aggregated_tokens_list: List[torch.Tensor],
172
+ images: torch.Tensor,
173
+ patch_start_idx: int,
174
+ frames_start_idx: int = None,
175
+ frames_end_idx: int = None,
176
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
177
+ """
178
+ Args:
179
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
180
+ images (Tensor): Input images with shape [B, S, 3, H, W].
181
+ patch_start_idx (int): Starting index for patch tokens.
182
+ frames_start_idx (int, optional): Starting index for frames to process.
183
+ frames_end_idx (int, optional): Ending index for frames to process.
184
+
185
+ Returns:
186
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
187
+ """
188
+ if frames_start_idx is not None and frames_end_idx is not None:
189
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
190
+
191
+ B, S, _, H, W = images.shape
192
+
193
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
194
+
195
+ out = []
196
+ dpt_idx = 0
197
+
198
+ for layer_idx in self.intermediate_layer_idx:
199
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
200
+
201
+ # Select frames if processing a chunk
202
+ if frames_start_idx is not None and frames_end_idx is not None:
203
+ x = x[:, frames_start_idx:frames_end_idx]
204
+
205
+ x = x.reshape(B * S, -1, x.shape[-1])
206
+
207
+ x = self.norm(x)
208
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
209
+
210
+ x = self.projects[dpt_idx](x)
211
+ if self.pos_embed:
212
+ x = self._apply_pos_embed(x, W, H)
213
+ x = self.resize_layers[dpt_idx](x)
214
+
215
+ out.append(x)
216
+ dpt_idx += 1
217
+
218
+ # Fuse features from multiple layers.
219
+ out = self.scratch_forward(out)
220
+ # Interpolate fused output to match target image resolution.
221
+ out = custom_interpolate(
222
+ out,
223
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
224
+ mode="bilinear",
225
+ align_corners=True,
226
+ )
227
+
228
+ if self.pos_embed:
229
+ out = self._apply_pos_embed(out, W, H)
230
+
231
+ if self.feature_only:
232
+ return out.reshape(B, S, *out.shape[1:])
233
+
234
+ out = self.scratch.output_conv2(out)
235
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
236
+
237
+ preds = preds.reshape(B, S, *preds.shape[1:])
238
+ conf = conf.reshape(B, S, *conf.shape[1:])
239
+ return preds, conf
240
+
241
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
242
+ """
243
+ Apply positional embedding to tensor x.
244
+ """
245
+ patch_w = x.shape[-1]
246
+ patch_h = x.shape[-2]
247
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
248
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
249
+ pos_embed = pos_embed * ratio
250
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
251
+ return x + pos_embed
252
+
253
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
254
+ """
255
+ Forward pass through the fusion blocks.
256
+
257
+ Args:
258
+ features (List[Tensor]): List of feature maps from different layers.
259
+
260
+ Returns:
261
+ Tensor: Fused feature map.
262
+ """
263
+ layer_1, layer_2, layer_3, layer_4 = features
264
+
265
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
266
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
267
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
268
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
269
+
270
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
271
+ del layer_4_rn, layer_4
272
+
273
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
274
+ del layer_3_rn, layer_3
275
+
276
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
277
+ del layer_2_rn, layer_2
278
+
279
+ out = self.scratch.refinenet1(out, layer_1_rn)
280
+ del layer_1_rn, layer_1
281
+
282
+ out = self.scratch.output_conv1(out)
283
+ return out
284
+
285
+
286
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
287
+ return FeatureFusionBlock(
288
+ features,
289
+ nn.ReLU(inplace=True),
290
+ deconv=False,
291
+ bn=False,
292
+ expand=False,
293
+ align_corners=True,
294
+ size=size,
295
+ has_residual=has_residual,
296
+ groups=groups,
297
+ )
298
+
299
+
300
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
301
+ scratch = nn.Module()
302
+ out_shape1 = out_shape
303
+ out_shape2 = out_shape
304
+ out_shape3 = out_shape
305
+ if len(in_shape) >= 4:
306
+ out_shape4 = out_shape
307
+
308
+ if expand:
309
+ out_shape1 = out_shape
310
+ out_shape2 = out_shape * 2
311
+ out_shape3 = out_shape * 4
312
+ if len(in_shape) >= 4:
313
+ out_shape4 = out_shape * 8
314
+
315
+ scratch.layer1_rn = nn.Conv2d(
316
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
317
+ )
318
+ scratch.layer2_rn = nn.Conv2d(
319
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
320
+ )
321
+ scratch.layer3_rn = nn.Conv2d(
322
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
323
+ )
324
+ if len(in_shape) >= 4:
325
+ scratch.layer4_rn = nn.Conv2d(
326
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
327
+ )
328
+ return scratch
329
+
330
+
331
+ class ResidualConvUnit(nn.Module):
332
+ """Residual convolution module."""
333
+
334
+ def __init__(self, features, activation, bn, groups=1):
335
+ """Init.
336
+
337
+ Args:
338
+ features (int): number of features
339
+ """
340
+ super().__init__()
341
+
342
+ self.bn = bn
343
+ self.groups = groups
344
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
345
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
346
+
347
+ self.norm1 = None
348
+ self.norm2 = None
349
+
350
+ self.activation = activation
351
+ self.skip_add = nn.quantized.FloatFunctional()
352
+
353
+ def forward(self, x):
354
+ """Forward pass.
355
+
356
+ Args:
357
+ x (tensor): input
358
+
359
+ Returns:
360
+ tensor: output
361
+ """
362
+
363
+ out = self.activation(x)
364
+ out = self.conv1(out)
365
+ if self.norm1 is not None:
366
+ out = self.norm1(out)
367
+
368
+ out = self.activation(out)
369
+ out = self.conv2(out)
370
+ if self.norm2 is not None:
371
+ out = self.norm2(out)
372
+
373
+ return self.skip_add.add(out, x)
374
+
375
+
376
+ class FeatureFusionBlock(nn.Module):
377
+ """Feature fusion block."""
378
+
379
+ def __init__(
380
+ self,
381
+ features,
382
+ activation,
383
+ deconv=False,
384
+ bn=False,
385
+ expand=False,
386
+ align_corners=True,
387
+ size=None,
388
+ has_residual=True,
389
+ groups=1,
390
+ ):
391
+ """Init.
392
+
393
+ Args:
394
+ features (int): number of features
395
+ """
396
+ super(FeatureFusionBlock, self).__init__()
397
+
398
+ self.deconv = deconv
399
+ self.align_corners = align_corners
400
+ self.groups = groups
401
+ self.expand = expand
402
+ out_features = features
403
+ if self.expand == True:
404
+ out_features = features // 2
405
+
406
+ self.out_conv = nn.Conv2d(
407
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
408
+ )
409
+
410
+ if has_residual:
411
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
412
+
413
+ self.has_residual = has_residual
414
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
415
+
416
+ self.skip_add = nn.quantized.FloatFunctional()
417
+ self.size = size
418
+
419
+ def forward(self, *xs, size=None):
420
+ """Forward pass.
421
+
422
+ Returns:
423
+ tensor: output
424
+ """
425
+ output = xs[0]
426
+
427
+ if self.has_residual:
428
+ res = self.resConfUnit1(xs[1])
429
+ output = self.skip_add.add(output, res)
430
+
431
+ output = self.resConfUnit2(output)
432
+
433
+ if (size is None) and (self.size is None):
434
+ modifier = {"scale_factor": 2}
435
+ elif size is None:
436
+ modifier = {"size": self.size}
437
+ else:
438
+ modifier = {"size": size}
439
+
440
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
441
+ output = self.out_conv(output)
442
+
443
+ return output
444
+
445
+
446
+ def custom_interpolate(
447
+ x: torch.Tensor,
448
+ size: Tuple[int, int] = None,
449
+ scale_factor: float = None,
450
+ mode: str = "bilinear",
451
+ align_corners: bool = True,
452
+ ) -> torch.Tensor:
453
+ """
454
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
455
+ """
456
+ if size is None:
457
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
458
+
459
+ INT_MAX = 1610612736
460
+
461
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
462
+
463
+ if input_elements > INT_MAX:
464
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
465
+ interpolated_chunks = [
466
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
467
+ ]
468
+ x = torch.cat(interpolated_chunks, dim=0)
469
+ return x.contiguous()
470
+ else:
471
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/head_act.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
6
+ """
7
+ Args:
8
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
9
+ trans_act: Activation type for translation component
10
+ quat_act: Activation type for quaternion component
11
+ fl_act: Activation type for focal length component
12
+
13
+ Returns:
14
+ Activated pose parameters tensor
15
+ """
16
+ T = pred_pose_enc[..., :3]
17
+ quat = pred_pose_enc[..., 3:7]
18
+ fl = pred_pose_enc[..., 7:] # or fov
19
+
20
+ T = base_pose_act(T, trans_act)
21
+ quat = base_pose_act(quat, quat_act)
22
+ fl = base_pose_act(fl, fl_act) # or fov
23
+
24
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
25
+
26
+ return pred_pose_enc
27
+
28
+
29
+ def base_pose_act(pose_enc, act_type="linear"):
30
+ """
31
+ Apply basic activation function to pose parameters.
32
+
33
+ Args:
34
+ pose_enc: Tensor containing encoded pose parameters
35
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
36
+
37
+ Returns:
38
+ Activated pose parameters
39
+ """
40
+ if act_type == "linear":
41
+ return pose_enc
42
+ elif act_type == "inv_log":
43
+ return inverse_log_transform(pose_enc)
44
+ elif act_type == "exp":
45
+ return torch.exp(pose_enc)
46
+ elif act_type == "relu":
47
+ return F.relu(pose_enc)
48
+ else:
49
+ raise ValueError(f"Unknown act_type: {act_type}")
50
+
51
+
52
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
53
+ """
54
+ Process network output to extract 3D points and confidence values.
55
+
56
+ Args:
57
+ out: Network output tensor (B, C, H, W)
58
+ activation: Activation type for 3D points
59
+ conf_activation: Activation type for confidence values
60
+
61
+ Returns:
62
+ Tuple of (3D points tensor, confidence tensor)
63
+ """
64
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
65
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
66
+
67
+ # Split into xyz (first C-1 channels) and confidence (last channel)
68
+ xyz = fmap[:, :, :, :-1]
69
+ conf = fmap[:, :, :, -1]
70
+
71
+ if activation == "norm_exp":
72
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
73
+ xyz_normed = xyz / d
74
+ pts3d = xyz_normed * torch.expm1(d)
75
+ elif activation == "norm":
76
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
77
+ elif activation == "exp":
78
+ pts3d = torch.exp(xyz)
79
+ elif activation == "relu":
80
+ pts3d = F.relu(xyz)
81
+ elif activation == "inv_log":
82
+ pts3d = inverse_log_transform(xyz)
83
+ elif activation == "xy_inv_log":
84
+ xy, z = xyz.split([2, 1], dim=-1)
85
+ z = inverse_log_transform(z)
86
+ pts3d = torch.cat([xy * z, z], dim=-1)
87
+ elif activation == "sigmoid":
88
+ pts3d = torch.sigmoid(xyz)
89
+ elif activation == "linear":
90
+ pts3d = xyz
91
+ else:
92
+ raise ValueError(f"Unknown activation: {activation}")
93
+
94
+ if conf_activation == "expp1":
95
+ conf_out = 1 + conf.exp()
96
+ elif conf_activation == "expp0":
97
+ conf_out = conf.exp()
98
+ elif conf_activation == "sigmoid":
99
+ conf_out = torch.sigmoid(conf)
100
+ else:
101
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
102
+
103
+ return pts3d, conf_out
104
+
105
+
106
+ def inverse_log_transform(y):
107
+ """
108
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
109
+
110
+ Args:
111
+ y: Input tensor
112
+
113
+ Returns:
114
+ Transformed tensor
115
+ """
116
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_head.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .dpt_head import DPTHead
3
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
4
+
5
+
6
+ class TrackHead(nn.Module):
7
+ """
8
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
9
+ The tracking is performed iteratively, refining predictions over multiple iterations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ dim_in,
15
+ patch_size=14,
16
+ features=128,
17
+ iters=4,
18
+ predict_conf=True,
19
+ stride=2,
20
+ corr_levels=7,
21
+ corr_radius=4,
22
+ hidden_size=384,
23
+ ):
24
+ """
25
+ Initialize the TrackHead module.
26
+
27
+ Args:
28
+ dim_in (int): Input dimension of tokens from the backbone.
29
+ patch_size (int): Size of image patches used in the vision transformer.
30
+ features (int): Number of feature channels in the feature extractor output.
31
+ iters (int): Number of refinement iterations for tracking predictions.
32
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
33
+ stride (int): Stride value for the tracker predictor.
34
+ corr_levels (int): Number of correlation pyramid levels
35
+ corr_radius (int): Radius for correlation computation, controlling the search area.
36
+ hidden_size (int): Size of hidden layers in the tracker network.
37
+ """
38
+ super().__init__()
39
+
40
+ self.patch_size = patch_size
41
+
42
+ # Feature extractor based on DPT architecture
43
+ # Processes tokens into feature maps for tracking
44
+ self.feature_extractor = DPTHead(
45
+ dim_in=dim_in,
46
+ patch_size=patch_size,
47
+ features=features,
48
+ feature_only=True, # Only output features, no activation
49
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
50
+ pos_embed=False,
51
+ )
52
+
53
+ # Tracker module that predicts point trajectories
54
+ # Takes feature maps and predicts coordinates and visibility
55
+ self.tracker = BaseTrackerPredictor(
56
+ latent_dim=features, # Match the output_dim of feature extractor
57
+ predict_conf=predict_conf,
58
+ stride=stride,
59
+ corr_levels=corr_levels,
60
+ corr_radius=corr_radius,
61
+ hidden_size=hidden_size,
62
+ )
63
+
64
+ self.iters = iters
65
+
66
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
67
+ """
68
+ Forward pass of the TrackHead.
69
+
70
+ Args:
71
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
72
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
73
+ B = batch size, S = sequence length.
74
+ patch_start_idx (int): Starting index for patch tokens.
75
+ query_points (torch.Tensor, optional): Initial query points to track.
76
+ If None, points are initialized by the tracker.
77
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
78
+
79
+ Returns:
80
+ tuple:
81
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
82
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
83
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
84
+ """
85
+ B, S, _, H, W = images.shape
86
+
87
+ # Extract features from tokens
88
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
89
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
90
+
91
+ # Use default iterations if not specified
92
+ if iters is None:
93
+ iters = self.iters
94
+
95
+ # Perform tracking using the extracted features
96
+ coord_preds, vis_scores, conf_scores = self.tracker(
97
+ query_points=query_points,
98
+ fmaps=feature_maps,
99
+ iters=iters,
100
+ )
101
+
102
+ return coord_preds, vis_scores, conf_scores
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/__init__.py ADDED
File without changes
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange, repeat
4
+
5
+
6
+ from .blocks import EfficientUpdateFormer, CorrBlock
7
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
8
+ from .modules import Mlp
9
+
10
+
11
+ class BaseTrackerPredictor(nn.Module):
12
+ def __init__(
13
+ self,
14
+ stride=1,
15
+ corr_levels=5,
16
+ corr_radius=4,
17
+ latent_dim=128,
18
+ hidden_size=384,
19
+ use_spaceatt=True,
20
+ depth=6,
21
+ max_scale=518,
22
+ predict_conf=True,
23
+ ):
24
+ super(BaseTrackerPredictor, self).__init__()
25
+ self.stride = stride
26
+ self.latent_dim = latent_dim
27
+ self.corr_levels = corr_levels
28
+ self.corr_radius = corr_radius
29
+ self.hidden_size = hidden_size
30
+ self.max_scale = max_scale
31
+ self.predict_conf = predict_conf
32
+
33
+ self.flows_emb_dim = latent_dim // 2
34
+
35
+ self.corr_mlp = Mlp(
36
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
37
+ hidden_features=self.hidden_size,
38
+ out_features=self.latent_dim,
39
+ )
40
+
41
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
42
+
43
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
44
+
45
+ space_depth = depth if use_spaceatt else 0
46
+ time_depth = depth
47
+
48
+ self.updateformer = EfficientUpdateFormer(
49
+ space_depth=space_depth,
50
+ time_depth=time_depth,
51
+ input_dim=self.transformer_dim,
52
+ hidden_size=self.hidden_size,
53
+ output_dim=self.latent_dim + 2,
54
+ mlp_ratio=4.0,
55
+ add_space_attn=use_spaceatt,
56
+ )
57
+
58
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
59
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
60
+
61
+ # A linear layer to update track feats at each iteration
62
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
63
+
64
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
65
+
66
+ if predict_conf:
67
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
68
+
69
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
70
+ """
71
+ query_points: B x N x 2, the number of batches, tracks, and xy
72
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
73
+ note HH and WW is the size of feature maps instead of original images
74
+ """
75
+ B, N, D = query_points.shape
76
+ B, S, C, HH, WW = fmaps.shape
77
+
78
+ assert D == 2, "Input points must be 2D coordinates"
79
+
80
+ # apply a layernorm to fmaps here
81
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
82
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
83
+
84
+ # Scale the input query_points because we may downsample the images
85
+ # by down_ratio or self.stride
86
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
87
+ # its query_points should be query_points/4
88
+ if down_ratio > 1:
89
+ query_points = query_points / float(down_ratio)
90
+
91
+ query_points = query_points / float(self.stride)
92
+
93
+ # Init with coords as the query points
94
+ # It means the search will start from the position of query points at the reference frames
95
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
96
+
97
+ # Sample/extract the features of the query points in the query frame
98
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
99
+
100
+ # init track feats by query feats
101
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
102
+ # back up the init coords
103
+ coords_backup = coords.clone()
104
+
105
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
106
+
107
+ coord_preds = []
108
+
109
+ # Iterative Refinement
110
+ for _ in range(iters):
111
+ # Detach the gradients from the last iteration
112
+ # (in my experience, not very important for performance)
113
+ coords = coords.detach()
114
+
115
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
116
+
117
+ corr_dim = fcorrs.shape[3]
118
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
119
+ fcorrs_ = self.corr_mlp(fcorrs_)
120
+
121
+ # Movement of current coords relative to query points
122
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
123
+
124
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
125
+
126
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
127
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
128
+
129
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
130
+
131
+ # Concatenate them as the input for the transformers
132
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
133
+
134
+ # 2D positional embed
135
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
136
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
137
+
138
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
139
+
140
+ x = transformer_input + sampled_pos_emb
141
+
142
+ # Add the query ref token to the track feats
143
+ query_ref_token = torch.cat(
144
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
145
+ )
146
+ x = x + query_ref_token.to(x.device).to(x.dtype)
147
+
148
+ # B, N, S, C
149
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
150
+
151
+ # Compute the delta coordinates and delta track features
152
+ delta, _ = self.updateformer(x)
153
+
154
+ # BN, S, C
155
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
156
+ delta_coords_ = delta[:, :, :2]
157
+ delta_feats_ = delta[:, :, 2:]
158
+
159
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
160
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
161
+
162
+ # Update the track features
163
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
164
+
165
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
166
+
167
+ # B x S x N x 2
168
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
169
+
170
+ # Force coord0 as query
171
+ # because we assume the query points should not be changed
172
+ coords[:, 0] = coords_backup[:, 0]
173
+
174
+ # The predicted tracks are in the original image scale
175
+ if down_ratio > 1:
176
+ coord_preds.append(coords * self.stride * down_ratio)
177
+ else:
178
+ coord_preds.append(coords * self.stride)
179
+
180
+ # B, S, N
181
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
182
+ if apply_sigmoid:
183
+ vis_e = torch.sigmoid(vis_e)
184
+
185
+ if self.predict_conf:
186
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
187
+ if apply_sigmoid:
188
+ conf_e = torch.sigmoid(conf_e)
189
+ else:
190
+ conf_e = None
191
+
192
+ if return_feat:
193
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
194
+ else:
195
+ return coord_preds, vis_e, conf_e
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .utils import bilinear_sampler
7
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
8
+
9
+
10
+ class EfficientUpdateFormer(nn.Module):
11
+ """
12
+ Transformer model that updates track estimates.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ space_depth=6,
18
+ time_depth=6,
19
+ input_dim=320,
20
+ hidden_size=384,
21
+ num_heads=8,
22
+ output_dim=130,
23
+ mlp_ratio=4.0,
24
+ add_space_attn=True,
25
+ num_virtual_tracks=64,
26
+ ):
27
+ super().__init__()
28
+
29
+ self.out_channels = 2
30
+ self.num_heads = num_heads
31
+ self.hidden_size = hidden_size
32
+ self.add_space_attn = add_space_attn
33
+
34
+ # Add input LayerNorm before linear projection
35
+ self.input_norm = nn.LayerNorm(input_dim)
36
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
37
+
38
+ # Add output LayerNorm before final projection
39
+ self.output_norm = nn.LayerNorm(hidden_size)
40
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
41
+ self.num_virtual_tracks = num_virtual_tracks
42
+
43
+ if self.add_space_attn:
44
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
45
+ else:
46
+ self.virual_tracks = None
47
+
48
+ self.time_blocks = nn.ModuleList(
49
+ [
50
+ AttnBlock(
51
+ hidden_size,
52
+ num_heads,
53
+ mlp_ratio=mlp_ratio,
54
+ attn_class=nn.MultiheadAttention,
55
+ )
56
+ for _ in range(time_depth)
57
+ ]
58
+ )
59
+
60
+ if add_space_attn:
61
+ self.space_virtual_blocks = nn.ModuleList(
62
+ [
63
+ AttnBlock(
64
+ hidden_size,
65
+ num_heads,
66
+ mlp_ratio=mlp_ratio,
67
+ attn_class=nn.MultiheadAttention,
68
+ )
69
+ for _ in range(space_depth)
70
+ ]
71
+ )
72
+ self.space_point2virtual_blocks = nn.ModuleList(
73
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
74
+ )
75
+ self.space_virtual2point_blocks = nn.ModuleList(
76
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
77
+ )
78
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
79
+ self.initialize_weights()
80
+
81
+ def initialize_weights(self):
82
+ def _basic_init(module):
83
+ if isinstance(module, nn.Linear):
84
+ torch.nn.init.xavier_uniform_(module.weight)
85
+ if module.bias is not None:
86
+ nn.init.constant_(module.bias, 0)
87
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
88
+
89
+ self.apply(_basic_init)
90
+
91
+ def forward(self, input_tensor, mask=None):
92
+ # Apply input LayerNorm
93
+ input_tensor = self.input_norm(input_tensor)
94
+ tokens = self.input_transform(input_tensor)
95
+
96
+ init_tokens = tokens
97
+
98
+ B, _, T, _ = tokens.shape
99
+
100
+ if self.add_space_attn:
101
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
102
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
103
+
104
+ _, N, _, _ = tokens.shape
105
+
106
+ j = 0
107
+ for i in range(len(self.time_blocks)):
108
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
109
+
110
+ time_tokens = self.time_blocks[i](time_tokens)
111
+
112
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
113
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
114
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
115
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
116
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
117
+
118
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
119
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
120
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
121
+
122
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
123
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
124
+ j += 1
125
+
126
+ if self.add_space_attn:
127
+ tokens = tokens[:, : N - self.num_virtual_tracks]
128
+
129
+ tokens = tokens + init_tokens
130
+
131
+ # Apply output LayerNorm before final projection
132
+ tokens = self.output_norm(tokens)
133
+ flow = self.flow_head(tokens)
134
+
135
+ return flow, None
136
+
137
+
138
+ class CorrBlock:
139
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
140
+ """
141
+ Build a pyramid of feature maps from the input.
142
+
143
+ fmaps: Tensor (B, S, C, H, W)
144
+ num_levels: number of pyramid levels (each downsampled by factor 2)
145
+ radius: search radius for sampling correlation
146
+ multiple_track_feats: if True, split the target features per pyramid level
147
+ padding_mode: passed to grid_sample / bilinear_sampler
148
+ """
149
+ B, S, C, H, W = fmaps.shape
150
+ self.S, self.C, self.H, self.W = S, C, H, W
151
+ self.num_levels = num_levels
152
+ self.radius = radius
153
+ self.padding_mode = padding_mode
154
+ self.multiple_track_feats = multiple_track_feats
155
+
156
+ # Build pyramid: each level is half the spatial resolution of the previous
157
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
158
+ current_fmaps = fmaps
159
+ for i in range(num_levels - 1):
160
+ B, S, C, H, W = current_fmaps.shape
161
+ # Merge batch & sequence dimensions
162
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
163
+ # Avg pool down by factor 2
164
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
165
+ _, _, H_new, W_new = current_fmaps.shape
166
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
167
+ self.fmaps_pyramid.append(current_fmaps)
168
+
169
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
170
+ # This grid is added to the (scaled) coordinate centroids.
171
+ r = self.radius
172
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
173
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
174
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
175
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
176
+
177
+ def corr_sample(self, targets, coords):
178
+ """
179
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
180
+ volume, sample it immediately, then discard it. This saves GPU memory.
181
+
182
+ Args:
183
+ targets: Tensor (B, S, N, C) — features for the current targets.
184
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
185
+
186
+ Returns:
187
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
188
+ """
189
+ B, S, N, C = targets.shape
190
+
191
+ # If you have multiple track features, split them per level.
192
+ if self.multiple_track_feats:
193
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
194
+
195
+ out_pyramid = []
196
+ for i, fmaps in enumerate(self.fmaps_pyramid):
197
+ # Get current spatial resolution H, W for this pyramid level.
198
+ B, S, C, H, W = fmaps.shape
199
+ # Reshape feature maps for correlation computation:
200
+ # fmap2s: (B, S, C, H*W)
201
+ fmap2s = fmaps.view(B, S, C, H * W)
202
+ # Choose appropriate target features.
203
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
204
+
205
+ # Compute correlation directly
206
+ corrs = compute_corr_level(fmap1, fmap2s, C)
207
+ corrs = corrs.view(B, S, N, H, W)
208
+
209
+ # Prepare sampling grid:
210
+ # Scale down the coordinates for the current level.
211
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
212
+ # Make sure our precomputed delta grid is on the same device/dtype.
213
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
214
+ # Now the grid for grid_sample is:
215
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
216
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
217
+
218
+ # Sample from the correlation volume using bilinear interpolation.
219
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
220
+ corrs_sampled = bilinear_sampler(
221
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
222
+ )
223
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
224
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
225
+ out_pyramid.append(corrs_sampled)
226
+
227
+ # Concatenate all levels along the last dimension.
228
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
229
+ return out
230
+
231
+
232
+ def compute_corr_level(fmap1, fmap2s, C):
233
+ # fmap1: (B, S, N, C)
234
+ # fmap2s: (B, S, C, H*W)
235
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
236
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
237
+ return corrs / math.sqrt(C)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+ from typing import Callable
6
+ import collections
7
+ from torch import Tensor
8
+ from itertools import repeat
9
+
10
+
11
+ # From PyTorch internals
12
+ def _ntuple(n):
13
+ def parse(x):
14
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
15
+ return tuple(x)
16
+ return tuple(repeat(x, n))
17
+
18
+ return parse
19
+
20
+
21
+ def exists(val):
22
+ return val is not None
23
+
24
+
25
+ def default(val, d):
26
+ return val if exists(val) else d
27
+
28
+
29
+ to_2tuple = _ntuple(2)
30
+
31
+
32
+ class ResidualBlock(nn.Module):
33
+ """
34
+ ResidualBlock: construct a block of two conv layers with residual connections
35
+ """
36
+
37
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
38
+ super(ResidualBlock, self).__init__()
39
+
40
+ self.conv1 = nn.Conv2d(
41
+ in_planes,
42
+ planes,
43
+ kernel_size=kernel_size,
44
+ padding=1,
45
+ stride=stride,
46
+ padding_mode="zeros",
47
+ )
48
+ self.conv2 = nn.Conv2d(
49
+ planes,
50
+ planes,
51
+ kernel_size=kernel_size,
52
+ padding=1,
53
+ padding_mode="zeros",
54
+ )
55
+ self.relu = nn.ReLU(inplace=True)
56
+
57
+ num_groups = planes // 8
58
+
59
+ if norm_fn == "group":
60
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
61
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
62
+ if not stride == 1:
63
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
64
+
65
+ elif norm_fn == "batch":
66
+ self.norm1 = nn.BatchNorm2d(planes)
67
+ self.norm2 = nn.BatchNorm2d(planes)
68
+ if not stride == 1:
69
+ self.norm3 = nn.BatchNorm2d(planes)
70
+
71
+ elif norm_fn == "instance":
72
+ self.norm1 = nn.InstanceNorm2d(planes)
73
+ self.norm2 = nn.InstanceNorm2d(planes)
74
+ if not stride == 1:
75
+ self.norm3 = nn.InstanceNorm2d(planes)
76
+
77
+ elif norm_fn == "none":
78
+ self.norm1 = nn.Sequential()
79
+ self.norm2 = nn.Sequential()
80
+ if not stride == 1:
81
+ self.norm3 = nn.Sequential()
82
+ else:
83
+ raise NotImplementedError
84
+
85
+ if stride == 1:
86
+ self.downsample = None
87
+ else:
88
+ self.downsample = nn.Sequential(
89
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
90
+ self.norm3,
91
+ )
92
+
93
+ def forward(self, x):
94
+ y = x
95
+ y = self.relu(self.norm1(self.conv1(y)))
96
+ y = self.relu(self.norm2(self.conv2(y)))
97
+
98
+ if self.downsample is not None:
99
+ x = self.downsample(x)
100
+
101
+ return self.relu(x + y)
102
+
103
+
104
+ class Mlp(nn.Module):
105
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
106
+
107
+ def __init__(
108
+ self,
109
+ in_features,
110
+ hidden_features=None,
111
+ out_features=None,
112
+ act_layer=nn.GELU,
113
+ norm_layer=None,
114
+ bias=True,
115
+ drop=0.0,
116
+ use_conv=False,
117
+ ):
118
+ super().__init__()
119
+ out_features = out_features or in_features
120
+ hidden_features = hidden_features or in_features
121
+ bias = to_2tuple(bias)
122
+ drop_probs = to_2tuple(drop)
123
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
124
+
125
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
126
+ self.act = act_layer()
127
+ self.drop1 = nn.Dropout(drop_probs[0])
128
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
129
+ self.drop2 = nn.Dropout(drop_probs[1])
130
+
131
+ def forward(self, x):
132
+ x = self.fc1(x)
133
+ x = self.act(x)
134
+ x = self.drop1(x)
135
+ x = self.fc2(x)
136
+ x = self.drop2(x)
137
+ return x
138
+
139
+
140
+ class AttnBlock(nn.Module):
141
+ def __init__(
142
+ self,
143
+ hidden_size,
144
+ num_heads,
145
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
146
+ mlp_ratio=4.0,
147
+ **block_kwargs
148
+ ):
149
+ """
150
+ Self attention block
151
+ """
152
+ super().__init__()
153
+
154
+ self.norm1 = nn.LayerNorm(hidden_size)
155
+ self.norm2 = nn.LayerNorm(hidden_size)
156
+
157
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
158
+
159
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
160
+
161
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
162
+
163
+ def forward(self, x, mask=None):
164
+ # Prepare the mask for PyTorch's attention (it expects a different format)
165
+ # attn_mask = mask if mask is not None else None
166
+ # Normalize before attention
167
+ x = self.norm1(x)
168
+
169
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
170
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
171
+
172
+ attn_output, _ = self.attn(x, x, x)
173
+
174
+ # Add & Norm
175
+ x = x + attn_output
176
+ x = x + self.mlp(self.norm2(x))
177
+ return x
178
+
179
+
180
+ class CrossAttnBlock(nn.Module):
181
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
182
+ """
183
+ Cross attention block
184
+ """
185
+ super().__init__()
186
+
187
+ self.norm1 = nn.LayerNorm(hidden_size)
188
+ self.norm_context = nn.LayerNorm(hidden_size)
189
+ self.norm2 = nn.LayerNorm(hidden_size)
190
+
191
+ self.cross_attn = nn.MultiheadAttention(
192
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
193
+ )
194
+
195
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
196
+
197
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
198
+
199
+ def forward(self, x, context, mask=None):
200
+ # Normalize inputs
201
+ x = self.norm1(x)
202
+ context = self.norm_context(context)
203
+
204
+ # Apply cross attention
205
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
206
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
207
+
208
+ # Add & Norm
209
+ x = x + attn_output
210
+ x = x + self.mlp(self.norm2(x))
211
+ return x
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Optional, Tuple, Union
6
+
7
+
8
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
9
+ """
10
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
11
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
12
+ Args:
13
+ - embed_dim: The embedding dimension.
14
+ - grid_size: The grid size.
15
+ Returns:
16
+ - pos_embed: The generated 2D positional embedding.
17
+ """
18
+ if isinstance(grid_size, tuple):
19
+ grid_size_h, grid_size_w = grid_size
20
+ else:
21
+ grid_size_h = grid_size_w = grid_size
22
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
23
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
24
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
25
+ grid = torch.stack(grid, dim=0)
26
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
27
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
28
+ if return_grid:
29
+ return (
30
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
31
+ grid,
32
+ )
33
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
34
+
35
+
36
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - grid: The grid to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 2D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+
49
+ # use half of dimensions to encode grid_h
50
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
51
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
52
+
53
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
54
+ return emb
55
+
56
+
57
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
58
+ """
59
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
60
+
61
+ Args:
62
+ - embed_dim: The embedding dimension.
63
+ - pos: The position to generate the embedding from.
64
+
65
+ Returns:
66
+ - emb: The generated 1D positional embedding.
67
+ """
68
+ assert embed_dim % 2 == 0
69
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
70
+ omega /= embed_dim / 2.0
71
+ omega = 1.0 / 10000**omega # (D/2,)
72
+
73
+ pos = pos.reshape(-1) # (M,)
74
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
75
+
76
+ emb_sin = torch.sin(out) # (M, D/2)
77
+ emb_cos = torch.cos(out) # (M, D/2)
78
+
79
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
80
+ return emb[None].float()
81
+
82
+
83
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
84
+ """
85
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
86
+
87
+ Args:
88
+ - xy: The coordinates to generate the embedding from.
89
+ - C: The size of the embedding.
90
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
91
+
92
+ Returns:
93
+ - pe: The generated 2D positional embedding.
94
+ """
95
+ B, N, D = xy.shape
96
+ assert D == 2
97
+
98
+ x = xy[:, :, 0:1]
99
+ y = xy[:, :, 1:2]
100
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
101
+
102
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
103
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
104
+
105
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
106
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
107
+
108
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
109
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
110
+
111
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
112
+ if cat_coords:
113
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
114
+ return pe
115
+
116
+
117
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
118
+ r"""Sample a tensor using bilinear interpolation
119
+
120
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
121
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
122
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
123
+ convention.
124
+
125
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
126
+ :math:`B` is the batch size, :math:`C` is the number of channels,
127
+ :math:`H` is the height of the image, and :math:`W` is the width of the
128
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
129
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
130
+
131
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
132
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
133
+ that in this case the order of the components is slightly different
134
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
135
+
136
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
137
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
138
+ left-most image pixel :math:`W-1` to the center of the right-most
139
+ pixel.
140
+
141
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
142
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
143
+ the left-most pixel :math:`W` to the right edge of the right-most
144
+ pixel.
145
+
146
+ Similar conventions apply to the :math:`y` for the range
147
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
148
+ :math:`[0,T-1]` and :math:`[0,T]`.
149
+
150
+ Args:
151
+ input (Tensor): batch of input images.
152
+ coords (Tensor): batch of coordinates.
153
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
154
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
155
+
156
+ Returns:
157
+ Tensor: sampled points.
158
+ """
159
+ coords = coords.detach().clone()
160
+ ############################################################
161
+ # IMPORTANT:
162
+ coords = coords.to(input.device).to(input.dtype)
163
+ ############################################################
164
+
165
+ sizes = input.shape[2:]
166
+
167
+ assert len(sizes) in [2, 3]
168
+
169
+ if len(sizes) == 3:
170
+ # t x y -> x y t to match dimensions T H W in grid_sample
171
+ coords = coords[..., [1, 2, 0]]
172
+
173
+ if align_corners:
174
+ scale = torch.tensor(
175
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
176
+ )
177
+ else:
178
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
179
+
180
+ coords.mul_(scale) # coords = coords * scale
181
+ coords.sub_(1) # coords = coords - 1
182
+
183
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
184
+
185
+
186
+ def sample_features4d(input, coords):
187
+ r"""Sample spatial features
188
+
189
+ `sample_features4d(input, coords)` samples the spatial features
190
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
191
+
192
+ The field is sampled at coordinates :attr:`coords` using bilinear
193
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
194
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
195
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
196
+
197
+ The output tensor has one feature per point, and has shape :math:`(B,
198
+ R, C)`.
199
+
200
+ Args:
201
+ input (Tensor): spatial features.
202
+ coords (Tensor): points.
203
+
204
+ Returns:
205
+ Tensor: sampled features.
206
+ """
207
+
208
+ B, _, _, _ = input.shape
209
+
210
+ # B R 2 -> B R 1 2
211
+ coords = coords.unsqueeze(2)
212
+
213
+ # B C R 1
214
+ feats = bilinear_sampler(input, coords)
215
+
216
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/heads/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
6
+ """
7
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
8
+
9
+ Args:
10
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
11
+ embed_dim: Output channel dimension for embeddings
12
+
13
+ Returns:
14
+ Tensor of shape (H, W, embed_dim) with positional embeddings
15
+ """
16
+ H, W, grid_dim = pos_grid.shape
17
+ assert grid_dim == 2
18
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
19
+
20
+ # Process x and y coordinates separately
21
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
22
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
23
+
24
+ # Combine and reshape
25
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
26
+
27
+ return emb.view(H, W, embed_dim) # [H, W, D]
28
+
29
+
30
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
31
+ """
32
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
33
+
34
+ Args:
35
+ - embed_dim: The embedding dimension.
36
+ - pos: The position to generate the embedding from.
37
+
38
+ Returns:
39
+ - emb: The generated 1D positional embedding.
40
+ """
41
+ assert embed_dim % 2 == 0
42
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
43
+ omega /= embed_dim / 2.0
44
+ omega = 1.0 / omega_0**omega # (D/2,)
45
+
46
+ pos = pos.reshape(-1) # (M,)
47
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
48
+
49
+ emb_sin = torch.sin(out) # (M, D/2)
50
+ emb_cos = torch.cos(out) # (M, D/2)
51
+
52
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
53
+ return emb.float()
54
+
55
+
56
+ def create_uv_grid(
57
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
58
+ ) -> torch.Tensor:
59
+ """
60
+ Create a normalized UV grid of shape (width, height, 2).
61
+
62
+ The grid spans horizontally and vertically according to an aspect ratio,
63
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
64
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
65
+
66
+ Args:
67
+ width (int): Number of points horizontally.
68
+ height (int): Number of points vertically.
69
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
70
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
71
+ device (torch.device, optional): Device on which the tensor is created.
72
+
73
+ Returns:
74
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
75
+ """
76
+ # Derive aspect ratio if not explicitly provided
77
+ if aspect_ratio is None:
78
+ aspect_ratio = float(width) / float(height)
79
+
80
+ # Compute normalized spans for X and Y
81
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
82
+ span_x = aspect_ratio / diag_factor
83
+ span_y = 1.0 / diag_factor
84
+
85
+ # Establish the linspace boundaries
86
+ left_x = -span_x * (width - 1) / width
87
+ right_x = span_x * (width - 1) / width
88
+ top_y = -span_y * (height - 1) / height
89
+ bottom_y = span_y * (height - 1) / height
90
+
91
+ # Generate 1D coordinates
92
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
93
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
94
+
95
+ # Create 2D meshgrid (width x height) and stack into UV
96
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
97
+ uv_grid = torch.stack((uu, vv), dim=-1)
98
+
99
+ return uv_grid
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .mlp import Mlp
2
+ from .patch_embed import PatchEmbed
3
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
4
+ from .block import NestedTensorBlock
5
+ from .attention import MemEffAttention
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/attention.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import warnings
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from typing import Union, Tuple, Dict, Optional
10
+
11
+ from einops import rearrange
12
+
13
+ XFORMERS_AVAILABLE = False
14
+
15
+
16
+ class Attention(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ num_heads: int = 8,
21
+ qkv_bias: bool = True,
22
+ proj_bias: bool = True,
23
+ attn_drop: float = 0.0,
24
+ proj_drop: float = 0.0,
25
+ norm_layer: nn.Module = nn.LayerNorm,
26
+ qk_norm: bool = False,
27
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
28
+ rope=None,
29
+ ) -> None:
30
+ super().__init__()
31
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
32
+ self.num_heads = num_heads
33
+ self.head_dim = dim // num_heads
34
+ self.scale = self.head_dim**-0.5
35
+ self.fused_attn = fused_attn
36
+
37
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
38
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
39
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
40
+ self.attn_drop = nn.Dropout(attn_drop)
41
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
42
+ self.proj_drop = nn.Dropout(proj_drop)
43
+ self.rope = rope
44
+
45
+ def forward(self,
46
+ x: torch.Tensor,
47
+ pos=None,
48
+ attn_mask=None,
49
+ past_key_values=None,
50
+ use_cache=False
51
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]:
52
+ B, N, C = x.shape
53
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
54
+ q, k, v = qkv.unbind(0)
55
+
56
+ pos_k = pos
57
+ if use_cache:
58
+ k = k.unsqueeze(2)
59
+ v = v.unsqueeze(2)
60
+ if past_key_values is not None:
61
+ past_k, past_v = past_key_values
62
+ k = torch.cat([past_k, k], dim=2)
63
+ v = torch.cat([past_v, v], dim=2)
64
+
65
+ new_kv = (k, v)
66
+ a, b, c, d, e = k.shape
67
+ k = k.reshape(a, b, c*d, e)
68
+ v = v.reshape(a, b, c*d, e)
69
+ if pos_k is not None:
70
+ #print(pos_k.shape)
71
+ pos_k = pos_k.repeat(1, c, 1)
72
+ #print(pos_k.shape)
73
+
74
+ q, k = self.q_norm(q), self.k_norm(k)
75
+
76
+ if self.rope is not None:
77
+ q = self.rope(q, pos)
78
+ k = self.rope(k, pos_k)
79
+
80
+ if self.fused_attn:
81
+ x = F.scaled_dot_product_attention(
82
+ q,
83
+ k,
84
+ v,
85
+ attn_mask=attn_mask,
86
+ dropout_p=self.attn_drop.p if self.training else 0.0,
87
+ )
88
+
89
+ else:
90
+ q = q * self.scale
91
+ attn = q @ k.transpose(-2, -1)
92
+
93
+ # Mask
94
+ if attn_mask is not None:
95
+ assert attn_mask.shape[-2:] == (N, N), f"Expected mask shape [..., {N}, {N}], got {attn_mask.shape}"
96
+ attn = attn + attn_mask
97
+
98
+ attn = attn.softmax(dim=-1)
99
+ attn = self.attn_drop(attn)
100
+ x = attn @ v
101
+
102
+ x = x.transpose(1, 2).reshape(B, N, C)
103
+ x = self.proj(x)
104
+ x = self.proj_drop(x)
105
+ if use_cache:
106
+ return x, new_kv
107
+ return x
108
+
109
+
110
+ class MemEffAttention(Attention):
111
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
112
+ assert pos is None
113
+ if not XFORMERS_AVAILABLE:
114
+ if attn_bias is not None:
115
+ raise AssertionError("xFormers is required for using nested tensors")
116
+ return super().forward(x)
117
+
118
+ B, N, C = x.shape
119
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
120
+
121
+ q, k, v = unbind(qkv, 2)
122
+
123
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
124
+ x = x.reshape([B, N, C])
125
+
126
+ x = self.proj(x)
127
+ x = self.proj_drop(x)
128
+
129
+ return x
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/block.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Callable, List, Any, Tuple, Dict, Union
4
+ import warnings
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+
9
+ from .attention import Attention
10
+ from .drop_path import DropPath
11
+ from .layer_scale import LayerScale
12
+ from .mlp import Mlp
13
+
14
+
15
+ XFORMERS_AVAILABLE = False
16
+
17
+
18
+ class Block(nn.Module):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ num_heads: int,
23
+ mlp_ratio: float = 4.0,
24
+ qkv_bias: bool = True,
25
+ proj_bias: bool = True,
26
+ ffn_bias: bool = True,
27
+ drop: float = 0.0,
28
+ attn_drop: float = 0.0,
29
+ init_values=None,
30
+ drop_path: float = 0.0,
31
+ act_layer: Callable[..., nn.Module] = nn.GELU,
32
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
33
+ attn_class: Callable[..., nn.Module] = Attention,
34
+ ffn_layer: Callable[..., nn.Module] = Mlp,
35
+ qk_norm: bool = False,
36
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
37
+ rope=None,
38
+ ) -> None:
39
+ super().__init__()
40
+
41
+ self.norm1 = norm_layer(dim)
42
+
43
+ self.attn = attn_class(
44
+ dim,
45
+ num_heads=num_heads,
46
+ qkv_bias=qkv_bias,
47
+ proj_bias=proj_bias,
48
+ attn_drop=attn_drop,
49
+ proj_drop=drop,
50
+ qk_norm=qk_norm,
51
+ fused_attn=fused_attn,
52
+ rope=rope,
53
+ )
54
+
55
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
56
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
57
+
58
+ self.norm2 = norm_layer(dim)
59
+ mlp_hidden_dim = int(dim * mlp_ratio)
60
+ self.mlp = ffn_layer(
61
+ in_features=dim,
62
+ hidden_features=mlp_hidden_dim,
63
+ act_layer=act_layer,
64
+ drop=drop,
65
+ bias=ffn_bias,
66
+ )
67
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
68
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
69
+
70
+ self.sample_drop_ratio = drop_path
71
+
72
+ def forward(self, x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
73
+
74
+ def attn_residual_func(x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
75
+ if use_cache:
76
+ output, new_kv = self.attn(self.norm1(x), pos=pos, past_key_values=past_key_values, use_cache=True)
77
+ return self.ls1(output), new_kv
78
+ else:
79
+ if attn_mask is not None:
80
+ return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
81
+ else:
82
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
83
+ def ffn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls2(self.mlp(self.norm2(x)))
85
+
86
+ if use_cache:
87
+ attn_output, new_kv = attn_residual_func(x, pos=pos, past_key_values=past_key_values, use_cache=True)
88
+ x = x + attn_output
89
+ x = x + ffn_residual_func(x)
90
+ return x, new_kv
91
+
92
+ if self.training and self.sample_drop_ratio > 0.1:
93
+ # the overhead is compensated only for a drop path rate larger than 0.1
94
+ x = drop_add_residual_stochastic_depth(
95
+ x,
96
+ pos=pos,
97
+ residual_func=attn_residual_func,
98
+ sample_drop_ratio=self.sample_drop_ratio,
99
+ )
100
+ x = drop_add_residual_stochastic_depth(
101
+ x,
102
+ residual_func=ffn_residual_func,
103
+ sample_drop_ratio=self.sample_drop_ratio,
104
+ )
105
+ elif self.training and self.sample_drop_ratio > 0.0:
106
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask))
107
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
108
+ else:
109
+ x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
110
+ x = x + ffn_residual_func(x)
111
+ return x
112
+
113
+
114
+ def drop_add_residual_stochastic_depth(
115
+ x: Tensor,
116
+ residual_func: Callable[[Tensor], Tensor],
117
+ sample_drop_ratio: float = 0.0,
118
+ pos=None,
119
+ ) -> Tensor:
120
+ # 1) extract subset using permutation
121
+ b, n, d = x.shape
122
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
123
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
124
+ x_subset = x[brange]
125
+
126
+ # 2) apply residual_func to get residual
127
+ if pos is not None:
128
+ # if necessary, apply rope to the subset
129
+ pos = pos[brange]
130
+ residual = residual_func(x_subset, pos=pos)
131
+ else:
132
+ residual = residual_func(x_subset)
133
+
134
+ x_flat = x.flatten(1)
135
+ residual = residual.flatten(1)
136
+
137
+ residual_scale_factor = b / sample_subset_size
138
+
139
+ # 3) add the residual
140
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
141
+ return x_plus_residual.view_as(x)
142
+
143
+
144
+ def get_branges_scales(x, sample_drop_ratio=0.0):
145
+ b, n, d = x.shape
146
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
147
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
148
+ residual_scale_factor = b / sample_subset_size
149
+ return brange, residual_scale_factor
150
+
151
+
152
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
153
+ if scaling_vector is None:
154
+ x_flat = x.flatten(1)
155
+ residual = residual.flatten(1)
156
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
157
+ else:
158
+ x_plus_residual = scaled_index_add(
159
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
160
+ )
161
+ return x_plus_residual
162
+
163
+
164
+ attn_bias_cache: Dict[Tuple, Any] = {}
165
+
166
+
167
+ def get_attn_bias_and_cat(x_list, branges=None):
168
+ """
169
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
170
+ """
171
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
172
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
173
+ if all_shapes not in attn_bias_cache.keys():
174
+ seqlens = []
175
+ for b, x in zip(batch_sizes, x_list):
176
+ for _ in range(b):
177
+ seqlens.append(x.shape[1])
178
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
179
+ attn_bias._batch_sizes = batch_sizes
180
+ attn_bias_cache[all_shapes] = attn_bias
181
+
182
+ if branges is not None:
183
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
184
+ else:
185
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
186
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
187
+
188
+ return attn_bias_cache[all_shapes], cat_tensors
189
+
190
+
191
+ def drop_add_residual_stochastic_depth_list(
192
+ x_list: List[Tensor],
193
+ residual_func: Callable[[Tensor, Any], Tensor],
194
+ sample_drop_ratio: float = 0.0,
195
+ scaling_vector=None,
196
+ ) -> Tensor:
197
+ # 1) generate random set of indices for dropping samples in the batch
198
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
199
+ branges = [s[0] for s in branges_scales]
200
+ residual_scale_factors = [s[1] for s in branges_scales]
201
+
202
+ # 2) get attention bias and index+concat the tensors
203
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
204
+
205
+ # 3) apply residual_func to get residual, and split the result
206
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
207
+
208
+ outputs = []
209
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
210
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
211
+ return outputs
212
+
213
+
214
+ class NestedTensorBlock(Block):
215
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
216
+ """
217
+ x_list contains a list of tensors to nest together and run
218
+ """
219
+ assert isinstance(self.attn, MemEffAttention)
220
+
221
+ if self.training and self.sample_drop_ratio > 0.0:
222
+
223
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
225
+
226
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
227
+ return self.mlp(self.norm2(x))
228
+
229
+ x_list = drop_add_residual_stochastic_depth_list(
230
+ x_list,
231
+ residual_func=attn_residual_func,
232
+ sample_drop_ratio=self.sample_drop_ratio,
233
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
234
+ )
235
+ x_list = drop_add_residual_stochastic_depth_list(
236
+ x_list,
237
+ residual_func=ffn_residual_func,
238
+ sample_drop_ratio=self.sample_drop_ratio,
239
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
240
+ )
241
+ return x_list
242
+ else:
243
+
244
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
246
+
247
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
248
+ return self.ls2(self.mlp(self.norm2(x)))
249
+
250
+ attn_bias, x = get_attn_bias_and_cat(x_list)
251
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
252
+ x = x + ffn_residual_func(x)
253
+ return attn_bias.split(x)
254
+
255
+ def forward(self, x_or_x_list):
256
+ if isinstance(x_or_x_list, Tensor):
257
+ return super().forward(x_or_x_list)
258
+ elif isinstance(x_or_x_list, list):
259
+ if not XFORMERS_AVAILABLE:
260
+ raise AssertionError("xFormers is required for using nested tensors")
261
+ return self.forward_nested(x_or_x_list)
262
+ else:
263
+ raise AssertionError
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/drop_path.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
5
+ if drop_prob == 0.0 or not training:
6
+ return x
7
+ keep_prob = 1 - drop_prob
8
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
9
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
10
+ if keep_prob > 0.0:
11
+ random_tensor.div_(keep_prob)
12
+ output = x * random_tensor
13
+ return output
14
+
15
+
16
+ class DropPath(nn.Module):
17
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
18
+
19
+ def __init__(self, drop_prob=None):
20
+ super(DropPath, self).__init__()
21
+ self.drop_prob = drop_prob
22
+
23
+ def forward(self, x):
24
+ return drop_path(x, self.drop_prob, self.training)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/layer_scale.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch import nn
6
+
7
+
8
+ class LayerScale(nn.Module):
9
+ def __init__(
10
+ self,
11
+ dim: int,
12
+ init_values: Union[float, Tensor] = 1e-5,
13
+ inplace: bool = False,
14
+ ) -> None:
15
+ super().__init__()
16
+ self.inplace = inplace
17
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
18
+
19
+ def forward(self, x: Tensor) -> Tensor:
20
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/mlp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from torch import Tensor, nn
4
+
5
+
6
+ class Mlp(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features: int,
10
+ hidden_features: Optional[int] = None,
11
+ out_features: Optional[int] = None,
12
+ act_layer: Callable[..., nn.Module] = nn.GELU,
13
+ drop: float = 0.0,
14
+ bias: bool = True,
15
+ ) -> None:
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x: Tensor) -> Tensor:
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/patch_embed.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple, Union
2
+
3
+ from torch import Tensor
4
+ import torch.nn as nn
5
+
6
+
7
+ def make_2tuple(x):
8
+ if isinstance(x, tuple):
9
+ assert len(x) == 2
10
+ return x
11
+
12
+ assert isinstance(x, int)
13
+ return (x, x)
14
+
15
+
16
+ class PatchEmbed(nn.Module):
17
+ """
18
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
19
+
20
+ Args:
21
+ img_size: Image size.
22
+ patch_size: Patch token size.
23
+ in_chans: Number of input image channels.
24
+ embed_dim: Number of linear projection output channels.
25
+ norm_layer: Normalization layer.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ img_size: Union[int, Tuple[int, int]] = 224,
31
+ patch_size: Union[int, Tuple[int, int]] = 16,
32
+ in_chans: int = 3,
33
+ embed_dim: int = 768,
34
+ norm_layer: Optional[Callable] = None,
35
+ flatten_embedding: bool = True,
36
+ ) -> None:
37
+ super().__init__()
38
+
39
+ image_HW = make_2tuple(img_size)
40
+ patch_HW = make_2tuple(patch_size)
41
+ patch_grid_size = (
42
+ image_HW[0] // patch_HW[0],
43
+ image_HW[1] // patch_HW[1],
44
+ )
45
+
46
+ self.img_size = image_HW
47
+ self.patch_size = patch_HW
48
+ self.patches_resolution = patch_grid_size
49
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
50
+
51
+ self.in_chans = in_chans
52
+ self.embed_dim = embed_dim
53
+
54
+ self.flatten_embedding = flatten_embedding
55
+
56
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
57
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ _, _, H, W = x.shape
61
+ patch_H, patch_W = self.patch_size
62
+
63
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
64
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
65
+
66
+ x = self.proj(x) # B C H W
67
+ H, W = x.size(2), x.size(3)
68
+ x = x.flatten(2).transpose(1, 2) # B HW C
69
+ x = self.norm(x)
70
+ if not self.flatten_embedding:
71
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
72
+ return x
73
+
74
+ def flops(self) -> float:
75
+ Ho, Wo = self.patches_resolution
76
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
77
+ if self.norm is not None:
78
+ flops += Ho * Wo * self.embed_dim
79
+ return flops
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/rope.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Dict, Tuple
6
+
7
+
8
+ class PositionGetter:
9
+ """Generates and caches 2D spatial positions for patches in a grid.
10
+
11
+ This class efficiently manages the generation of spatial coordinates for patches
12
+ in a 2D grid, caching results to avoid redundant computations.
13
+
14
+ Attributes:
15
+ position_cache: Dictionary storing precomputed position tensors for different
16
+ grid dimensions.
17
+ """
18
+
19
+ def __init__(self):
20
+ """Initializes the position generator with an empty cache."""
21
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
22
+
23
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
24
+ """Generates spatial positions for a batch of patches.
25
+
26
+ Args:
27
+ batch_size: Number of samples in the batch.
28
+ height: Height of the grid in patches.
29
+ width: Width of the grid in patches.
30
+ device: Target device for the position tensor.
31
+
32
+ Returns:
33
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
34
+ for each position in the grid, repeated for each batch item.
35
+ """
36
+ if (height, width) not in self.position_cache:
37
+ y_coords = torch.arange(height, device=device)
38
+ x_coords = torch.arange(width, device=device)
39
+ positions = torch.cartesian_prod(y_coords, x_coords)
40
+ self.position_cache[height, width] = positions
41
+
42
+ cached_positions = self.position_cache[height, width]
43
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
44
+
45
+
46
+ class RotaryPositionEmbedding2D(nn.Module):
47
+ """2D Rotary Position Embedding implementation.
48
+
49
+ This module applies rotary position embeddings to input tokens based on their
50
+ 2D spatial positions. It handles the position-dependent rotation of features
51
+ separately for vertical and horizontal dimensions.
52
+
53
+ Args:
54
+ frequency: Base frequency for the position embeddings. Default: 100.0
55
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
56
+
57
+ Attributes:
58
+ base_frequency: Base frequency for computing position embeddings.
59
+ scaling_factor: Factor to scale the computed frequencies.
60
+ frequency_cache: Cache for storing precomputed frequency components.
61
+ """
62
+
63
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
64
+ """Initializes the 2D RoPE module."""
65
+ super().__init__()
66
+ self.base_frequency = frequency
67
+ self.scaling_factor = scaling_factor
68
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
69
+
70
+ def _compute_frequency_components(
71
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
72
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
73
+ """Computes frequency components for rotary embeddings.
74
+
75
+ Args:
76
+ dim: Feature dimension (must be even).
77
+ seq_len: Maximum sequence length.
78
+ device: Target device for computations.
79
+ dtype: Data type for the computed tensors.
80
+
81
+ Returns:
82
+ Tuple of (cosine, sine) tensors for frequency components.
83
+ """
84
+ cache_key = (dim, seq_len, device, dtype)
85
+ if cache_key not in self.frequency_cache:
86
+ # Compute frequency bands
87
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
88
+ inv_freq = 1.0 / (self.base_frequency**exponents)
89
+
90
+ # Generate position-dependent frequencies
91
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
92
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
93
+
94
+ # Compute and cache frequency components
95
+ angles = angles.to(dtype)
96
+ angles = torch.cat((angles, angles), dim=-1)
97
+ cos_components = angles.cos().to(dtype)
98
+ sin_components = angles.sin().to(dtype)
99
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
100
+
101
+ return self.frequency_cache[cache_key]
102
+
103
+ @staticmethod
104
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
105
+ """Performs feature rotation by splitting and recombining feature dimensions.
106
+
107
+ Args:
108
+ x: Input tensor to rotate.
109
+
110
+ Returns:
111
+ Rotated feature tensor.
112
+ """
113
+ feature_dim = x.shape[-1]
114
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
115
+ return torch.cat((-x2, x1), dim=-1)
116
+
117
+ def _apply_1d_rope(
118
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
119
+ ) -> torch.Tensor:
120
+ """Applies 1D rotary position embeddings along one dimension.
121
+
122
+ Args:
123
+ tokens: Input token features.
124
+ positions: Position indices.
125
+ cos_comp: Cosine components for rotation.
126
+ sin_comp: Sine components for rotation.
127
+
128
+ Returns:
129
+ Tokens with applied rotary position embeddings.
130
+ """
131
+ # Embed positions with frequency components
132
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
133
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
134
+
135
+ # Apply rotation
136
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
137
+
138
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
139
+ """Applies 2D rotary position embeddings to input tokens.
140
+
141
+ Args:
142
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
143
+ The feature dimension (dim) must be divisible by 4.
144
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
145
+ the y and x coordinates for each token.
146
+
147
+ Returns:
148
+ Tensor of same shape as input with applied 2D rotary position embeddings.
149
+
150
+ Raises:
151
+ AssertionError: If input dimensions are invalid or positions are malformed.
152
+ """
153
+ # Validate inputs
154
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
155
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
156
+
157
+ # Compute feature dimension for each spatial direction
158
+ feature_dim = tokens.size(-1) // 2
159
+
160
+ # Get frequency components
161
+ max_position = int(positions.max()) + 1
162
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
163
+
164
+ # Split features for vertical and horizontal processing
165
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
166
+
167
+ # Apply RoPE separately for each dimension
168
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
169
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
170
+
171
+ # Combine processed features
172
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Callable, Optional
3
+ import warnings
4
+
5
+ from torch import Tensor, nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class SwiGLUFFN(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_features: int,
13
+ hidden_features: Optional[int] = None,
14
+ out_features: Optional[int] = None,
15
+ act_layer: Callable[..., nn.Module] = None,
16
+ drop: float = 0.0,
17
+ bias: bool = True,
18
+ ) -> None:
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
23
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
24
+
25
+ def forward(self, x: Tensor) -> Tensor:
26
+ x12 = self.w12(x)
27
+ x1, x2 = x12.chunk(2, dim=-1)
28
+ hidden = F.silu(x1) * x2
29
+ return self.w3(hidden)
30
+
31
+
32
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
33
+ # try:
34
+ # if XFORMERS_ENABLED:
35
+ # from xformers.ops import SwiGLU
36
+
37
+ # XFORMERS_AVAILABLE = True
38
+ # warnings.warn("xFormers is available (SwiGLU)")
39
+ # else:
40
+ # warnings.warn("xFormers is disabled (SwiGLU)")
41
+ # raise ImportError
42
+ # except ImportError:
43
+ SwiGLU = SwiGLUFFN
44
+ XFORMERS_AVAILABLE = False
45
+
46
+ # warnings.warn("xFormers is not available (SwiGLU)")
47
+
48
+
49
+ class SwiGLUFFNFused(SwiGLU):
50
+ def __init__(
51
+ self,
52
+ in_features: int,
53
+ hidden_features: Optional[int] = None,
54
+ out_features: Optional[int] = None,
55
+ act_layer: Callable[..., nn.Module] = None,
56
+ drop: float = 0.0,
57
+ bias: bool = True,
58
+ ) -> None:
59
+ out_features = out_features or in_features
60
+ hidden_features = hidden_features or in_features
61
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
62
+ super().__init__(
63
+ in_features=in_features,
64
+ hidden_features=hidden_features,
65
+ out_features=out_features,
66
+ bias=bias,
67
+ )
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/layers/vision_transformer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+ import logging
4
+ from typing import Sequence, Tuple, Union, Callable
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+ from torch.nn.init import trunc_normal_
10
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
11
+
12
+ logger = logging.getLogger("dinov2")
13
+
14
+
15
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
16
+ if not depth_first and include_root:
17
+ fn(module=module, name=name)
18
+ for child_name, child_module in module.named_children():
19
+ child_name = ".".join((name, child_name)) if name else child_name
20
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
21
+ if depth_first and include_root:
22
+ fn(module=module, name=name)
23
+ return module
24
+
25
+
26
+ class BlockChunk(nn.ModuleList):
27
+ def forward(self, x):
28
+ for b in self:
29
+ x = b(x)
30
+ return x
31
+
32
+
33
+ class DinoVisionTransformer(nn.Module):
34
+ def __init__(
35
+ self,
36
+ img_size=224,
37
+ patch_size=16,
38
+ in_chans=3,
39
+ embed_dim=768,
40
+ depth=12,
41
+ num_heads=12,
42
+ mlp_ratio=4.0,
43
+ qkv_bias=True,
44
+ ffn_bias=True,
45
+ proj_bias=True,
46
+ drop_path_rate=0.0,
47
+ drop_path_uniform=False,
48
+ init_values=None, # for layerscale: None or 0 => no layerscale
49
+ embed_layer=PatchEmbed,
50
+ act_layer=nn.GELU,
51
+ block_fn=Block,
52
+ ffn_layer="mlp",
53
+ block_chunks=1,
54
+ num_register_tokens=0,
55
+ interpolate_antialias=False,
56
+ interpolate_offset=0.1,
57
+ qk_norm=False,
58
+ ):
59
+ """
60
+ Args:
61
+ img_size (int, tuple): input image size
62
+ patch_size (int, tuple): patch size
63
+ in_chans (int): number of input channels
64
+ embed_dim (int): embedding dimension
65
+ depth (int): depth of transformer
66
+ num_heads (int): number of attention heads
67
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
68
+ qkv_bias (bool): enable bias for qkv if True
69
+ proj_bias (bool): enable bias for proj in attn if True
70
+ ffn_bias (bool): enable bias for ffn if True
71
+ drop_path_rate (float): stochastic depth rate
72
+ drop_path_uniform (bool): apply uniform drop rate across blocks
73
+ weight_init (str): weight init scheme
74
+ init_values (float): layer-scale init values
75
+ embed_layer (nn.Module): patch embedding layer
76
+ act_layer (nn.Module): MLP activation layer
77
+ block_fn (nn.Module): transformer block class
78
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
79
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
80
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
81
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
82
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
83
+ """
84
+ super().__init__()
85
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
86
+
87
+ # tricky but makes it work
88
+ self.use_checkpoint = False
89
+ #
90
+
91
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
92
+ self.num_tokens = 1
93
+ self.n_blocks = depth
94
+ self.num_heads = num_heads
95
+ self.patch_size = patch_size
96
+ self.num_register_tokens = num_register_tokens
97
+ self.interpolate_antialias = interpolate_antialias
98
+ self.interpolate_offset = interpolate_offset
99
+
100
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
101
+ num_patches = self.patch_embed.num_patches
102
+
103
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
104
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
105
+ assert num_register_tokens >= 0
106
+ self.register_tokens = (
107
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
108
+ )
109
+
110
+ if drop_path_uniform is True:
111
+ dpr = [drop_path_rate] * depth
112
+ else:
113
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
114
+
115
+ if ffn_layer == "mlp":
116
+ logger.info("using MLP layer as FFN")
117
+ ffn_layer = Mlp
118
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
119
+ logger.info("using SwiGLU layer as FFN")
120
+ ffn_layer = SwiGLUFFNFused
121
+ elif ffn_layer == "identity":
122
+ logger.info("using Identity layer as FFN")
123
+
124
+ def f(*args, **kwargs):
125
+ return nn.Identity()
126
+
127
+ ffn_layer = f
128
+ else:
129
+ raise NotImplementedError
130
+
131
+ blocks_list = [
132
+ block_fn(
133
+ dim=embed_dim,
134
+ num_heads=num_heads,
135
+ mlp_ratio=mlp_ratio,
136
+ qkv_bias=qkv_bias,
137
+ proj_bias=proj_bias,
138
+ ffn_bias=ffn_bias,
139
+ drop_path=dpr[i],
140
+ norm_layer=norm_layer,
141
+ act_layer=act_layer,
142
+ ffn_layer=ffn_layer,
143
+ init_values=init_values,
144
+ qk_norm=qk_norm,
145
+ )
146
+ for i in range(depth)
147
+ ]
148
+ if block_chunks > 0:
149
+ self.chunked_blocks = True
150
+ chunked_blocks = []
151
+ chunksize = depth // block_chunks
152
+ for i in range(0, depth, chunksize):
153
+ # this is to keep the block index consistent if we chunk the block list
154
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
155
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
156
+ else:
157
+ self.chunked_blocks = False
158
+ self.blocks = nn.ModuleList(blocks_list)
159
+
160
+ self.norm = norm_layer(embed_dim)
161
+ self.head = nn.Identity()
162
+
163
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
164
+
165
+ self.init_weights()
166
+
167
+ def init_weights(self):
168
+ trunc_normal_(self.pos_embed, std=0.02)
169
+ nn.init.normal_(self.cls_token, std=1e-6)
170
+ if self.register_tokens is not None:
171
+ nn.init.normal_(self.register_tokens, std=1e-6)
172
+ named_apply(init_weights_vit_timm, self)
173
+
174
+ def interpolate_pos_encoding(self, x, w, h):
175
+ previous_dtype = x.dtype
176
+ npatch = x.shape[1] - 1
177
+ N = self.pos_embed.shape[1] - 1
178
+ if npatch == N and w == h:
179
+ return self.pos_embed
180
+ pos_embed = self.pos_embed.float()
181
+ class_pos_embed = pos_embed[:, 0]
182
+ patch_pos_embed = pos_embed[:, 1:]
183
+ dim = x.shape[-1]
184
+ w0 = w // self.patch_size
185
+ h0 = h // self.patch_size
186
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
187
+ assert N == M * M
188
+ kwargs = {}
189
+ if self.interpolate_offset:
190
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
191
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
192
+ sx = float(w0 + self.interpolate_offset) / M
193
+ sy = float(h0 + self.interpolate_offset) / M
194
+ kwargs["scale_factor"] = (sx, sy)
195
+ else:
196
+ # Simply specify an output size instead of a scale factor
197
+ kwargs["size"] = (w0, h0)
198
+ patch_pos_embed = nn.functional.interpolate(
199
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
200
+ mode="bicubic",
201
+ antialias=self.interpolate_antialias,
202
+ **kwargs,
203
+ )
204
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
205
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
206
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
207
+
208
+ def prepare_tokens_with_masks(self, x, masks=None):
209
+ B, nc, w, h = x.shape
210
+ x = self.patch_embed(x)
211
+ if masks is not None:
212
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
213
+
214
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
215
+ x = x + self.interpolate_pos_encoding(x, w, h)
216
+
217
+ if self.register_tokens is not None:
218
+ x = torch.cat(
219
+ (
220
+ x[:, :1],
221
+ self.register_tokens.expand(x.shape[0], -1, -1),
222
+ x[:, 1:],
223
+ ),
224
+ dim=1,
225
+ )
226
+
227
+ return x
228
+
229
+ def forward_features_list(self, x_list, masks_list):
230
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
231
+
232
+ for blk in self.blocks:
233
+ if self.use_checkpoint:
234
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
235
+ else:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ if self.use_checkpoint:
261
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
262
+ else:
263
+ x = blk(x)
264
+
265
+ x_norm = self.norm(x)
266
+ return {
267
+ "x_norm_clstoken": x_norm[:, 0],
268
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
269
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
270
+ "x_prenorm": x,
271
+ "masks": masks,
272
+ }
273
+
274
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
275
+ x = self.prepare_tokens_with_masks(x)
276
+ # If n is an int, take the n last blocks. If it's a list, take them
277
+ output, total_block_len = [], len(self.blocks)
278
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
279
+ for i, blk in enumerate(self.blocks):
280
+ x = blk(x)
281
+ if i in blocks_to_take:
282
+ output.append(x)
283
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
284
+ return output
285
+
286
+ def _get_intermediate_layers_chunked(self, x, n=1):
287
+ x = self.prepare_tokens_with_masks(x)
288
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
289
+ # If n is an int, take the n last blocks. If it's a list, take them
290
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
291
+ for block_chunk in self.blocks:
292
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
293
+ x = blk(x)
294
+ if i in blocks_to_take:
295
+ output.append(x)
296
+ i += 1
297
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
298
+ return output
299
+
300
+ def get_intermediate_layers(
301
+ self,
302
+ x: torch.Tensor,
303
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
304
+ reshape: bool = False,
305
+ return_class_token: bool = False,
306
+ norm=True,
307
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
308
+ if self.chunked_blocks:
309
+ outputs = self._get_intermediate_layers_chunked(x, n)
310
+ else:
311
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
312
+ if norm:
313
+ outputs = [self.norm(out) for out in outputs]
314
+ class_tokens = [out[:, 0] for out in outputs]
315
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
316
+ if reshape:
317
+ B, _, w, h = x.shape
318
+ outputs = [
319
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
320
+ for out in outputs
321
+ ]
322
+ if return_class_token:
323
+ return tuple(zip(outputs, class_tokens))
324
+ return tuple(outputs)
325
+
326
+ def forward(self, *args, is_training=True, **kwargs):
327
+ ret = self.forward_features(*args, **kwargs)
328
+ if is_training:
329
+ return ret
330
+ else:
331
+ return self.head(ret["x_norm_clstoken"])
332
+
333
+
334
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
335
+ """ViT weight initialization, original timm impl (for reproducibility)"""
336
+ if isinstance(module, nn.Linear):
337
+ trunc_normal_(module.weight, std=0.02)
338
+ if module.bias is not None:
339
+ nn.init.zeros_(module.bias)
340
+
341
+
342
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
343
+ model = DinoVisionTransformer(
344
+ patch_size=patch_size,
345
+ embed_dim=384,
346
+ depth=12,
347
+ num_heads=6,
348
+ mlp_ratio=4,
349
+ block_fn=partial(Block, attn_class=MemEffAttention),
350
+ num_register_tokens=num_register_tokens,
351
+ **kwargs,
352
+ )
353
+ return model
354
+
355
+
356
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
357
+ model = DinoVisionTransformer(
358
+ patch_size=patch_size,
359
+ embed_dim=768,
360
+ depth=12,
361
+ num_heads=12,
362
+ mlp_ratio=4,
363
+ block_fn=partial(Block, attn_class=MemEffAttention),
364
+ num_register_tokens=num_register_tokens,
365
+ **kwargs,
366
+ )
367
+ return model
368
+
369
+
370
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
371
+ model = DinoVisionTransformer(
372
+ patch_size=patch_size,
373
+ embed_dim=1024,
374
+ depth=24,
375
+ num_heads=16,
376
+ mlp_ratio=4,
377
+ block_fn=partial(Block, attn_class=MemEffAttention),
378
+ num_register_tokens=num_register_tokens,
379
+ **kwargs,
380
+ )
381
+ return model
382
+
383
+
384
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
385
+ """
386
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
387
+ """
388
+ model = DinoVisionTransformer(
389
+ patch_size=patch_size,
390
+ embed_dim=1536,
391
+ depth=40,
392
+ num_heads=24,
393
+ mlp_ratio=4,
394
+ block_fn=partial(Block, attn_class=MemEffAttention),
395
+ num_register_tokens=num_register_tokens,
396
+ **kwargs,
397
+ )
398
+ return model
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/models/aggregator.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from streamvggt.layers import PatchEmbed
14
+ from streamvggt.layers.block import Block
15
+ from streamvggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
16
+ from streamvggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
21
+ _RESNET_STD = [0.229, 0.224, 0.225]
22
+
23
+
24
+ class Aggregator(nn.Module):
25
+ """
26
+ The Aggregator applies alternating-attention over input frames,
27
+ as described in VGGT: Visual Geometry Grounded Transformer.
28
+
29
+
30
+ Args:
31
+ img_size (int): Image size in pixels.
32
+ patch_size (int): Size of each patch for PatchEmbed.
33
+ embed_dim (int): Dimension of the token embeddings.
34
+ depth (int): Number of blocks.
35
+ num_heads (int): Number of attention heads.
36
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
37
+ num_register_tokens (int): Number of register tokens.
38
+ block_fn (nn.Module): The block type used for attention (Block by default).
39
+ qkv_bias (bool): Whether to include bias in QKV projections.
40
+ proj_bias (bool): Whether to include bias in the output projection.
41
+ ffn_bias (bool): Whether to include bias in MLP layers.
42
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
43
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
44
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
45
+ qk_norm (bool): Whether to apply QK normalization.
46
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
47
+ init_values (float): Init scale for layer scale.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ img_size=518,
53
+ patch_size=14,
54
+ embed_dim=1024,
55
+ depth=24,
56
+ num_heads=16,
57
+ mlp_ratio=4.0,
58
+ num_register_tokens=4,
59
+ block_fn=Block,
60
+ qkv_bias=True,
61
+ proj_bias=True,
62
+ ffn_bias=True,
63
+ patch_embed="dinov2_vitl14_reg",
64
+ aa_order=["frame", "global"],
65
+ aa_block_size=1,
66
+ qk_norm=True,
67
+ rope_freq=100,
68
+ init_values=0.01,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
73
+
74
+ # Initialize rotary position embedding if frequency > 0
75
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
76
+ self.position_getter = PositionGetter() if self.rope is not None else None
77
+
78
+ self.frame_blocks = nn.ModuleList(
79
+ [
80
+ block_fn(
81
+ dim=embed_dim,
82
+ num_heads=num_heads,
83
+ mlp_ratio=mlp_ratio,
84
+ qkv_bias=qkv_bias,
85
+ proj_bias=proj_bias,
86
+ ffn_bias=ffn_bias,
87
+ init_values=init_values,
88
+ qk_norm=qk_norm,
89
+ rope=self.rope,
90
+ )
91
+ for _ in range(depth)
92
+ ]
93
+ )
94
+
95
+ self.global_blocks = nn.ModuleList(
96
+ [
97
+ block_fn(
98
+ dim=embed_dim,
99
+ num_heads=num_heads,
100
+ mlp_ratio=mlp_ratio,
101
+ qkv_bias=qkv_bias,
102
+ proj_bias=proj_bias,
103
+ ffn_bias=ffn_bias,
104
+ init_values=init_values,
105
+ qk_norm=qk_norm,
106
+ rope=self.rope,
107
+ )
108
+ for _ in range(depth)
109
+ ]
110
+ )
111
+
112
+ self.depth = depth
113
+ self.aa_order = aa_order
114
+ self.patch_size = patch_size
115
+ self.aa_block_size = aa_block_size
116
+
117
+ # Validate that depth is divisible by aa_block_size
118
+ if self.depth % self.aa_block_size != 0:
119
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
120
+
121
+ self.aa_block_num = self.depth // self.aa_block_size
122
+
123
+ # Note: We have two camera tokens, one for the first frame and one for the rest
124
+ # The same applies for register tokens
125
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
126
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
127
+
128
+ # The patch tokens start after the camera and register tokens
129
+ self.patch_start_idx = 1 + num_register_tokens
130
+
131
+ # Initialize parameters with small values
132
+ nn.init.normal_(self.camera_token, std=1e-6)
133
+ nn.init.normal_(self.register_token, std=1e-6)
134
+
135
+ # Register normalization constants as buffers
136
+ for name, value in (
137
+ ("_resnet_mean", _RESNET_MEAN),
138
+ ("_resnet_std", _RESNET_STD),
139
+ ):
140
+ self.register_buffer(
141
+ name,
142
+ torch.FloatTensor(value).reshape(1, 1, 3, 1, 1),
143
+ persistent=False,
144
+ )
145
+
146
+
147
+ def __build_patch_embed__(
148
+ self,
149
+ patch_embed,
150
+ img_size,
151
+ patch_size,
152
+ num_register_tokens,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ block_chunks=0,
156
+ init_values=1.0,
157
+ embed_dim=1024,
158
+ ):
159
+ """
160
+ Build the patch embed layer. If 'conv', we use a
161
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
162
+ """
163
+
164
+ if "conv" in patch_embed:
165
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
166
+ else:
167
+ vit_models = {
168
+ "dinov2_vitl14_reg": vit_large,
169
+ "dinov2_vitb14_reg": vit_base,
170
+ "dinov2_vits14_reg": vit_small,
171
+ "dinov2_vitg2_reg": vit_giant2,
172
+ }
173
+
174
+ self.patch_embed = vit_models[patch_embed](
175
+ img_size=img_size,
176
+ patch_size=patch_size,
177
+ num_register_tokens=num_register_tokens,
178
+ interpolate_antialias=interpolate_antialias,
179
+ interpolate_offset=interpolate_offset,
180
+ block_chunks=block_chunks,
181
+ init_values=init_values,
182
+ )
183
+
184
+ # Disable gradient updates for mask token
185
+ if hasattr(self.patch_embed, "mask_token"):
186
+ self.patch_embed.mask_token.requires_grad_(False)
187
+
188
+ def forward(
189
+ self,
190
+ images: torch.Tensor,
191
+ past_key_values=None,
192
+ use_cache=False,
193
+ past_frame_idx=0
194
+ ) -> Tuple[List[torch.Tensor], int]:
195
+ """
196
+ Args:
197
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
198
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
199
+
200
+ Returns:
201
+ (list[torch.Tensor], int):
202
+ The list of outputs from the attention blocks,
203
+ and the patch_start_idx indicating where patch tokens begin.
204
+ """
205
+ B, S, C_in, H, W = images.shape
206
+
207
+ if use_cache and past_key_values[0] is not None:
208
+ _, _, S_true, _, _ = past_key_values[0][0].shape
209
+ S_true += 1
210
+ else:
211
+ S_true = S
212
+
213
+ if use_cache and S > 1:
214
+ print(f"Use KV cache expects S=1, got S={S}")
215
+
216
+ if C_in != 3:
217
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
218
+
219
+ # Normalize images and reshape for patch embed
220
+ images = (images - self._resnet_mean.to(images.device)) / self._resnet_std.to(images.device)
221
+
222
+ # Reshape to [B*S, C, H, W] for patch embedding
223
+ images = images.reshape(B * S, C_in, H, W)
224
+ patch_tokens = self.patch_embed(images)
225
+
226
+ if isinstance(patch_tokens, dict):
227
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
228
+
229
+ _, P, C = patch_tokens.shape
230
+
231
+ if use_cache:
232
+ camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
233
+ camera_token = camera_token_full[-1:, :, :]
234
+
235
+ register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
236
+ register_token = register_token_full[-1:, :, :]
237
+ else:
238
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
239
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
240
+ # Concatenate special tokens with patch tokens
241
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
242
+
243
+ pos = None
244
+ if self.rope is not None:
245
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
246
+
247
+ if self.patch_start_idx > 0:
248
+ # do not use position embedding for special tokens (camera and register tokens)
249
+ # so set pos to 0 for the special tokens
250
+ pos = pos + 1
251
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
252
+ pos = torch.cat([pos_special, pos], dim=1)
253
+
254
+ # update P because we added special tokens
255
+ _, P, C = tokens.shape
256
+
257
+ frame_idx = 0
258
+ global_idx = 0
259
+ output_list = []
260
+
261
+ for _ in range(self.aa_block_num):
262
+ for attn_type in self.aa_order:
263
+ if attn_type == "frame":
264
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
265
+ tokens, B, S, P, C, frame_idx, pos=pos
266
+ )
267
+ elif attn_type == "global":
268
+ if use_cache:
269
+ if past_key_values[global_idx] is not None:
270
+ k, v = past_key_values[global_idx]
271
+ tokens, global_idx, global_intermediates, new_kv = self._process_global_attention(
272
+ tokens, B, S, P, C, global_idx, pos=pos,
273
+ past_key_values_block=past_key_values[global_idx] if past_key_values[global_idx] is not None else None,
274
+ use_cache=True,
275
+ past_frame_idx=past_frame_idx
276
+ )
277
+ past_key_values[global_idx - 1] = new_kv
278
+ else:
279
+ tokens, global_idx, global_intermediates = self._process_global_attention(
280
+ tokens, B, S, P, C, global_idx, pos=pos
281
+ )
282
+ else:
283
+ raise ValueError(f"Unknown attention type: {attn_type}")
284
+ for i in range(len(frame_intermediates)):
285
+ # concat frame and global intermediates, [B x S x P x 2C]
286
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
287
+ output_list.append(concat_inter)
288
+
289
+ del concat_inter
290
+ del frame_intermediates
291
+ del global_intermediates
292
+ if use_cache:
293
+ return output_list, self.patch_start_idx, past_key_values
294
+ return output_list, self.patch_start_idx
295
+
296
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
297
+ """
298
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
299
+ """
300
+ # If needed, reshape tokens or positions:
301
+ if tokens.shape != (B * S, P, C):
302
+ tokens = tokens.reshape(B, S, P, C).reshape(B * S, P, C)
303
+
304
+ if pos is not None and pos.shape != (B * S, P, 2):
305
+ pos = pos.reshape(B, S, P, 2).reshape(B * S, P, 2)
306
+
307
+ intermediates = []
308
+
309
+ # by default, self.aa_block_size=1, which processes one block at a time
310
+ for _ in range(self.aa_block_size):
311
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
312
+ frame_idx += 1
313
+ intermediates.append(tokens.reshape(B, S, P, C))
314
+
315
+ return tokens, frame_idx, intermediates
316
+
317
+ def _process_global_attention(
318
+ self,
319
+ tokens,
320
+ B,
321
+ S,
322
+ P,
323
+ C,
324
+ global_idx,
325
+ pos=None,
326
+ past_key_values_block=None,
327
+ use_cache=False,
328
+ past_frame_idx=0
329
+ ) -> Union[Tuple[torch.Tensor, int, List[torch.Tensor]], Tuple[torch.Tensor, int, List[torch.Tensor], List]]:
330
+ """
331
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
332
+ """
333
+
334
+ if tokens.shape != (B, S * P, C):
335
+ tokens = tokens.reshape(B, S, P, C).reshape(B, S * P, C)
336
+
337
+ if pos is not None and pos.shape != (B, S * P, 2):
338
+ pos = pos.reshape(B, S, P, 2).reshape(B, S * P, 2)
339
+
340
+ intermediates = []
341
+
342
+ for _ in range(self.aa_block_size):
343
+ if not use_cache:
344
+ L = S * P
345
+ frame_ids = torch.arange(L, device=tokens.device) // P # [0,0,...,1,1,...,S-1]
346
+ future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
347
+ attn_mask = future_frame.to(tokens.dtype) * torch.finfo(tokens.dtype).min
348
+ else:
349
+ attn_mask = None
350
+
351
+ if use_cache:
352
+ tokens, block_kv = self.global_blocks[global_idx](
353
+ tokens,
354
+ pos=pos,
355
+ attn_mask=attn_mask,
356
+ past_key_values=past_key_values_block,
357
+ use_cache=True
358
+ )
359
+ else:
360
+ tokens = self.global_blocks[global_idx](tokens, pos=pos, attn_mask=attn_mask)
361
+ global_idx += 1
362
+ intermediates.append(tokens.reshape(B, S, P, C))
363
+
364
+ # if self.use_causal_global:
365
+ # del attn_mask
366
+ if use_cache:
367
+ return tokens, global_idx, intermediates, block_kv
368
+ return tokens, global_idx, intermediates
369
+
370
+
371
+ def slice_expand_and_flatten(token_tensor, B, S):
372
+ """
373
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
374
+ 1) Uses the first position (index=0) for the first frame only
375
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
376
+ 3) Expands both to match batch size B
377
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
378
+ followed by (S-1) second-position tokens
379
+ 5) Flattens to (B*S, X, C) for processing
380
+
381
+ Returns:
382
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
383
+ """
384
+
385
+ # Slice out the "query" tokens => shape (1, 1, ...)
386
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
387
+ # Slice out the "other" tokens => shape (1, S-1, ...)
388
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
389
+ # Concatenate => shape (B, S, ...)
390
+ combined = torch.cat([query, others], dim=1)
391
+
392
+ # Finally flatten => shape (B*S, ...)
393
+ combined = combined.reshape(B * S, *combined.shape[2:])
394
+ return combined
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/models/streamvggt.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
4
+
5
+ from streamvggt.models.aggregator import Aggregator
6
+ from streamvggt.heads.camera_head import CameraHead
7
+ from streamvggt.heads.dpt_head import DPTHead
8
+ from streamvggt.heads.track_head import TrackHead
9
+ from transformers.file_utils import ModelOutput
10
+ from typing import Optional, Tuple, List, Any
11
+ from dataclasses import dataclass
12
+ import pdb
13
+
14
+ @dataclass
15
+ class StreamVGGTOutput(ModelOutput):
16
+ ress: Optional[List[dict]] = None
17
+ views: Optional[torch.Tensor] = None
18
+
19
+ class StreamVGGT(nn.Module, PyTorchModelHubMixin):
20
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
21
+ super().__init__()
22
+
23
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
24
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
25
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
26
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
27
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
28
+
29
+
30
+
31
+ def forward(
32
+ self,
33
+ views,
34
+ query_points: torch.Tensor = None,
35
+ history_info: Optional[dict] = None,
36
+ past_key_values=None,
37
+ use_cache=False,
38
+ past_frame_idx=0
39
+ ):
40
+ images = torch.stack(
41
+ [view["img"] for view in views], dim=0
42
+ ).permute(1, 0, 2, 3, 4) # B S C H W
43
+
44
+ # If without batch dimension, add it
45
+ if len(images.shape) == 4:
46
+ images = images.unsqueeze(0)
47
+ if query_points is not None and len(query_points.shape) == 2:
48
+ query_points = query_points.unsqueeze(0)
49
+
50
+ if history_info is None:
51
+ history_info = {"token": None}
52
+
53
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
54
+ predictions = {}
55
+
56
+ with torch.cuda.amp.autocast(enabled=False):
57
+ if self.camera_head is not None:
58
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
59
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
60
+
61
+ if self.depth_head is not None:
62
+ depth, depth_conf = self.depth_head(
63
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
64
+ )
65
+ predictions["depth"] = depth
66
+ predictions["depth_conf"] = depth_conf
67
+
68
+ if self.point_head is not None:
69
+ pts3d, pts3d_conf = self.point_head(
70
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
71
+ )
72
+ predictions["world_points"] = pts3d
73
+ predictions["world_points_conf"] = pts3d_conf
74
+
75
+ if self.track_head is not None and query_points is not None:
76
+ track_list, vis, conf = self.track_head(
77
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
78
+ )
79
+ predictions["track"] = track_list[-1] # track of the last iteration
80
+ predictions["vis"] = vis
81
+ predictions["conf"] = conf
82
+ predictions["images"] = images
83
+
84
+ B, S = images.shape[:2]
85
+ ress = []
86
+ for s in range(S):
87
+ res = {
88
+ 'pts3d_in_other_view': predictions['world_points'][:, s], # [B, H, W, 3]
89
+ 'conf': predictions['world_points_conf'][:, s], # [B, H, W]
90
+
91
+ 'depth': predictions['depth'][:, s], # [B, H, W, 1]
92
+ 'depth_conf': predictions['depth_conf'][:, s], # [B, H, W]
93
+ 'camera_pose': predictions['pose_enc'][:, s, :], # [B, 9]
94
+
95
+ **({'valid_mask': views[s]["valid_mask"]}
96
+ if 'valid_mask' in views[s] else {}), # [B, H, W]
97
+
98
+ **({'track': predictions['track'][:, s], # [B, N, 2]
99
+ 'vis': predictions['vis'][:, s], # [B, N]
100
+ 'track_conf': predictions['conf'][:, s]}
101
+ if 'track' in predictions else {})
102
+ }
103
+ ress.append(res)
104
+ return StreamVGGTOutput(ress=ress, views=views) # [S] [B, C, H, W]
105
+
106
+ def frontendT(self, frame):
107
+ images = frame[None, None] # 1,1,C,H,W
108
+ B,S,C,H,W = images.shape
109
+ #if self.frontend_past_key_values is None:
110
+ with torch.no_grad():
111
+ if not hasattr(self,"frontend_past_key_values"):
112
+ self.frontend_images_size = (B,C,H,W)
113
+ self.frontend_past_key_values = [None] * self.aggregator.depth
114
+ self.frontend_kid = 0
115
+ else:
116
+ self.frontend_kid += 1
117
+
118
+
119
+ aggregator_output = self.aggregator(
120
+ images,
121
+ past_key_values=self.frontend_past_key_values,
122
+ use_cache=True,
123
+ past_frame_idx=self.frontend_kid
124
+ )
125
+ aggregated_tokens, patch_start_idx, self.frontend_past_key_values = aggregator_output
126
+ aggregated_tokens = [t_.detach() for t_ in aggregated_tokens]
127
+
128
+ return aggregated_tokens
129
+
130
+
131
+ def extract(self, map_tokens, query_points=None):
132
+ B,C,H,W = self.frontend_images_size
133
+ S = map_tokens[0].shape[1]
134
+ images = torch.zeros((B,S,C,H,W)).to('cuda')
135
+
136
+ aggregated_tokens = map_tokens
137
+ patch_start_idx = self.aggregator.patch_start_idx
138
+ all_ress = []
139
+
140
+ with torch.no_grad():
141
+ with torch.cuda.amp.autocast(enabled=False):
142
+ if self.camera_head is not None:
143
+ pose_enc = self.camera_head(aggregated_tokens)
144
+ pose_enc = pose_enc[-1]
145
+ camera_pose = pose_enc # 1,S,9
146
+
147
+ if self.depth_head is not None:
148
+ depth, depth_conf = self.depth_head(
149
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx
150
+ )
151
+ # 1,S,H,W,1
152
+ # 1,S,H,W
153
+
154
+ if self.point_head is not None:
155
+ pts3d, pts3d_conf = self.point_head(
156
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx
157
+ )
158
+ # 1,S,H,W,3
159
+ # 1.S,H,W
160
+
161
+ if self.track_head is not None and query_points is not None:
162
+ track_list, vis, conf = self.track_head(
163
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx, query_points=query_points
164
+ )
165
+ track = track_list[-1][:, 0]
166
+ query_points = track
167
+ vis = vis[:, 0]
168
+ track_conf = conf[:, 0]
169
+
170
+ output = {
171
+ 'pts3d_in_other_view': pts3d,
172
+ 'conf': pts3d_conf,
173
+ 'depth': depth,
174
+ 'depth_conf': depth_conf,
175
+ 'camera_pose': camera_pose,
176
+ }
177
+ #output = StreamVGGTOutput(ress=all_ress, views=processed_frames)
178
+ return output
179
+
180
+
181
+ def inference(self, frames, query_points: torch.Tensor = None, past_key_values=None):
182
+ past_key_values = [None] * self.aggregator.depth
183
+ past_key_values_camera = [None] * self.camera_head.trunk_depth
184
+
185
+ all_ress = []
186
+ processed_frames = []
187
+
188
+ for i, frame in enumerate(frames):
189
+ images = frame["img"].unsqueeze(0)
190
+ aggregator_output = self.aggregator(
191
+ images,
192
+ past_key_values=past_key_values,
193
+ use_cache=True,
194
+ past_frame_idx=i
195
+ )
196
+
197
+ if isinstance(aggregator_output, tuple) and len(aggregator_output) == 3:
198
+ aggregated_tokens, patch_start_idx, past_key_values = aggregator_output
199
+ else:
200
+ aggregated_tokens, patch_start_idx = aggregator_output
201
+
202
+ with torch.cuda.amp.autocast(enabled=False):
203
+ if self.camera_head is not None:
204
+ pose_enc, past_key_values_camera = self.camera_head(aggregated_tokens, past_key_values_camera=past_key_values_camera, use_cache=True)
205
+ pose_enc = pose_enc[-1]
206
+ camera_pose = pose_enc[:, 0, :]
207
+
208
+ if self.depth_head is not None:
209
+ depth, depth_conf = self.depth_head(
210
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx
211
+ )
212
+ depth = depth[:, 0]
213
+ depth_conf = depth_conf[:, 0]
214
+
215
+ if self.point_head is not None:
216
+ pts3d, pts3d_conf = self.point_head(
217
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx
218
+ )
219
+ pts3d = pts3d[:, 0]
220
+ pts3d_conf = pts3d_conf[:, 0]
221
+
222
+ if self.track_head is not None and query_points is not None:
223
+ track_list, vis, conf = self.track_head(
224
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx, query_points=query_points
225
+ )
226
+ track = track_list[-1][:, 0]
227
+ query_points = track
228
+ vis = vis[:, 0]
229
+ track_conf = conf[:, 0]
230
+
231
+ all_ress.append({
232
+ 'pts3d_in_other_view': pts3d,
233
+ 'conf': pts3d_conf,
234
+ 'depth': depth,
235
+ 'depth_conf': depth_conf,
236
+ 'camera_pose': camera_pose,
237
+ **({'valid_mask': frame["valid_mask"]}
238
+ if 'valid_mask' in frame else {}),
239
+
240
+ **({'track': track,
241
+ 'vis': vis,
242
+ 'track_conf': track_conf}
243
+ if query_points is not None else {})
244
+ })
245
+ processed_frames.append(frame)
246
+
247
+ output = StreamVGGTOutput(ress=all_ress, views=processed_frames)
248
+ return output
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/geometry.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ def unproject_depth_map_to_point_map(
13
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
14
+ ) -> np.ndarray:
15
+ """
16
+ Unproject a batch of depth maps to 3D world coordinates.
17
+
18
+ Args:
19
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
20
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
21
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
22
+
23
+ Returns:
24
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
25
+ """
26
+ if isinstance(depth_map, torch.Tensor):
27
+ depth_map = depth_map.cpu().numpy()
28
+ if isinstance(extrinsics_cam, torch.Tensor):
29
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
30
+ if isinstance(intrinsics_cam, torch.Tensor):
31
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
32
+
33
+ world_points_list = []
34
+ for frame_idx in range(depth_map.shape[0]):
35
+ cur_world_points, _, _ = depth_to_world_coords_points(
36
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
+ )
38
+ world_points_list.append(cur_world_points)
39
+ world_points_array = np.stack(world_points_list, axis=0)
40
+
41
+ return world_points_array
42
+
43
+
44
+ def depth_to_world_coords_points(
45
+ depth_map: np.ndarray,
46
+ extrinsic: np.ndarray,
47
+ intrinsic: np.ndarray,
48
+ eps=1e-8,
49
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
50
+ """
51
+ Convert a depth map to world coordinates.
52
+
53
+ Args:
54
+ depth_map (np.ndarray): Depth map of shape (H, W).
55
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
56
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
57
+
58
+ Returns:
59
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
60
+ """
61
+ if depth_map is None:
62
+ return None, None, None
63
+
64
+ # Valid depth mask
65
+ point_mask = depth_map > eps
66
+
67
+ # Convert depth map to camera coordinates
68
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
69
+
70
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
71
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
72
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
73
+
74
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
75
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
76
+
77
+ # Apply the rotation and translation to the camera coordinates
78
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
79
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
80
+
81
+ return world_coords_points, cam_coords_points, point_mask
82
+
83
+
84
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
85
+ """
86
+ Convert a depth map to camera coordinates.
87
+
88
+ Args:
89
+ depth_map (np.ndarray): Depth map of shape (H, W).
90
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
91
+
92
+ Returns:
93
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
94
+ """
95
+ H, W = depth_map.shape
96
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
97
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
98
+
99
+ # Intrinsic parameters
100
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
101
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
102
+
103
+ # Generate grid of pixel coordinates
104
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
105
+
106
+ # Unproject to camera coordinates
107
+ x_cam = (u - cu) * depth_map / fu
108
+ y_cam = (v - cv) * depth_map / fv
109
+ z_cam = depth_map
110
+
111
+ # Stack to form camera coordinates
112
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
113
+
114
+ return cam_coords
115
+
116
+
117
+ def closed_form_inverse_se3(se3, R=None, T=None):
118
+ """
119
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
120
+
121
+ If `R` and `T` are provided, they must correspond to the rotation and translation
122
+ components of `se3`. Otherwise, they will be extracted from `se3`.
123
+
124
+ Args:
125
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
126
+ R (optional): Nx3x3 array or tensor of rotation matrices.
127
+ T (optional): Nx3x1 array or tensor of translation vectors.
128
+
129
+ Returns:
130
+ Inverted SE3 matrices with the same type and device as `se3`.
131
+
132
+ Shapes:
133
+ se3: (N, 4, 4)
134
+ R: (N, 3, 3)
135
+ T: (N, 3, 1)
136
+ """
137
+ # Check if se3 is a numpy array or a torch tensor
138
+ is_numpy = isinstance(se3, np.ndarray)
139
+
140
+ # Validate shapes
141
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
142
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
143
+
144
+ # Extract R and T if not provided
145
+ if R is None:
146
+ R = se3[:, :3, :3] # (N,3,3)
147
+ if T is None:
148
+ T = se3[:, :3, 3:] # (N,3,1)
149
+
150
+ # Transpose R
151
+ if is_numpy:
152
+ # Compute the transpose of the rotation for NumPy
153
+ R_transposed = np.transpose(R, (0, 2, 1))
154
+ # -R^T t for NumPy
155
+ top_right = -np.matmul(R_transposed, T)
156
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
157
+ else:
158
+ R_transposed = R.transpose(1, 2) # (N,3,3)
159
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
160
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
161
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
162
+
163
+ inverted_matrix[:, :3, :3] = R_transposed
164
+ inverted_matrix[:, :3, 3:] = top_right
165
+
166
+ return inverted_matrix
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/load_fn.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list, mode="crop"):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
20
+ - "crop" (default): Sets width to 518px and center crops height if needed.
21
+ - "pad": Preserves all pixels by making the largest dimension 518px
22
+ and padding the smaller dimension to reach a square shape.
23
+
24
+ Returns:
25
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
26
+
27
+ Raises:
28
+ ValueError: If the input list is empty or if mode is invalid
29
+
30
+ Notes:
31
+ - Images with different dimensions will be padded with white (value=1.0)
32
+ - A warning is printed when images have different shapes
33
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
34
+ and height is center-cropped if larger than 518px
35
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
36
+ and the smaller dimension is padded to reach a square shape (518x518)
37
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
38
+ """
39
+ # Check for empty list
40
+ if len(image_path_list) == 0:
41
+ raise ValueError("At least 1 image is required")
42
+
43
+ # Validate mode
44
+ if mode not in ["crop", "pad"]:
45
+ raise ValueError("Mode must be either 'crop' or 'pad'")
46
+
47
+ images = []
48
+ shapes = set()
49
+ to_tensor = TF.ToTensor()
50
+ target_size = 518
51
+
52
+ # First process all images and collect their shapes
53
+ for image_path in image_path_list:
54
+
55
+ # Open image
56
+ img = Image.open(image_path)
57
+
58
+ # If there's an alpha channel, blend onto white background:
59
+ if img.mode == "RGBA":
60
+ # Create white background
61
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
62
+ # Alpha composite onto the white background
63
+ img = Image.alpha_composite(background, img)
64
+
65
+ # Now convert to "RGB" (this step assigns white for transparent areas)
66
+ img = img.convert("RGB")
67
+
68
+ width, height = img.size
69
+
70
+ if mode == "pad":
71
+ # Make the largest dimension 518px while maintaining aspect ratio
72
+ if width >= height:
73
+ new_width = target_size
74
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
75
+ else:
76
+ new_height = target_size
77
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
78
+ else: # mode == "crop"
79
+ # Original behavior: set width to 518px
80
+ new_width = target_size
81
+ # Calculate height maintaining aspect ratio, divisible by 14
82
+ new_height = round(height * (new_width / width) / 14) * 14
83
+
84
+ # Resize with new dimensions (width, height)
85
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
86
+ img = to_tensor(img) # Convert to tensor (0, 1)
87
+
88
+ # Center crop height if it's larger than 518 (only in crop mode)
89
+ if mode == "crop" and new_height > target_size:
90
+ start_y = (new_height - target_size) // 2
91
+ img = img[:, start_y: start_y + target_size, :]
92
+
93
+ # For pad mode, pad to make a square of target_size x target_size
94
+ if mode == "pad":
95
+ h_padding = target_size - img.shape[1]
96
+ w_padding = target_size - img.shape[2]
97
+
98
+ if h_padding > 0 or w_padding > 0:
99
+ pad_top = h_padding // 2
100
+ pad_bottom = h_padding - pad_top
101
+ pad_left = w_padding // 2
102
+ pad_right = w_padding - pad_left
103
+
104
+ # Pad with white (value=1.0)
105
+ img = torch.nn.functional.pad(
106
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
107
+ )
108
+
109
+ shapes.add((img.shape[1], img.shape[2]))
110
+ images.append(img)
111
+
112
+ # Check if we have different shapes
113
+ # In theory our model can also work well with different shapes
114
+ if len(shapes) > 1:
115
+ print(f"Warning: Found images with different shapes: {shapes}")
116
+ # Find maximum dimensions
117
+ max_height = max(shape[0] for shape in shapes)
118
+ max_width = max(shape[1] for shape in shapes)
119
+
120
+ # Pad images if necessary
121
+ padded_images = []
122
+ for img in images:
123
+ h_padding = max_height - img.shape[1]
124
+ w_padding = max_width - img.shape[2]
125
+
126
+ if h_padding > 0 or w_padding > 0:
127
+ pad_top = h_padding // 2
128
+ pad_bottom = h_padding - pad_top
129
+ pad_left = w_padding // 2
130
+ pad_right = w_padding - pad_left
131
+
132
+ img = torch.nn.functional.pad(
133
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
134
+ )
135
+ padded_images.append(img)
136
+ images = padded_images
137
+
138
+ images = torch.stack(images) # concatenate images
139
+
140
+ # Ensure correct shape when single image
141
+ if len(image_path_list) == 1:
142
+ # Verify shape is (1, C, H, W)
143
+ if images.dim() == 3:
144
+ images = images.unsqueeze(0)
145
+
146
+ return images
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/pose_enc.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from .rotation import quat_to_mat, mat_to_quat
9
+
10
+
11
+ def extri_intri_to_pose_encoding(
12
+ extrinsics,
13
+ intrinsics,
14
+ image_size_hw=None, # e.g., (256, 512)
15
+ pose_encoding_type="absT_quaR_FoV",
16
+ ):
17
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
18
+
19
+ This function transforms camera parameters into a unified pose encoding format,
20
+ which can be used for various downstream tasks like pose prediction or representation.
21
+
22
+ Args:
23
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
24
+ where B is batch size and S is sequence length.
25
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
26
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
27
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
28
+ Defined in pixels, with format:
29
+ [[fx, 0, cx],
30
+ [0, fy, cy],
31
+ [0, 0, 1]]
32
+ where fx, fy are focal lengths and (cx, cy) is the principal point
33
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
34
+ Required for computing field of view values. For example: (256, 512).
35
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
36
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
37
+
38
+ Returns:
39
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
40
+ For "absT_quaR_FoV" type, the 9 dimensions are:
41
+ - [:3] = absolute translation vector T (3D)
42
+ - [3:7] = rotation as quaternion quat (4D)
43
+ - [7:] = field of view (2D)
44
+ """
45
+
46
+ # extrinsics: BxSx3x4
47
+ # intrinsics: BxSx3x3
48
+
49
+ if pose_encoding_type == "absT_quaR_FoV":
50
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
51
+ T = extrinsics[:, :, :3, 3] # BxSx3
52
+
53
+ quat = mat_to_quat(R)
54
+ # Note the order of h and w here
55
+ H, W = image_size_hw
56
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
57
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
58
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
59
+ else:
60
+ raise NotImplementedError
61
+
62
+ return pose_encoding
63
+
64
+
65
+ def pose_encoding_to_extri_intri(
66
+ pose_encoding,
67
+ image_size_hw=None, # e.g., (256, 512)
68
+ pose_encoding_type="absT_quaR_FoV",
69
+ build_intrinsics=True,
70
+ ):
71
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
72
+
73
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
74
+ reconstructing the full camera parameters from the compact encoding.
75
+
76
+ Args:
77
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
78
+ where B is batch size and S is sequence length.
79
+ For "absT_quaR_FoV" type, the 9 dimensions are:
80
+ - [:3] = absolute translation vector T (3D)
81
+ - [3:7] = rotation as quaternion quat (4D)
82
+ - [7:] = field of view (2D)
83
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
84
+ Required for reconstructing intrinsics from field of view values.
85
+ For example: (256, 512).
86
+ pose_encoding_type (str): Type of pose encoding used. Currently only
87
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
88
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
89
+ If False, only extrinsics are returned and intrinsics will be None.
90
+
91
+ Returns:
92
+ tuple: (extrinsics, intrinsics)
93
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
94
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
95
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
96
+ a 3x1 translation vector.
97
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
98
+ or None if build_intrinsics is False. Defined in pixels, with format:
99
+ [[fx, 0, cx],
100
+ [0, fy, cy],
101
+ [0, 0, 1]]
102
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
103
+ assumed to be at the center of the image (W/2, H/2).
104
+ """
105
+
106
+ intrinsics = None
107
+
108
+ if pose_encoding_type == "absT_quaR_FoV":
109
+ T = pose_encoding[..., :3]
110
+ quat = pose_encoding[..., 3:7]
111
+ fov_h = pose_encoding[..., 7]
112
+ fov_w = pose_encoding[..., 8]
113
+
114
+ R = quat_to_mat(quat)
115
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
116
+
117
+ if build_intrinsics:
118
+ H, W = image_size_hw
119
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
120
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
121
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
122
+ intrinsics[..., 0, 0] = fx
123
+ intrinsics[..., 1, 1] = fy
124
+ intrinsics[..., 0, 2] = W / 2
125
+ intrinsics[..., 1, 2] = H / 2
126
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
127
+ else:
128
+ raise NotImplementedError
129
+
130
+ return extrinsics, intrinsics
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/rotation.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
+
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Quaternion Order: XYZW or say ijkr, scalar-last
17
+
18
+ Convert rotations given as quaternions to rotation matrices.
19
+ Args:
20
+ quaternions: quaternions with real part last,
21
+ as tensor of shape (..., 4).
22
+
23
+ Returns:
24
+ Rotation matrices as tensor of shape (..., 3, 3).
25
+ """
26
+ i, j, k, r = torch.unbind(quaternions, -1)
27
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
28
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
29
+
30
+ o = torch.stack(
31
+ (
32
+ 1 - two_s * (j * j + k * k),
33
+ two_s * (i * j - k * r),
34
+ two_s * (i * k + j * r),
35
+ two_s * (i * j + k * r),
36
+ 1 - two_s * (i * i + k * k),
37
+ two_s * (j * k - i * r),
38
+ two_s * (i * k - j * r),
39
+ two_s * (j * k + i * r),
40
+ 1 - two_s * (i * i + j * j),
41
+ ),
42
+ -1,
43
+ )
44
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
45
+
46
+
47
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert rotations given as rotation matrices to quaternions.
50
+
51
+ Args:
52
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
53
+
54
+ Returns:
55
+ quaternions with real part last, as tensor of shape (..., 4).
56
+ Quaternion Order: XYZW or say ijkr, scalar-last
57
+ """
58
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
59
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
60
+
61
+ batch_dim = matrix.shape[:-2]
62
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
63
+
64
+ q_abs = _sqrt_positive_part(
65
+ torch.stack(
66
+ [
67
+ 1.0 + m00 + m11 + m22,
68
+ 1.0 + m00 - m11 - m22,
69
+ 1.0 - m00 + m11 - m22,
70
+ 1.0 - m00 - m11 + m22,
71
+ ],
72
+ dim=-1,
73
+ )
74
+ )
75
+
76
+ # we produce the desired quaternion multiplied by each of r, i, j, k
77
+ quat_by_rijk = torch.stack(
78
+ [
79
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
80
+ # `int`.
81
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
82
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
83
+ # `int`.
84
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
85
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
86
+ # `int`.
87
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
88
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
89
+ # `int`.
90
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
91
+ ],
92
+ dim=-2,
93
+ )
94
+
95
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
96
+ # the candidate won't be picked.
97
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
98
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
99
+
100
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
101
+ # forall i; we pick the best-conditioned one (with the largest denominator)
102
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
103
+
104
+ # Convert from rijk to ijkr
105
+ out = out[..., [1, 2, 3, 0]]
106
+
107
+ out = standardize_quaternion(out)
108
+
109
+ return out
110
+
111
+
112
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Returns torch.sqrt(torch.max(0, x))
115
+ but with a zero subgradient where x is 0.
116
+ """
117
+ ret = torch.zeros_like(x)
118
+ positive_mask = x > 0
119
+ if torch.is_grad_enabled():
120
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
121
+ else:
122
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
123
+ return ret
124
+
125
+
126
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
127
+ """
128
+ Convert a unit quaternion to a standard form: one in which the real
129
+ part is non negative.
130
+
131
+ Args:
132
+ quaternions: Quaternions with real part last,
133
+ as tensor of shape (..., 4).
134
+
135
+ Returns:
136
+ Standardized quaternions as tensor of shape (..., 4).
137
+ """
138
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/streamvggt/utils/visual_track.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def color_from_xy(x, y, W, H, cmap_name="hsv"):
14
+ """
15
+ Map (x, y) -> color in (R, G, B).
16
+ 1) Normalize x,y to [0,1].
17
+ 2) Combine them into a single scalar c in [0,1].
18
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
19
+
20
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
21
+ """
22
+ import matplotlib.cm
23
+ import matplotlib.colors
24
+
25
+ x_norm = x / max(W - 1, 1)
26
+ y_norm = y / max(H - 1, 1)
27
+ # Simple combination:
28
+ c = (x_norm + y_norm) / 2.0
29
+
30
+ cmap = matplotlib.cm.get_cmap(cmap_name)
31
+ # cmap(c) -> (r,g,b,a) in [0,1]
32
+ rgba = cmap(c)
33
+ r, g, b = rgba[0], rgba[1], rgba[2]
34
+ return (r, g, b) # in [0,1], RGB order
35
+
36
+
37
+ def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
38
+ """
39
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
40
+ in [0,255]. The color is determined by the (x,y) position in the first
41
+ visible frame for each track.
42
+
43
+ Args:
44
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
45
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
46
+ image_width, image_height: used for normalizing (x, y).
47
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
48
+
49
+ Returns:
50
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
51
+ """
52
+ S, N, _ = tracks_b.shape
53
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
54
+
55
+ if vis_mask_b is None:
56
+ # treat all as visible
57
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
58
+
59
+ for i in range(N):
60
+ # Find first visible frame for track i
61
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
62
+ if len(visible_frames) == 0:
63
+ # track is never visible; just assign black or something
64
+ track_colors[i] = (0, 0, 0)
65
+ continue
66
+
67
+ first_s = int(visible_frames[0].item())
68
+ # use that frame's (x,y)
69
+ x, y = tracks_b[first_s, i].tolist()
70
+
71
+ # map (x,y) -> (R,G,B) in [0,1]
72
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
73
+ # scale to [0,255]
74
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
75
+ track_colors[i] = (r, g, b)
76
+
77
+ return track_colors
78
+
79
+
80
+ def visualize_tracks_on_images(
81
+ images,
82
+ tracks,
83
+ track_vis_mask=None,
84
+ out_dir="track_visuals_concat_by_xy",
85
+ image_format="CHW", # "CHW" or "HWC"
86
+ normalize_mode="[0,1]",
87
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
88
+ frames_per_row=4, # New parameter for grid layout
89
+ save_grid=True, # Flag to control whether to save the grid image
90
+ ):
91
+ """
92
+ Visualizes frames in a grid layout with specified frames per row.
93
+ Each track's color is determined by its (x,y) position
94
+ in the first visible frame (or frame 0 if always visible).
95
+ Finally convert the BGR result to RGB before saving.
96
+ Also saves each individual frame as a separate PNG file.
97
+
98
+ Args:
99
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
100
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
101
+ track_vis_mask: torch.Tensor (S, N) or None.
102
+ out_dir: folder to save visualizations.
103
+ image_format: "CHW" or "HWC".
104
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
105
+ cmap_name: a matplotlib colormap name for color_from_xy.
106
+ frames_per_row: number of frames to display in each row of the grid.
107
+ save_grid: whether to save all frames in one grid image.
108
+
109
+ Returns:
110
+ None (saves images in out_dir).
111
+ """
112
+
113
+ if len(tracks.shape) == 4:
114
+ tracks = tracks.squeeze(0)
115
+ images = images.squeeze(0)
116
+ if track_vis_mask is not None:
117
+ track_vis_mask = track_vis_mask.squeeze(0)
118
+
119
+ import matplotlib
120
+
121
+ matplotlib.use("Agg") # for non-interactive (optional)
122
+
123
+ os.makedirs(out_dir, exist_ok=True)
124
+
125
+ S = images.shape[0]
126
+ _, N, _ = tracks.shape # (S, N, 2)
127
+
128
+ # Move to CPU
129
+ images = images.cpu().clone()
130
+ tracks = tracks.cpu().clone()
131
+ if track_vis_mask is not None:
132
+ track_vis_mask = track_vis_mask.cpu().clone()
133
+
134
+ # Infer H, W from images shape
135
+ if image_format == "CHW":
136
+ # e.g. images[s].shape = (3, H, W)
137
+ H, W = images.shape[2], images.shape[3]
138
+ else:
139
+ # e.g. images[s].shape = (H, W, 3)
140
+ H, W = images.shape[1], images.shape[2]
141
+
142
+ # Pre-compute the color for each track i based on first visible position
143
+ track_colors_rgb = get_track_colors_by_position(
144
+ tracks, # shape (S, N, 2)
145
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
146
+ image_width=W,
147
+ image_height=H,
148
+ cmap_name=cmap_name,
149
+ )
150
+
151
+ # We'll accumulate each frame's drawn image in a list
152
+ frame_images = []
153
+
154
+ for s in range(S):
155
+ # shape => either (3, H, W) or (H, W, 3)
156
+ img = images[s]
157
+
158
+ # Convert to (H, W, 3)
159
+ if image_format == "CHW":
160
+ img = img.permute(1, 2, 0) # (H, W, 3)
161
+ # else "HWC", do nothing
162
+
163
+ img = img.numpy().astype(np.float32)
164
+
165
+ # Scale to [0,255] if needed
166
+ if normalize_mode == "[0,1]":
167
+ img = np.clip(img, 0, 1) * 255.0
168
+ elif normalize_mode == "[-1,1]":
169
+ img = (img + 1.0) * 0.5 * 255.0
170
+ img = np.clip(img, 0, 255.0)
171
+ # else no normalization
172
+
173
+ # Convert to uint8
174
+ img = img.astype(np.uint8)
175
+
176
+ # For drawing in OpenCV, convert to BGR
177
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
178
+
179
+ # Draw each visible track
180
+ cur_tracks = tracks[s] # shape (N, 2)
181
+ if track_vis_mask is not None:
182
+ valid_indices = torch.where(track_vis_mask[s])[0]
183
+ else:
184
+ valid_indices = range(N)
185
+
186
+ cur_tracks_np = cur_tracks.numpy()
187
+ for i in valid_indices:
188
+ x, y = cur_tracks_np[i]
189
+ pt = (int(round(x)), int(round(y)))
190
+
191
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
192
+ R, G, B = track_colors_rgb[i]
193
+ color_bgr = (int(B), int(G), int(R))
194
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
195
+
196
+ # Convert back to RGB for consistent final saving:
197
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ # Save individual frame
200
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
201
+ # Convert to BGR for OpenCV imwrite
202
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
203
+ cv2.imwrite(frame_path, frame_bgr)
204
+
205
+ frame_images.append(img_rgb)
206
+
207
+ # Only create and save the grid image if save_grid is True
208
+ if save_grid:
209
+ # Calculate grid dimensions
210
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
211
+
212
+ # Create a grid of images
213
+ grid_img = None
214
+ for row in range(num_rows):
215
+ start_idx = row * frames_per_row
216
+ end_idx = min(start_idx + frames_per_row, S)
217
+
218
+ # Concatenate this row horizontally
219
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
220
+
221
+ # If this row has fewer than frames_per_row images, pad with black
222
+ if end_idx - start_idx < frames_per_row:
223
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
224
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
225
+ row_img = np.concatenate([row_img, padding], axis=1)
226
+
227
+ # Add this row to the grid
228
+ if grid_img is None:
229
+ grid_img = row_img
230
+ else:
231
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
232
+
233
+ out_path = os.path.join(out_dir, "tracks_grid.png")
234
+ # Convert back to BGR for OpenCV imwrite
235
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
236
+ cv2.imwrite(out_path, grid_img_bgr)
237
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
238
+
239
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/camera_head.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from vggt.layers import Mlp
15
+ from vggt.layers.block import Block
16
+ from vggt.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(
54
+ dim=dim_in,
55
+ num_heads=num_heads,
56
+ mlp_ratio=mlp_ratio,
57
+ init_values=init_values,
58
+ )
59
+ for _ in range(trunk_depth)
60
+ ]
61
+ )
62
+
63
+ # Normalizations for camera token and trunk output.
64
+ self.token_norm = nn.LayerNorm(dim_in)
65
+ self.trunk_norm = nn.LayerNorm(dim_in)
66
+
67
+ # Learnable empty camera pose token.
68
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
69
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
70
+
71
+ # Module for producing modulation parameters: shift, scale, and a gate.
72
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
73
+
74
+ # Adaptive layer normalization without affine parameters.
75
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
76
+ self.pose_branch = Mlp(
77
+ in_features=dim_in,
78
+ hidden_features=dim_in // 2,
79
+ out_features=self.target_dim,
80
+ drop=0,
81
+ )
82
+
83
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
84
+ """
85
+ Forward pass to predict camera parameters.
86
+
87
+ Args:
88
+ aggregated_tokens_list (list): List of token tensors from the network;
89
+ the last tensor is used for prediction.
90
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
91
+
92
+ Returns:
93
+ list: A list of predicted camera encodings (post-activation) from each iteration.
94
+ """
95
+ # Use tokens from the last block for camera prediction.
96
+ tokens = aggregated_tokens_list[-1]
97
+
98
+ # Extract the camera tokens
99
+ pose_tokens = tokens[:, :, 0]
100
+ pose_tokens = self.token_norm(pose_tokens)
101
+
102
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
103
+ return pred_pose_enc_list
104
+
105
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
106
+ """
107
+ Iteratively refine camera pose predictions.
108
+
109
+ Args:
110
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
111
+ num_iterations (int): Number of refinement iterations.
112
+
113
+ Returns:
114
+ list: List of activated camera encodings from each iteration.
115
+ """
116
+ B, S, C = pose_tokens.shape # S is expected to be 1.
117
+ pred_pose_enc = None
118
+ pred_pose_enc_list = []
119
+
120
+ for _ in range(num_iterations):
121
+ # Use a learned empty pose for the first iteration.
122
+ if pred_pose_enc is None:
123
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
124
+ else:
125
+ # Detach the previous prediction to avoid backprop through time.
126
+ pred_pose_enc = pred_pose_enc.detach()
127
+ module_input = self.embed_pose(pred_pose_enc)
128
+
129
+ # Generate modulation parameters and split them into shift, scale, and gate components.
130
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
131
+
132
+ # Adaptive layer normalization and modulation.
133
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
134
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
135
+
136
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
137
+ # Compute the delta update for the pose encoding.
138
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
139
+
140
+ if pred_pose_enc is None:
141
+ pred_pose_enc = pred_pose_enc_delta
142
+ else:
143
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
144
+
145
+ # Apply final activation functions for translation, quaternion, and field-of-view.
146
+ activated_pose = activate_pose(
147
+ pred_pose_enc,
148
+ trans_act=self.trans_act,
149
+ quat_act=self.quat_act,
150
+ fl_act=self.fl_act,
151
+ )
152
+ pred_pose_enc_list.append(activated_pose)
153
+
154
+ return pred_pose_enc_list
155
+
156
+
157
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Modulate the input tensor using scaling and shifting parameters.
160
+ """
161
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
162
+ return x * (1 + scale) + shift
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
87
+ ),
88
+ nn.ConvTranspose2d(
89
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
90
+ ),
91
+ nn.Identity(),
92
+ nn.Conv2d(
93
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
94
+ ),
95
+ ]
96
+ )
97
+
98
+ self.scratch = _make_scratch(
99
+ out_channels,
100
+ features,
101
+ expand=False,
102
+ )
103
+
104
+ # Attach additional modules to scratch.
105
+ self.scratch.stem_transpose = None
106
+ self.scratch.refinenet1 = _make_fusion_block(features)
107
+ self.scratch.refinenet2 = _make_fusion_block(features)
108
+ self.scratch.refinenet3 = _make_fusion_block(features)
109
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
110
+
111
+ head_features_1 = features
112
+ head_features_2 = 32
113
+
114
+ if feature_only:
115
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
116
+ else:
117
+ self.scratch.output_conv1 = nn.Conv2d(
118
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
119
+ )
120
+ conv2_in_channels = head_features_1 // 2
121
+
122
+ self.scratch.output_conv2 = nn.Sequential(
123
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ aggregated_tokens_list: List[torch.Tensor],
131
+ images: torch.Tensor,
132
+ patch_start_idx: int,
133
+ frames_chunk_size: int = 8,
134
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
135
+ """
136
+ Forward pass through the DPT head, supports processing by chunking frames.
137
+ Args:
138
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
139
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
140
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
141
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
142
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
143
+ If None or larger than S, all frames are processed at once. Default: 8.
144
+
145
+ Returns:
146
+ Tensor or Tuple[Tensor, Tensor]:
147
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
148
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
149
+ """
150
+ B, S, _, H, W = images.shape
151
+
152
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
153
+ if frames_chunk_size is None or frames_chunk_size >= S:
154
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
155
+
156
+ # Otherwise, process frames in chunks to manage memory usage
157
+ assert frames_chunk_size > 0
158
+
159
+ # Process frames in batches
160
+ all_preds = []
161
+ all_conf = []
162
+
163
+ for frames_start_idx in range(0, S, frames_chunk_size):
164
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
165
+
166
+ # Process batch of frames
167
+ if self.feature_only:
168
+ chunk_output = self._forward_impl(
169
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
170
+ )
171
+ all_preds.append(chunk_output)
172
+ else:
173
+ chunk_preds, chunk_conf = self._forward_impl(
174
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
175
+ )
176
+ all_preds.append(chunk_preds)
177
+ all_conf.append(chunk_conf)
178
+
179
+ # Concatenate results along the sequence dimension
180
+ if self.feature_only:
181
+ return torch.cat(all_preds, dim=1)
182
+ else:
183
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
184
+
185
+ def _forward_impl(
186
+ self,
187
+ aggregated_tokens_list: List[torch.Tensor],
188
+ images: torch.Tensor,
189
+ patch_start_idx: int,
190
+ frames_start_idx: int = None,
191
+ frames_end_idx: int = None,
192
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ """
194
+ Implementation of the forward pass through the DPT head.
195
+
196
+ This method processes a specific chunk of frames from the sequence.
197
+
198
+ Args:
199
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
200
+ images (Tensor): Input images with shape [B, S, 3, H, W].
201
+ patch_start_idx (int): Starting index for patch tokens.
202
+ frames_start_idx (int, optional): Starting index for frames to process.
203
+ frames_end_idx (int, optional): Ending index for frames to process.
204
+
205
+ Returns:
206
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
207
+ """
208
+ if frames_start_idx is not None and frames_end_idx is not None:
209
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
210
+
211
+ B, S, _, H, W = images.shape
212
+
213
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
214
+
215
+ out = []
216
+ dpt_idx = 0
217
+
218
+ for layer_idx in self.intermediate_layer_idx:
219
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
220
+
221
+ # Select frames if processing a chunk
222
+ if frames_start_idx is not None and frames_end_idx is not None:
223
+ x = x[:, frames_start_idx:frames_end_idx]
224
+
225
+ x = x.reshape(B * S, -1, x.shape[-1])
226
+
227
+ x = self.norm(x)
228
+
229
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
230
+
231
+ x = self.projects[dpt_idx](x)
232
+ if self.pos_embed:
233
+ x = self._apply_pos_embed(x, W, H)
234
+ x = self.resize_layers[dpt_idx](x)
235
+
236
+ out.append(x)
237
+ dpt_idx += 1
238
+
239
+ # Fuse features from multiple layers.
240
+ out = self.scratch_forward(out)
241
+ # Interpolate fused output to match target image resolution.
242
+ out = custom_interpolate(
243
+ out,
244
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
245
+ mode="bilinear",
246
+ align_corners=True,
247
+ )
248
+
249
+ if self.pos_embed:
250
+ out = self._apply_pos_embed(out, W, H)
251
+
252
+ if self.feature_only:
253
+ return out.reshape(B, S, *out.shape[1:])
254
+
255
+ out = self.scratch.output_conv2(out)
256
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
257
+
258
+ preds = preds.reshape(B, S, *preds.shape[1:])
259
+ conf = conf.reshape(B, S, *conf.shape[1:])
260
+ return preds, conf
261
+
262
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
263
+ """
264
+ Apply positional embedding to tensor x.
265
+ """
266
+ patch_w = x.shape[-1]
267
+ patch_h = x.shape[-2]
268
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
269
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
270
+ pos_embed = pos_embed * ratio
271
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
272
+ return x + pos_embed
273
+
274
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
275
+ """
276
+ Forward pass through the fusion blocks.
277
+
278
+ Args:
279
+ features (List[Tensor]): List of feature maps from different layers.
280
+
281
+ Returns:
282
+ Tensor: Fused feature map.
283
+ """
284
+ layer_1, layer_2, layer_3, layer_4 = features
285
+
286
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
287
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
288
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
289
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
290
+
291
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
292
+ del layer_4_rn, layer_4
293
+
294
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
295
+ del layer_3_rn, layer_3
296
+
297
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
298
+ del layer_2_rn, layer_2
299
+
300
+ out = self.scratch.refinenet1(out, layer_1_rn)
301
+ del layer_1_rn, layer_1
302
+
303
+ out = self.scratch.output_conv1(out)
304
+ return out
305
+
306
+
307
+ ################################################################################
308
+ # Modules
309
+ ################################################################################
310
+
311
+
312
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
313
+ return FeatureFusionBlock(
314
+ features,
315
+ nn.ReLU(inplace=True),
316
+ deconv=False,
317
+ bn=False,
318
+ expand=False,
319
+ align_corners=True,
320
+ size=size,
321
+ has_residual=has_residual,
322
+ groups=groups,
323
+ )
324
+
325
+
326
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
327
+ scratch = nn.Module()
328
+ out_shape1 = out_shape
329
+ out_shape2 = out_shape
330
+ out_shape3 = out_shape
331
+ if len(in_shape) >= 4:
332
+ out_shape4 = out_shape
333
+
334
+ if expand:
335
+ out_shape1 = out_shape
336
+ out_shape2 = out_shape * 2
337
+ out_shape3 = out_shape * 4
338
+ if len(in_shape) >= 4:
339
+ out_shape4 = out_shape * 8
340
+
341
+ scratch.layer1_rn = nn.Conv2d(
342
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
343
+ )
344
+ scratch.layer2_rn = nn.Conv2d(
345
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
346
+ )
347
+ scratch.layer3_rn = nn.Conv2d(
348
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
349
+ )
350
+ if len(in_shape) >= 4:
351
+ scratch.layer4_rn = nn.Conv2d(
352
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
353
+ )
354
+ return scratch
355
+
356
+
357
+ class ResidualConvUnit(nn.Module):
358
+ """Residual convolution module."""
359
+
360
+ def __init__(self, features, activation, bn, groups=1):
361
+ """Init.
362
+
363
+ Args:
364
+ features (int): number of features
365
+ """
366
+ super().__init__()
367
+
368
+ self.bn = bn
369
+ self.groups = groups
370
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
371
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
372
+
373
+ self.norm1 = None
374
+ self.norm2 = None
375
+
376
+ self.activation = activation
377
+ self.skip_add = nn.quantized.FloatFunctional()
378
+
379
+ def forward(self, x):
380
+ """Forward pass.
381
+
382
+ Args:
383
+ x (tensor): input
384
+
385
+ Returns:
386
+ tensor: output
387
+ """
388
+
389
+ out = self.activation(x)
390
+ out = self.conv1(out)
391
+ if self.norm1 is not None:
392
+ out = self.norm1(out)
393
+
394
+ out = self.activation(out)
395
+ out = self.conv2(out)
396
+ if self.norm2 is not None:
397
+ out = self.norm2(out)
398
+
399
+ return self.skip_add.add(out, x)
400
+
401
+
402
+ class FeatureFusionBlock(nn.Module):
403
+ """Feature fusion block."""
404
+
405
+ def __init__(
406
+ self,
407
+ features,
408
+ activation,
409
+ deconv=False,
410
+ bn=False,
411
+ expand=False,
412
+ align_corners=True,
413
+ size=None,
414
+ has_residual=True,
415
+ groups=1,
416
+ ):
417
+ """Init.
418
+
419
+ Args:
420
+ features (int): number of features
421
+ """
422
+ super(FeatureFusionBlock, self).__init__()
423
+
424
+ self.deconv = deconv
425
+ self.align_corners = align_corners
426
+ self.groups = groups
427
+ self.expand = expand
428
+ out_features = features
429
+ if self.expand == True:
430
+ out_features = features // 2
431
+
432
+ self.out_conv = nn.Conv2d(
433
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
434
+ )
435
+
436
+ if has_residual:
437
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
438
+
439
+ self.has_residual = has_residual
440
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
441
+
442
+ self.skip_add = nn.quantized.FloatFunctional()
443
+ self.size = size
444
+
445
+ def forward(self, *xs, size=None):
446
+ """Forward pass.
447
+
448
+ Returns:
449
+ tensor: output
450
+ """
451
+ output = xs[0]
452
+
453
+ if self.has_residual:
454
+ res = self.resConfUnit1(xs[1])
455
+ output = self.skip_add.add(output, res)
456
+
457
+ output = self.resConfUnit2(output)
458
+
459
+ if (size is None) and (self.size is None):
460
+ modifier = {"scale_factor": 2}
461
+ elif size is None:
462
+ modifier = {"size": self.size}
463
+ else:
464
+ modifier = {"size": size}
465
+
466
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
467
+ output = self.out_conv(output)
468
+
469
+ return output
470
+
471
+
472
+ def custom_interpolate(
473
+ x: torch.Tensor,
474
+ size: Tuple[int, int] = None,
475
+ scale_factor: float = None,
476
+ mode: str = "bilinear",
477
+ align_corners: bool = True,
478
+ ) -> torch.Tensor:
479
+ """
480
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
481
+ """
482
+ if size is None:
483
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
484
+
485
+ INT_MAX = 1610612736
486
+
487
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
488
+
489
+ if input_elements > INT_MAX:
490
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
491
+ interpolated_chunks = [
492
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
493
+ ]
494
+ x = torch.cat(interpolated_chunks, dim=0)
495
+ return x.contiguous()
496
+ else:
497
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_head.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ from .dpt_head import DPTHead
9
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
10
+
11
+
12
+ class TrackHead(nn.Module):
13
+ """
14
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
15
+ The tracking is performed iteratively, refining predictions over multiple iterations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim_in,
21
+ patch_size=14,
22
+ features=128,
23
+ iters=4,
24
+ predict_conf=True,
25
+ stride=2,
26
+ corr_levels=7,
27
+ corr_radius=4,
28
+ hidden_size=384,
29
+ ):
30
+ """
31
+ Initialize the TrackHead module.
32
+
33
+ Args:
34
+ dim_in (int): Input dimension of tokens from the backbone.
35
+ patch_size (int): Size of image patches used in the vision transformer.
36
+ features (int): Number of feature channels in the feature extractor output.
37
+ iters (int): Number of refinement iterations for tracking predictions.
38
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
39
+ stride (int): Stride value for the tracker predictor.
40
+ corr_levels (int): Number of correlation pyramid levels
41
+ corr_radius (int): Radius for correlation computation, controlling the search area.
42
+ hidden_size (int): Size of hidden layers in the tracker network.
43
+ """
44
+ super().__init__()
45
+
46
+ self.patch_size = patch_size
47
+
48
+ # Feature extractor based on DPT architecture
49
+ # Processes tokens into feature maps for tracking
50
+ self.feature_extractor = DPTHead(
51
+ dim_in=dim_in,
52
+ patch_size=patch_size,
53
+ features=features,
54
+ feature_only=True, # Only output features, no activation
55
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
56
+ pos_embed=False,
57
+ )
58
+
59
+ # Tracker module that predicts point trajectories
60
+ # Takes feature maps and predicts coordinates and visibility
61
+ self.tracker = BaseTrackerPredictor(
62
+ latent_dim=features, # Match the output_dim of feature extractor
63
+ predict_conf=predict_conf,
64
+ stride=stride,
65
+ corr_levels=corr_levels,
66
+ corr_radius=corr_radius,
67
+ hidden_size=hidden_size,
68
+ )
69
+
70
+ self.iters = iters
71
+
72
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
73
+ """
74
+ Forward pass of the TrackHead.
75
+
76
+ Args:
77
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
78
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
79
+ B = batch size, S = sequence length.
80
+ patch_start_idx (int): Starting index for patch tokens.
81
+ query_points (torch.Tensor, optional): Initial query points to track.
82
+ If None, points are initialized by the tracker.
83
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
84
+
85
+ Returns:
86
+ tuple:
87
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
88
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
89
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
90
+ """
91
+ B, S, _, H, W = images.shape
92
+
93
+ # Extract features from tokens
94
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
95
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
96
+
97
+ # Use default iterations if not specified
98
+ if iters is None:
99
+ iters = self.iters
100
+
101
+ # Perform tracking using the extracted features
102
+ coord_preds, vis_scores, conf_scores = self.tracker(
103
+ query_points=query_points,
104
+ fmaps=feature_maps,
105
+ iters=iters,
106
+ )
107
+
108
+ return coord_preds, vis_scores, conf_scores
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ from .blocks import EfficientUpdateFormer, CorrBlock
13
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
14
+ from .modules import Mlp
15
+
16
+
17
+ class BaseTrackerPredictor(nn.Module):
18
+ def __init__(
19
+ self,
20
+ stride=1,
21
+ corr_levels=5,
22
+ corr_radius=4,
23
+ latent_dim=128,
24
+ hidden_size=384,
25
+ use_spaceatt=True,
26
+ depth=6,
27
+ max_scale=518,
28
+ predict_conf=True,
29
+ ):
30
+ super(BaseTrackerPredictor, self).__init__()
31
+ """
32
+ The base template to create a track predictor
33
+
34
+ Modified from https://github.com/facebookresearch/co-tracker/
35
+ and https://github.com/facebookresearch/vggsfm
36
+ """
37
+
38
+ self.stride = stride
39
+ self.latent_dim = latent_dim
40
+ self.corr_levels = corr_levels
41
+ self.corr_radius = corr_radius
42
+ self.hidden_size = hidden_size
43
+ self.max_scale = max_scale
44
+ self.predict_conf = predict_conf
45
+
46
+ self.flows_emb_dim = latent_dim // 2
47
+
48
+ self.corr_mlp = Mlp(
49
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
50
+ hidden_features=self.hidden_size,
51
+ out_features=self.latent_dim,
52
+ )
53
+
54
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
55
+
56
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
57
+
58
+ space_depth = depth if use_spaceatt else 0
59
+ time_depth = depth
60
+
61
+ self.updateformer = EfficientUpdateFormer(
62
+ space_depth=space_depth,
63
+ time_depth=time_depth,
64
+ input_dim=self.transformer_dim,
65
+ hidden_size=self.hidden_size,
66
+ output_dim=self.latent_dim + 2,
67
+ mlp_ratio=4.0,
68
+ add_space_attn=use_spaceatt,
69
+ )
70
+
71
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
72
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
73
+
74
+ # A linear layer to update track feats at each iteration
75
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
76
+
77
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
78
+
79
+ if predict_conf:
80
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
81
+
82
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
83
+ """
84
+ query_points: B x N x 2, the number of batches, tracks, and xy
85
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
86
+ note HH and WW is the size of feature maps instead of original images
87
+ """
88
+ B, N, D = query_points.shape
89
+ B, S, C, HH, WW = fmaps.shape
90
+
91
+ assert D == 2, "Input points must be 2D coordinates"
92
+
93
+ # apply a layernorm to fmaps here
94
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
95
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
96
+
97
+ # Scale the input query_points because we may downsample the images
98
+ # by down_ratio or self.stride
99
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
100
+ # its query_points should be query_points/4
101
+ if down_ratio > 1:
102
+ query_points = query_points / float(down_ratio)
103
+
104
+ query_points = query_points / float(self.stride)
105
+
106
+ # Init with coords as the query points
107
+ # It means the search will start from the position of query points at the reference frames
108
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
109
+
110
+ # Sample/extract the features of the query points in the query frame
111
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
112
+
113
+ # init track feats by query feats
114
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
115
+ # back up the init coords
116
+ coords_backup = coords.clone()
117
+
118
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
119
+
120
+ coord_preds = []
121
+
122
+ # Iterative Refinement
123
+ for _ in range(iters):
124
+ # Detach the gradients from the last iteration
125
+ # (in my experience, not very important for performance)
126
+ coords = coords.detach()
127
+
128
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
129
+
130
+ corr_dim = fcorrs.shape[3]
131
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
132
+ fcorrs_ = self.corr_mlp(fcorrs_)
133
+
134
+ # Movement of current coords relative to query points
135
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
136
+
137
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
138
+
139
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
140
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
141
+
142
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
143
+
144
+ # Concatenate them as the input for the transformers
145
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
146
+
147
+ # 2D positional embed
148
+ # TODO: this can be much simplified
149
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
150
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
151
+
152
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
153
+
154
+ x = transformer_input + sampled_pos_emb
155
+
156
+ # Add the query ref token to the track feats
157
+ query_ref_token = torch.cat(
158
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
159
+ )
160
+ x = x + query_ref_token.to(x.device).to(x.dtype)
161
+
162
+ # B, N, S, C
163
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
164
+
165
+ # Compute the delta coordinates and delta track features
166
+ delta, _ = self.updateformer(x)
167
+
168
+ # BN, S, C
169
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
170
+ delta_coords_ = delta[:, :, :2]
171
+ delta_feats_ = delta[:, :, 2:]
172
+
173
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
174
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
175
+
176
+ # Update the track features
177
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
178
+
179
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
180
+
181
+ # B x S x N x 2
182
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
183
+
184
+ # Force coord0 as query
185
+ # because we assume the query points should not be changed
186
+ coords[:, 0] = coords_backup[:, 0]
187
+
188
+ # The predicted tracks are in the original image scale
189
+ if down_ratio > 1:
190
+ coord_preds.append(coords * self.stride * down_ratio)
191
+ else:
192
+ coord_preds.append(coords * self.stride)
193
+
194
+ # B, S, N
195
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
196
+ if apply_sigmoid:
197
+ vis_e = torch.sigmoid(vis_e)
198
+
199
+ if self.predict_conf:
200
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
201
+ if apply_sigmoid:
202
+ conf_e = torch.sigmoid(conf_e)
203
+ else:
204
+ conf_e = None
205
+
206
+ if return_feat:
207
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
208
+ else:
209
+ return coord_preds, vis_e, conf_e
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .utils import bilinear_sampler
16
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
17
+
18
+
19
+ class EfficientUpdateFormer(nn.Module):
20
+ """
21
+ Transformer model that updates track estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ space_depth=6,
27
+ time_depth=6,
28
+ input_dim=320,
29
+ hidden_size=384,
30
+ num_heads=8,
31
+ output_dim=130,
32
+ mlp_ratio=4.0,
33
+ add_space_attn=True,
34
+ num_virtual_tracks=64,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.out_channels = 2
39
+ self.num_heads = num_heads
40
+ self.hidden_size = hidden_size
41
+ self.add_space_attn = add_space_attn
42
+
43
+ # Add input LayerNorm before linear projection
44
+ self.input_norm = nn.LayerNorm(input_dim)
45
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
46
+
47
+ # Add output LayerNorm before final projection
48
+ self.output_norm = nn.LayerNorm(hidden_size)
49
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
50
+ self.num_virtual_tracks = num_virtual_tracks
51
+
52
+ if self.add_space_attn:
53
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
54
+ else:
55
+ self.virual_tracks = None
56
+
57
+ self.time_blocks = nn.ModuleList(
58
+ [
59
+ AttnBlock(
60
+ hidden_size,
61
+ num_heads,
62
+ mlp_ratio=mlp_ratio,
63
+ attn_class=nn.MultiheadAttention,
64
+ )
65
+ for _ in range(time_depth)
66
+ ]
67
+ )
68
+
69
+ if add_space_attn:
70
+ self.space_virtual_blocks = nn.ModuleList(
71
+ [
72
+ AttnBlock(
73
+ hidden_size,
74
+ num_heads,
75
+ mlp_ratio=mlp_ratio,
76
+ attn_class=nn.MultiheadAttention,
77
+ )
78
+ for _ in range(space_depth)
79
+ ]
80
+ )
81
+ self.space_point2virtual_blocks = nn.ModuleList(
82
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
83
+ )
84
+ self.space_virtual2point_blocks = nn.ModuleList(
85
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
86
+ )
87
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
88
+ self.initialize_weights()
89
+
90
+ def initialize_weights(self):
91
+ def _basic_init(module):
92
+ if isinstance(module, nn.Linear):
93
+ torch.nn.init.xavier_uniform_(module.weight)
94
+ if module.bias is not None:
95
+ nn.init.constant_(module.bias, 0)
96
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
97
+
98
+ self.apply(_basic_init)
99
+
100
+ def forward(self, input_tensor, mask=None):
101
+ # Apply input LayerNorm
102
+ input_tensor = self.input_norm(input_tensor)
103
+ tokens = self.input_transform(input_tensor)
104
+
105
+ init_tokens = tokens
106
+
107
+ B, _, T, _ = tokens.shape
108
+
109
+ if self.add_space_attn:
110
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
111
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
112
+
113
+ _, N, _, _ = tokens.shape
114
+
115
+ j = 0
116
+ for i in range(len(self.time_blocks)):
117
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
118
+
119
+ time_tokens = self.time_blocks[i](time_tokens)
120
+
121
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
122
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
123
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
124
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
125
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
126
+
127
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
128
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
129
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
130
+
131
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
132
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
133
+ j += 1
134
+
135
+ if self.add_space_attn:
136
+ tokens = tokens[:, : N - self.num_virtual_tracks]
137
+
138
+ tokens = tokens + init_tokens
139
+
140
+ # Apply output LayerNorm before final projection
141
+ tokens = self.output_norm(tokens)
142
+ flow = self.flow_head(tokens)
143
+
144
+ return flow, None
145
+
146
+
147
+ class CorrBlock:
148
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
149
+ """
150
+ Build a pyramid of feature maps from the input.
151
+
152
+ fmaps: Tensor (B, S, C, H, W)
153
+ num_levels: number of pyramid levels (each downsampled by factor 2)
154
+ radius: search radius for sampling correlation
155
+ multiple_track_feats: if True, split the target features per pyramid level
156
+ padding_mode: passed to grid_sample / bilinear_sampler
157
+ """
158
+ B, S, C, H, W = fmaps.shape
159
+ self.S, self.C, self.H, self.W = S, C, H, W
160
+ self.num_levels = num_levels
161
+ self.radius = radius
162
+ self.padding_mode = padding_mode
163
+ self.multiple_track_feats = multiple_track_feats
164
+
165
+ # Build pyramid: each level is half the spatial resolution of the previous
166
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
167
+ current_fmaps = fmaps
168
+ for i in range(num_levels - 1):
169
+ B, S, C, H, W = current_fmaps.shape
170
+ # Merge batch & sequence dimensions
171
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
172
+ # Avg pool down by factor 2
173
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
174
+ _, _, H_new, W_new = current_fmaps.shape
175
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
176
+ self.fmaps_pyramid.append(current_fmaps)
177
+
178
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
179
+ # This grid is added to the (scaled) coordinate centroids.
180
+ r = self.radius
181
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
182
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
183
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
184
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
185
+
186
+ def corr_sample(self, targets, coords):
187
+ """
188
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
189
+ volume, sample it immediately, then discard it. This saves GPU memory.
190
+
191
+ Args:
192
+ targets: Tensor (B, S, N, C) — features for the current targets.
193
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
194
+
195
+ Returns:
196
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
197
+ """
198
+ B, S, N, C = targets.shape
199
+
200
+ # If you have multiple track features, split them per level.
201
+ if self.multiple_track_feats:
202
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
203
+
204
+ out_pyramid = []
205
+ for i, fmaps in enumerate(self.fmaps_pyramid):
206
+ # Get current spatial resolution H, W for this pyramid level.
207
+ B, S, C, H, W = fmaps.shape
208
+ # Reshape feature maps for correlation computation:
209
+ # fmap2s: (B, S, C, H*W)
210
+ fmap2s = fmaps.view(B, S, C, H * W)
211
+ # Choose appropriate target features.
212
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
213
+
214
+ # Compute correlation directly
215
+ corrs = compute_corr_level(fmap1, fmap2s, C)
216
+ corrs = corrs.view(B, S, N, H, W)
217
+
218
+ # Prepare sampling grid:
219
+ # Scale down the coordinates for the current level.
220
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
221
+ # Make sure our precomputed delta grid is on the same device/dtype.
222
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
223
+ # Now the grid for grid_sample is:
224
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
225
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
226
+
227
+ # Sample from the correlation volume using bilinear interpolation.
228
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
229
+ corrs_sampled = bilinear_sampler(
230
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
231
+ )
232
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
233
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
234
+ out_pyramid.append(corrs_sampled)
235
+
236
+ # Concatenate all levels along the last dimension.
237
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
238
+ return out
239
+
240
+
241
+ def compute_corr_level(fmap1, fmap2s, C):
242
+ # fmap1: (B, S, N, C)
243
+ # fmap2s: (B, S, C, H*W)
244
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
245
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
246
+ return corrs / math.sqrt(C)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes,
49
+ planes,
50
+ kernel_size=kernel_size,
51
+ padding=1,
52
+ stride=stride,
53
+ padding_mode="zeros",
54
+ )
55
+ self.conv2 = nn.Conv2d(
56
+ planes,
57
+ planes,
58
+ kernel_size=kernel_size,
59
+ padding=1,
60
+ padding_mode="zeros",
61
+ )
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ num_groups = planes // 8
65
+
66
+ if norm_fn == "group":
67
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
68
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
69
+ if not stride == 1:
70
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
71
+
72
+ elif norm_fn == "batch":
73
+ self.norm1 = nn.BatchNorm2d(planes)
74
+ self.norm2 = nn.BatchNorm2d(planes)
75
+ if not stride == 1:
76
+ self.norm3 = nn.BatchNorm2d(planes)
77
+
78
+ elif norm_fn == "instance":
79
+ self.norm1 = nn.InstanceNorm2d(planes)
80
+ self.norm2 = nn.InstanceNorm2d(planes)
81
+ if not stride == 1:
82
+ self.norm3 = nn.InstanceNorm2d(planes)
83
+
84
+ elif norm_fn == "none":
85
+ self.norm1 = nn.Sequential()
86
+ self.norm2 = nn.Sequential()
87
+ if not stride == 1:
88
+ self.norm3 = nn.Sequential()
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ if stride == 1:
93
+ self.downsample = None
94
+ else:
95
+ self.downsample = nn.Sequential(
96
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
97
+ self.norm3,
98
+ )
99
+
100
+ def forward(self, x):
101
+ y = x
102
+ y = self.relu(self.norm1(self.conv1(y)))
103
+ y = self.relu(self.norm2(self.conv2(y)))
104
+
105
+ if self.downsample is not None:
106
+ x = self.downsample(x)
107
+
108
+ return self.relu(x + y)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_features,
117
+ hidden_features=None,
118
+ out_features=None,
119
+ act_layer=nn.GELU,
120
+ norm_layer=None,
121
+ bias=True,
122
+ drop=0.0,
123
+ use_conv=False,
124
+ ):
125
+ super().__init__()
126
+ out_features = out_features or in_features
127
+ hidden_features = hidden_features or in_features
128
+ bias = to_2tuple(bias)
129
+ drop_probs = to_2tuple(drop)
130
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
131
+
132
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
133
+ self.act = act_layer()
134
+ self.drop1 = nn.Dropout(drop_probs[0])
135
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
136
+ self.drop2 = nn.Dropout(drop_probs[1])
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = self.act(x)
141
+ x = self.drop1(x)
142
+ x = self.fc2(x)
143
+ x = self.drop2(x)
144
+ return x
145
+
146
+
147
+ class AttnBlock(nn.Module):
148
+ def __init__(
149
+ self,
150
+ hidden_size,
151
+ num_heads,
152
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
153
+ mlp_ratio=4.0,
154
+ **block_kwargs
155
+ ):
156
+ """
157
+ Self attention block
158
+ """
159
+ super().__init__()
160
+
161
+ self.norm1 = nn.LayerNorm(hidden_size)
162
+ self.norm2 = nn.LayerNorm(hidden_size)
163
+
164
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
165
+
166
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
167
+
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
169
+
170
+ def forward(self, x, mask=None):
171
+ # Prepare the mask for PyTorch's attention (it expects a different format)
172
+ # attn_mask = mask if mask is not None else None
173
+ # Normalize before attention
174
+ x = self.norm1(x)
175
+
176
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
177
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
178
+
179
+ attn_output, _ = self.attn(x, x, x)
180
+
181
+ # Add & Norm
182
+ x = x + attn_output
183
+ x = x + self.mlp(self.norm2(x))
184
+ return x
185
+
186
+
187
+ class CrossAttnBlock(nn.Module):
188
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
189
+ """
190
+ Cross attention block
191
+ """
192
+ super().__init__()
193
+
194
+ self.norm1 = nn.LayerNorm(hidden_size)
195
+ self.norm_context = nn.LayerNorm(hidden_size)
196
+ self.norm2 = nn.LayerNorm(hidden_size)
197
+
198
+ self.cross_attn = nn.MultiheadAttention(
199
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
200
+ )
201
+
202
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
203
+
204
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
205
+
206
+ def forward(self, x, context, mask=None):
207
+ # Normalize inputs
208
+ x = self.norm1(x)
209
+ context = self.norm_context(context)
210
+
211
+ # Apply cross attention
212
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
213
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
214
+
215
+ # Add & Norm
216
+ x = x + attn_output
217
+ x = x + self.mlp(self.norm2(x))
218
+ return x
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/vggsfm
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
19
+ """
20
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
21
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
22
+ Args:
23
+ - embed_dim: The embedding dimension.
24
+ - grid_size: The grid size.
25
+ Returns:
26
+ - pos_embed: The generated 2D positional embedding.
27
+ """
28
+ if isinstance(grid_size, tuple):
29
+ grid_size_h, grid_size_w = grid_size
30
+ else:
31
+ grid_size_h = grid_size_w = grid_size
32
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
33
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
34
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = torch.stack(grid, dim=0)
36
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
37
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
38
+ if return_grid:
39
+ return (
40
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
41
+ grid,
42
+ )
43
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
44
+
45
+
46
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
49
+
50
+ Args:
51
+ - embed_dim: The embedding dimension.
52
+ - grid: The grid to generate the embedding from.
53
+
54
+ Returns:
55
+ - emb: The generated 2D positional embedding.
56
+ """
57
+ assert embed_dim % 2 == 0
58
+
59
+ # use half of dimensions to encode grid_h
60
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
61
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
62
+
63
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
64
+ return emb
65
+
66
+
67
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
70
+
71
+ Args:
72
+ - embed_dim: The embedding dimension.
73
+ - pos: The position to generate the embedding from.
74
+
75
+ Returns:
76
+ - emb: The generated 1D positional embedding.
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = torch.sin(out) # (M, D/2)
87
+ emb_cos = torch.cos(out) # (M, D/2)
88
+
89
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
90
+ return emb[None].float()
91
+
92
+
93
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
94
+ """
95
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
96
+
97
+ Args:
98
+ - xy: The coordinates to generate the embedding from.
99
+ - C: The size of the embedding.
100
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
101
+
102
+ Returns:
103
+ - pe: The generated 2D positional embedding.
104
+ """
105
+ B, N, D = xy.shape
106
+ assert D == 2
107
+
108
+ x = xy[:, :, 0:1]
109
+ y = xy[:, :, 1:2]
110
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
111
+
112
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
113
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
114
+
115
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
116
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
117
+
118
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
119
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
120
+
121
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
122
+ if cat_coords:
123
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
124
+ return pe
125
+
126
+
127
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
128
+ r"""Sample a tensor using bilinear interpolation
129
+
130
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
131
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
132
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
133
+ convention.
134
+
135
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
136
+ :math:`B` is the batch size, :math:`C` is the number of channels,
137
+ :math:`H` is the height of the image, and :math:`W` is the width of the
138
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
139
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
140
+
141
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
142
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
143
+ that in this case the order of the components is slightly different
144
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
145
+
146
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
147
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
148
+ left-most image pixel :math:`W-1` to the center of the right-most
149
+ pixel.
150
+
151
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
152
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
153
+ the left-most pixel :math:`W` to the right edge of the right-most
154
+ pixel.
155
+
156
+ Similar conventions apply to the :math:`y` for the range
157
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
158
+ :math:`[0,T-1]` and :math:`[0,T]`.
159
+
160
+ Args:
161
+ input (Tensor): batch of input images.
162
+ coords (Tensor): batch of coordinates.
163
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
164
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
165
+
166
+ Returns:
167
+ Tensor: sampled points.
168
+ """
169
+ coords = coords.detach().clone()
170
+ ############################################################
171
+ # IMPORTANT:
172
+ coords = coords.to(input.device).to(input.dtype)
173
+ ############################################################
174
+
175
+ sizes = input.shape[2:]
176
+
177
+ assert len(sizes) in [2, 3]
178
+
179
+ if len(sizes) == 3:
180
+ # t x y -> x y t to match dimensions T H W in grid_sample
181
+ coords = coords[..., [1, 2, 0]]
182
+
183
+ if align_corners:
184
+ scale = torch.tensor(
185
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
186
+ )
187
+ else:
188
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
189
+
190
+ coords.mul_(scale) # coords = coords * scale
191
+ coords.sub_(1) # coords = coords - 1
192
+
193
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
194
+
195
+
196
+ def sample_features4d(input, coords):
197
+ r"""Sample spatial features
198
+
199
+ `sample_features4d(input, coords)` samples the spatial features
200
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
201
+
202
+ The field is sampled at coordinates :attr:`coords` using bilinear
203
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
204
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
205
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
206
+
207
+ The output tensor has one feature per point, and has shape :math:`(B,
208
+ R, C)`.
209
+
210
+ Args:
211
+ input (Tensor): spatial features.
212
+ coords (Tensor): points.
213
+
214
+ Returns:
215
+ Tensor: sampled features.
216
+ """
217
+
218
+ B, _, _, _ = input.shape
219
+
220
+ # B R 2 -> B R 1 2
221
+ coords = coords.unsqueeze(2)
222
+
223
+ # B C R 1
224
+ feats = bilinear_sampler(input, coords)
225
+
226
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/heads/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
49
+ omega /= embed_dim / 2.0
50
+ omega = 1.0 / omega_0**omega # (D/2,)
51
+
52
+ pos = pos.reshape(-1) # (M,)
53
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
54
+
55
+ emb_sin = torch.sin(out) # (M, D/2)
56
+ emb_cos = torch.cos(out) # (M, D/2)
57
+
58
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
59
+ return emb.float()
60
+
61
+
62
+ # Inspired by https://github.com/microsoft/moge
63
+
64
+
65
+ def create_uv_grid(
66
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
67
+ ) -> torch.Tensor:
68
+ """
69
+ Create a normalized UV grid of shape (width, height, 2).
70
+
71
+ The grid spans horizontally and vertically according to an aspect ratio,
72
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
73
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
74
+
75
+ Args:
76
+ width (int): Number of points horizontally.
77
+ height (int): Number of points vertically.
78
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
79
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
80
+ device (torch.device, optional): Device on which the tensor is created.
81
+
82
+ Returns:
83
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
84
+ """
85
+ # Derive aspect ratio if not explicitly provided
86
+ if aspect_ratio is None:
87
+ aspect_ratio = float(width) / float(height)
88
+
89
+ # Compute normalized spans for X and Y
90
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
91
+ span_x = aspect_ratio / diag_factor
92
+ span_y = 1.0 / diag_factor
93
+
94
+ # Establish the linspace boundaries
95
+ left_x = -span_x * (width - 1) / width
96
+ right_x = span_x * (width - 1) / width
97
+ top_y = -span_y * (height - 1) / height
98
+ bottom_y = span_y * (height - 1) / height
99
+
100
+ # Generate 1D coordinates
101
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
102
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
103
+
104
+ # Create 2D meshgrid (width x height) and stack into UV
105
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
106
+ uv_grid = torch.stack((uu, vv), dim=-1)
107
+
108
+ return uv_grid
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/models/aggregator.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from vggt.layers import PatchEmbed
14
+ from vggt.layers.block import Block
15
+ from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
16
+ from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
21
+ _RESNET_STD = [0.229, 0.224, 0.225]
22
+
23
+
24
+ class Aggregator(nn.Module):
25
+ """
26
+ The Aggregator applies alternating-attention over input frames,
27
+ as described in VGGT: Visual Geometry Grounded Transformer.
28
+
29
+
30
+ Args:
31
+ img_size (int): Image size in pixels.
32
+ patch_size (int): Size of each patch for PatchEmbed.
33
+ embed_dim (int): Dimension of the token embeddings.
34
+ depth (int): Number of blocks.
35
+ num_heads (int): Number of attention heads.
36
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
37
+ num_register_tokens (int): Number of register tokens.
38
+ block_fn (nn.Module): The block type used for attention (Block by default).
39
+ qkv_bias (bool): Whether to include bias in QKV projections.
40
+ proj_bias (bool): Whether to include bias in the output projection.
41
+ ffn_bias (bool): Whether to include bias in MLP layers.
42
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
43
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
44
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
45
+ qk_norm (bool): Whether to apply QK normalization.
46
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
47
+ init_values (float): Init scale for layer scale.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ img_size=518,
53
+ patch_size=14,
54
+ embed_dim=1024,
55
+ depth=24,
56
+ num_heads=16,
57
+ mlp_ratio=4.0,
58
+ num_register_tokens=4,
59
+ block_fn=Block,
60
+ qkv_bias=True,
61
+ proj_bias=True,
62
+ ffn_bias=True,
63
+ patch_embed="dinov2_vitl14_reg",
64
+ aa_order=["frame", "global"],
65
+ aa_block_size=1,
66
+ qk_norm=True,
67
+ rope_freq=100,
68
+ init_values=0.01,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
73
+
74
+ # Initialize rotary position embedding if frequency > 0
75
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
76
+ self.position_getter = PositionGetter() if self.rope is not None else None
77
+
78
+ self.frame_blocks = nn.ModuleList(
79
+ [
80
+ block_fn(
81
+ dim=embed_dim,
82
+ num_heads=num_heads,
83
+ mlp_ratio=mlp_ratio,
84
+ qkv_bias=qkv_bias,
85
+ proj_bias=proj_bias,
86
+ ffn_bias=ffn_bias,
87
+ init_values=init_values,
88
+ qk_norm=qk_norm,
89
+ rope=self.rope,
90
+ )
91
+ for _ in range(depth)
92
+ ]
93
+ )
94
+
95
+ self.global_blocks = nn.ModuleList(
96
+ [
97
+ block_fn(
98
+ dim=embed_dim,
99
+ num_heads=num_heads,
100
+ mlp_ratio=mlp_ratio,
101
+ qkv_bias=qkv_bias,
102
+ proj_bias=proj_bias,
103
+ ffn_bias=ffn_bias,
104
+ init_values=init_values,
105
+ qk_norm=qk_norm,
106
+ rope=self.rope,
107
+ )
108
+ for _ in range(depth)
109
+ ]
110
+ )
111
+
112
+ self.depth = depth
113
+ self.aa_order = aa_order
114
+ self.patch_size = patch_size
115
+ self.aa_block_size = aa_block_size
116
+
117
+ # Validate that depth is divisible by aa_block_size
118
+ if self.depth % self.aa_block_size != 0:
119
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
120
+
121
+ self.aa_block_num = self.depth // self.aa_block_size
122
+
123
+ # Note: We have two camera tokens, one for the first frame and one for the rest
124
+ # The same applies for register tokens
125
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
126
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
127
+
128
+ # The patch tokens start after the camera and register tokens
129
+ self.patch_start_idx = 1 + num_register_tokens
130
+
131
+ # Initialize parameters with small values
132
+ nn.init.normal_(self.camera_token, std=1e-6)
133
+ nn.init.normal_(self.register_token, std=1e-6)
134
+
135
+ # Register normalization constants as buffers
136
+ for name, value in (
137
+ ("_resnet_mean", _RESNET_MEAN),
138
+ ("_resnet_std", _RESNET_STD),
139
+ ):
140
+ self.register_buffer(
141
+ name,
142
+ torch.FloatTensor(value).reshape(1, 1, 3, 1, 1),
143
+ persistent=False,
144
+ )
145
+
146
+ def __build_patch_embed__(
147
+ self,
148
+ patch_embed,
149
+ img_size,
150
+ patch_size,
151
+ num_register_tokens,
152
+ interpolate_antialias=True,
153
+ interpolate_offset=0.0,
154
+ block_chunks=0,
155
+ init_values=1.0,
156
+ embed_dim=1024,
157
+ ):
158
+ """
159
+ Build the patch embed layer. If 'conv', we use a
160
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
161
+ """
162
+
163
+ if "conv" in patch_embed:
164
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
165
+ else:
166
+ vit_models = {
167
+ "dinov2_vitl14_reg": vit_large,
168
+ "dinov2_vitb14_reg": vit_base,
169
+ "dinov2_vits14_reg": vit_small,
170
+ "dinov2_vitg2_reg": vit_giant2,
171
+ }
172
+
173
+ self.patch_embed = vit_models[patch_embed](
174
+ img_size=img_size,
175
+ patch_size=patch_size,
176
+ num_register_tokens=num_register_tokens,
177
+ interpolate_antialias=interpolate_antialias,
178
+ interpolate_offset=interpolate_offset,
179
+ block_chunks=block_chunks,
180
+ init_values=init_values,
181
+ )
182
+
183
+ # Disable gradient updates for mask token
184
+ if hasattr(self.patch_embed, "mask_token"):
185
+ self.patch_embed.mask_token.requires_grad_(False)
186
+
187
+ def forward(
188
+ self,
189
+ images: torch.Tensor,
190
+ ) -> Union[Tuple[List[torch.Tensor], int], Tuple[List[torch.Tensor], int, Dict]]:
191
+ """
192
+ Args:
193
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
194
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
195
+
196
+ Returns:
197
+ (list[torch.Tensor], int):
198
+ The list of outputs from the attention blocks,
199
+ and the patch_start_idx indicating where patch tokens begin.
200
+ """
201
+ B, S, C_in, H, W = images.shape
202
+ # Normalize images and reshape for patch embed
203
+ images = (images - self._resnet_mean.to(images.device)) / self._resnet_std.to(images.device)
204
+
205
+ # Reshape to [B*S, C, H, W] for patch embedding
206
+ images = images.reshape(B * S, C_in, H, W)
207
+ patch_tokens = self.patch_embed(images)
208
+
209
+ if isinstance(patch_tokens, dict):
210
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
211
+
212
+ _, P, C = patch_tokens.shape
213
+
214
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
215
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
216
+
217
+ # Concatenate special tokens with patch tokens
218
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
219
+
220
+ pos = None
221
+ if self.rope is not None:
222
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
223
+
224
+ if self.patch_start_idx > 0:
225
+ # do not use position embedding for special tokens (camera and register tokens)
226
+ # so set pos to 0 for the special tokens
227
+ pos = pos + 1
228
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
229
+ pos = torch.cat([pos_special, pos], dim=1)
230
+
231
+ # update P because we added special tokens
232
+ _, P, C = tokens.shape
233
+
234
+ frame_idx = 0
235
+ global_idx = 0
236
+ output_list = []
237
+
238
+ for block_num in range(self.aa_block_num):
239
+ for attn_type in self.aa_order:
240
+ if attn_type == "frame":
241
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
242
+ tokens, B, S, P, C, frame_idx, pos=pos
243
+ )
244
+ elif attn_type == "global":
245
+ tokens, global_idx, global_intermediates = self._process_global_attention(
246
+ tokens, B, S, P, C, global_idx, pos=pos
247
+ )
248
+ else:
249
+ raise ValueError(f"Unknown attention type: {attn_type}")
250
+
251
+ for i in range(len(frame_intermediates)):
252
+ # concat frame and global intermediates, [B x S x P x 2C]
253
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
254
+ output_list.append(concat_inter)
255
+
256
+ del concat_inter
257
+ del frame_intermediates
258
+ del global_intermediates
259
+
260
+ return output_list, self.patch_start_idx
261
+
262
+
263
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
264
+ """
265
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
266
+ """
267
+ # If needed, reshape tokens or positions:
268
+ if tokens.shape != (B * S, P, C):
269
+ tokens = tokens.reshape(B, S, P, C).reshape(B * S, P, C)
270
+
271
+ if pos is not None and pos.shape != (B * S, P, 2):
272
+ pos = pos.reshape(B, S, P, 2).reshape(B * S, P, 2)
273
+
274
+ intermediates = []
275
+
276
+ # by default, self.aa_block_size=1, which processes one block at a time
277
+ for _ in range(self.aa_block_size):
278
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
279
+ frame_idx += 1
280
+ intermediates.append(tokens.reshape(B, S, P, C))
281
+
282
+ return tokens, frame_idx, intermediates
283
+
284
+
285
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None) -> Union[Tuple[torch.Tensor, int, List[torch.Tensor]], Tuple[torch.Tensor, int, List[torch.Tensor], List]]:
286
+ """
287
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
288
+ """
289
+
290
+ if tokens.shape != (B, S * P, C):
291
+ tokens = tokens.reshape(B, S, P, C).reshape(B, S * P, C)
292
+
293
+ if pos is not None and pos.shape != (B, S * P, 2):
294
+ pos = pos.reshape(B, S, P, 2).reshape(B, S * P, 2)
295
+
296
+ intermediates = []
297
+
298
+ for _ in range(self.aa_block_size):
299
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
300
+
301
+ global_idx += 1
302
+ intermediates.append(tokens.reshape(B, S, P, C))
303
+
304
+ return tokens, global_idx, intermediates
305
+
306
+
307
+
308
+
309
+ def slice_expand_and_flatten(token_tensor, B, S):
310
+ """
311
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
312
+ 1) Uses the first position (index=0) for the first frame only
313
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
314
+ 3) Expands both to match batch size B
315
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
316
+ followed by (S-1) second-position tokens
317
+ 5) Flattens to (B*S, X, C) for processing
318
+
319
+ Returns:
320
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
321
+ """
322
+
323
+ # Slice out the "query" tokens => shape (1, 1, ...)
324
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
325
+ # Slice out the "other" tokens => shape (1, S-1, ...)
326
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
327
+ # Concatenate => shape (B, S, ...)
328
+ combined = torch.cat([query, others], dim=1)
329
+
330
+ # Finally flatten => shape (B*S, ...)
331
+ combined = combined.reshape(B * S, *combined.shape[2:])
332
+ return combined
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/models/vggt.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
10
+
11
+ from vggt.models.aggregator import Aggregator
12
+ from vggt.heads.camera_head import CameraHead
13
+ from vggt.heads.dpt_head import DPTHead
14
+ from vggt.heads.track_head import TrackHead
15
+ from transformers.file_utils import ModelOutput
16
+ from typing import Optional, Tuple, List, Any
17
+ from dataclasses import dataclass
18
+
19
+ @dataclass
20
+ class VGGTOutput(ModelOutput):
21
+ ress: Optional[List[dict]] = None
22
+ views: Optional[torch.Tensor] = None
23
+
24
+ class VGGT(nn.Module, PyTorchModelHubMixin):
25
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
26
+ super().__init__()
27
+
28
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
29
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
30
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
31
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
32
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
33
+
34
+ def forward(
35
+ self,
36
+ views,
37
+ query_points: torch.Tensor = None,
38
+ ):
39
+ """
40
+ Forward pass of the VGGT model.
41
+
42
+ Args:
43
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
44
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
45
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
46
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
47
+ Default: None
48
+ history_info (dict, optional): Token history for streaming inference
49
+ past_key_values (dict, optional): KV cache from previous steps
50
+ use_cache (bool): Whether to use and return KV cache
51
+ past_frame_idx (int): Frame index for position encoding in sequence
52
+
53
+ Returns:
54
+ dict: A dictionary containing the following predictions:
55
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
56
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
57
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
58
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
59
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
60
+ - images (torch.Tensor): Original input images, preserved for visualization
61
+
62
+ If query_points is provided, also includes:
63
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
64
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
65
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
66
+ """
67
+ images = torch.stack(
68
+ [view["img"] for view in views], dim=0
69
+ ).permute(1, 0, 2, 3, 4) # B S C H
70
+
71
+ # If without batch dimension, add it
72
+ if len(images.shape) == 4:
73
+ images = images.unsqueeze(0)
74
+
75
+ if query_points is not None and len(query_points.shape) == 2:
76
+ query_points = query_points.unsqueeze(0)
77
+
78
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
79
+ predictions = {}
80
+
81
+ with torch.cuda.amp.autocast(enabled=False):
82
+ if self.camera_head is not None:
83
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
84
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
85
+
86
+ if self.depth_head is not None:
87
+ depth, depth_conf = self.depth_head(
88
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
89
+ )
90
+ predictions["depth"] = depth
91
+ predictions["depth_conf"] = depth_conf
92
+
93
+ if self.point_head is not None:
94
+ pts3d, pts3d_conf = self.point_head(
95
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
96
+ )
97
+ predictions["world_points"] = pts3d
98
+ predictions["world_points_conf"] = pts3d_conf
99
+ if self.track_head is not None and query_points is not None:
100
+ track_list, vis, conf = self.track_head(
101
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx,
102
+ query_points=query_points
103
+ )
104
+ predictions["track"] = track_list[-1] # track of the last iteration
105
+ predictions["vis"] = vis
106
+ predictions["conf"] = conf
107
+ predictions["images"] = images
108
+
109
+ B, S = images.shape[:2]
110
+ ress = []
111
+ for s in range(S):
112
+ res = {
113
+ 'pts3d_in_other_view': predictions['world_points'][:, s], # [B, H, W, 3]
114
+ 'conf': predictions['world_points_conf'][:, s], # [B, H, W]
115
+
116
+ 'depth': predictions['depth'][:, s], # [B, H, W, 1]
117
+ 'depth_conf': predictions['depth_conf'][:, s], # [B, H, W]
118
+
119
+ 'camera_pose': predictions['pose_enc'][:, s, :7], # [B, 7]
120
+
121
+ **({'valid_mask': views[s]["valid_mask"]}
122
+ if 'valid_mask' in views[s] else {}), # [B, H, W]
123
+
124
+ **({'track': predictions['track'][:, s], # [B, N, 2]
125
+ 'vis': predictions['vis'][:, s], # [B, N]
126
+ 'track_conf': predictions['conf'][:, s]}
127
+ if 'track' in predictions else {})
128
+ }
129
+ ress.append(res)
130
+ return VGGTOutput(ress=ress, views=views) # [S] [B, C, H, W]
131
+
132
+ def inference(
133
+ self,
134
+ views,
135
+ query_points: torch.Tensor = None,
136
+ ):
137
+ """
138
+ Forward pass of the VGGT model.
139
+
140
+ Args:
141
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
142
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
143
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
144
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
145
+ Default: None
146
+ history_info (dict, optional): Token history for streaming inference
147
+ past_key_values (dict, optional): KV cache from previous steps
148
+ use_cache (bool): Whether to use and return KV cache
149
+ past_frame_idx (int): Frame index for position encoding in sequence
150
+
151
+ Returns:
152
+ dict: A dictionary containing the following predictions:
153
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
154
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
155
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
156
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
157
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
158
+ - images (torch.Tensor): Original input images, preserved for visualization
159
+
160
+ If query_points is provided, also includes:
161
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
162
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
163
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
164
+ """
165
+ images = torch.stack(
166
+ [view["img"] for view in views], dim=0
167
+ ).permute(1, 0, 2, 3, 4) # B S C H W
168
+
169
+ # If without batch dimension, add it
170
+ if len(images.shape) == 4:
171
+ images = images.unsqueeze(0)
172
+
173
+ if query_points is not None and len(query_points.shape) == 2:
174
+ query_points = query_points.unsqueeze(0)
175
+
176
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
177
+ predictions = {}
178
+
179
+ with torch.cuda.amp.autocast(enabled=False):
180
+ if self.camera_head is not None:
181
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
182
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
183
+
184
+ if self.depth_head is not None:
185
+ depth, depth_conf = self.depth_head(
186
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
187
+ )
188
+ predictions["depth"] = depth
189
+ predictions["depth_conf"] = depth_conf
190
+
191
+ if self.point_head is not None:
192
+ pts3d, pts3d_conf = self.point_head(
193
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
194
+ )
195
+ predictions["world_points"] = pts3d
196
+ predictions["world_points_conf"] = pts3d_conf
197
+ if self.track_head is not None and query_points is not None:
198
+ track_list, vis, conf = self.track_head(
199
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx,
200
+ query_points=query_points
201
+ )
202
+ predictions["track"] = track_list[-1] # track of the last iteration
203
+ predictions["vis"] = vis
204
+ predictions["conf"] = conf
205
+ predictions["images"] = images
206
+
207
+ B, S = images.shape[:2]
208
+ ress = []
209
+ for s in range(S):
210
+ res = {
211
+ 'pts3d_in_other_view': predictions['world_points'][:, s], # [B, H, W, 3]
212
+ 'conf': predictions['world_points_conf'][:, s], # [B, H, W]
213
+
214
+ 'depth': predictions['depth'][:, s], # [B, H, W, 1]
215
+ 'depth_conf': predictions['depth_conf'][:, s], # [B, H, W]
216
+
217
+ 'camera_pose': predictions['pose_enc'][:, s, :], # [B, 9]
218
+
219
+ **({'valid_mask': views[s]["valid_mask"]}
220
+ if 'valid_mask' in views[s] else {}), # [B, H, W]
221
+
222
+ **({'track': predictions['track'][:, s], # [B, N, 2]
223
+ 'vis': predictions['vis'][:, s], # [B, N]
224
+ 'track_conf': predictions['conf'][:, s]}
225
+ if 'track' in predictions else {})
226
+ }
227
+ ress.append(res)
228
+ return VGGTOutput(ress=ress, views=views) # [S] [B, C, H, W]
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/augmentation.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Dict
8
+ from torchvision import transforms
9
+
10
+
11
+ def get_image_augmentation(
12
+ color_jitter: Optional[Dict[str, float]] = None,
13
+ gray_scale: bool = True,
14
+ gau_blur: bool = False
15
+ ) -> Optional[transforms.Compose]:
16
+ """Create a composition of image augmentations.
17
+
18
+ Args:
19
+ color_jitter: Dictionary containing color jitter parameters:
20
+ - brightness: float (default: 0.5)
21
+ - contrast: float (default: 0.5)
22
+ - saturation: float (default: 0.5)
23
+ - hue: float (default: 0.1)
24
+ - p: probability of applying (default: 0.9)
25
+ If None, uses default values
26
+ gray_scale: Whether to apply random grayscale (default: True)
27
+ gau_blur: Whether to apply gaussian blur (default: False)
28
+
29
+ Returns:
30
+ A Compose object of transforms or None if no transforms are added
31
+ """
32
+ transform_list = []
33
+ default_jitter = {
34
+ "brightness": 0.5,
35
+ "contrast": 0.5,
36
+ "saturation": 0.5,
37
+ "hue": 0.1,
38
+ "p": 0.9
39
+ }
40
+
41
+ # Handle color jitter
42
+ if color_jitter is not None:
43
+ # Merge with defaults for missing keys
44
+ effective_jitter = {**default_jitter, **color_jitter}
45
+ else:
46
+ effective_jitter = default_jitter
47
+
48
+ transform_list.append(
49
+ transforms.RandomApply(
50
+ [
51
+ transforms.ColorJitter(
52
+ brightness=effective_jitter["brightness"],
53
+ contrast=effective_jitter["contrast"],
54
+ saturation=effective_jitter["saturation"],
55
+ hue=effective_jitter["hue"],
56
+ )
57
+ ],
58
+ p=effective_jitter["p"],
59
+ )
60
+ )
61
+
62
+ if gray_scale:
63
+ transform_list.append(transforms.RandomGrayscale(p=0.05))
64
+
65
+ if gau_blur:
66
+ transform_list.append(
67
+ transforms.RandomApply(
68
+ [transforms.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05
69
+ )
70
+ )
71
+
72
+ return transforms.Compose(transform_list) if transform_list else None
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/general.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import os
10
+ import math
11
+ import random
12
+ import numpy as np
13
+ from typing import Union, Optional
14
+ import logging
15
+ from iopath.common.file_io import g_pathmgr
16
+ import torch.distributed as dist
17
+ from pathlib import Path
18
+ from typing import Dict, Iterable, List
19
+
20
+
21
+
22
+ from collections import defaultdict
23
+ from dataclasses import fields, is_dataclass
24
+ from typing import Any, Mapping, Protocol, runtime_checkable
25
+
26
+
27
+
28
+
29
+ def check_and_fix_inf_nan(input_tensor, loss_name="default", hard_max=100):
30
+ """
31
+ Checks if 'input_tensor' contains inf or nan values and clamps extreme values.
32
+
33
+ Args:
34
+ input_tensor (torch.Tensor): The loss tensor to check and fix.
35
+ loss_name (str): Name of the loss (for diagnostic prints).
36
+ hard_max (float, optional): Maximum absolute value allowed. Values outside
37
+ [-hard_max, hard_max] will be clamped. If None,
38
+ no clamping is performed. Defaults to 100.
39
+ """
40
+ if input_tensor is None:
41
+ return input_tensor
42
+
43
+ # Check for inf/nan values
44
+ has_inf_nan = torch.isnan(input_tensor).any() or torch.isinf(input_tensor).any()
45
+ if has_inf_nan:
46
+ logging.warning(f"Tensor {loss_name} contains inf or nan values. Replacing with zeros.")
47
+ input_tensor = torch.where(
48
+ torch.isnan(input_tensor) | torch.isinf(input_tensor),
49
+ torch.zeros_like(input_tensor),
50
+ input_tensor
51
+ )
52
+
53
+ # Apply hard clamping if specified
54
+ if hard_max is not None:
55
+ input_tensor = torch.clamp(input_tensor, min=-hard_max, max=hard_max)
56
+
57
+ return input_tensor
58
+
59
+
60
+ def get_resume_checkpoint(checkpoint_save_dir):
61
+ if not g_pathmgr.isdir(checkpoint_save_dir):
62
+ return None
63
+ ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt")
64
+ if not g_pathmgr.isfile(ckpt_file):
65
+ return None
66
+
67
+ return ckpt_file
68
+
69
+ class DurationMeter:
70
+ def __init__(self, name, device, fmt=":f"):
71
+ self.name = name
72
+ self.device = device
73
+ self.fmt = fmt
74
+ self.val = 0
75
+
76
+ def reset(self):
77
+ self.val = 0
78
+
79
+ def update(self, val):
80
+ self.val = val
81
+
82
+ def add(self, val):
83
+ self.val += val
84
+
85
+ def __str__(self):
86
+ return f"{self.name}: {human_readable_time(self.val)}"
87
+
88
+
89
+ def human_readable_time(time_seconds):
90
+ time = int(time_seconds)
91
+ minutes, seconds = divmod(time, 60)
92
+ hours, minutes = divmod(minutes, 60)
93
+ days, hours = divmod(hours, 24)
94
+ return f"{days:02}d {hours:02}h {minutes:02}m"
95
+
96
+
97
+
98
+ class ProgressMeter:
99
+ def __init__(self, num_batches, meters, real_meters, prefix=""):
100
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
101
+ self.meters = meters
102
+ self.real_meters = real_meters
103
+ self.prefix = prefix
104
+
105
+ def display(self, batch):
106
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
107
+ entries += [str(meter) for meter in self.meters]
108
+ entries += [
109
+ " | ".join(
110
+ [
111
+ f"{os.path.join(name, subname)}: {val:.4f}"
112
+ for subname, val in meter.compute().items()
113
+ ]
114
+ )
115
+ for name, meter in self.real_meters.items()
116
+ ]
117
+ logging.info(" | ".join(entries))
118
+
119
+ def _get_batch_fmtstr(self, num_batches):
120
+ num_digits = len(str(num_batches // 1))
121
+ fmt = "{:" + str(num_digits) + "d}"
122
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
123
+
124
+
125
+
126
+ @runtime_checkable
127
+ class _CopyableData(Protocol):
128
+ def to(self, device: torch.device, *args: Any, **kwargs: Any):
129
+ """Copy data to the specified device"""
130
+ ...
131
+
132
+
133
+ def _is_named_tuple(x) -> bool:
134
+ return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
135
+
136
+
137
+ def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
138
+ """Function that recursively copies data to a torch.device.
139
+
140
+ Args:
141
+ data: The data to copy to device
142
+ device: The device to which the data should be copied
143
+ args: positional arguments that will be passed to the `to` call
144
+ kwargs: keyword arguments that will be passed to the `to` call
145
+
146
+ Returns:
147
+ The data on the correct device
148
+ """
149
+
150
+ if _is_named_tuple(data):
151
+ return type(data)(
152
+ **copy_data_to_device(data._asdict(), device, *args, **kwargs)
153
+ )
154
+ elif isinstance(data, (list, tuple)):
155
+ return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
156
+ elif isinstance(data, defaultdict):
157
+ return type(data)(
158
+ data.default_factory,
159
+ {
160
+ k: copy_data_to_device(v, device, *args, **kwargs)
161
+ for k, v in data.items()
162
+ },
163
+ )
164
+ elif isinstance(data, Mapping) and not is_dataclass(data): # handing FrameData-like things
165
+ return type(data)(
166
+ {
167
+ k: copy_data_to_device(v, device, *args, **kwargs)
168
+ for k, v in data.items()
169
+ }
170
+ )
171
+ elif is_dataclass(data) and not isinstance(data, type):
172
+ new_data_class = type(data)(
173
+ **{
174
+ field.name: copy_data_to_device(
175
+ getattr(data, field.name), device, *args, **kwargs
176
+ )
177
+ for field in fields(data)
178
+ if field.init
179
+ }
180
+ )
181
+ for field in fields(data):
182
+ if not field.init:
183
+ setattr(
184
+ new_data_class,
185
+ field.name,
186
+ copy_data_to_device(
187
+ getattr(data, field.name), device, *args, **kwargs
188
+ ),
189
+ )
190
+ return new_data_class
191
+ elif isinstance(data, _CopyableData):
192
+ return data.to(device, *args, **kwargs)
193
+ return data
194
+
195
+
196
+
197
+ def safe_makedirs(path: str):
198
+ if not path:
199
+ logging.warning("safe_makedirs called with an empty path. No operation performed.")
200
+ return False
201
+
202
+ try:
203
+ os.makedirs(path, exist_ok=True)
204
+ return True
205
+ except OSError as e:
206
+ logging.error(f"Failed to create directory '{path}'. Reason: {e}")
207
+ raise
208
+ except Exception as e:
209
+ # Catch any other unexpected errors.
210
+ logging.error(f"An unexpected error occurred while creating directory '{path}'. Reason: {e}")
211
+ raise
212
+
213
+
214
+
215
+ def set_seeds(seed_value, max_epochs, dist_rank):
216
+ """
217
+ Set the python random, numpy and torch seed for each gpu. Also set the CUDA
218
+ seeds if the CUDA is available. This ensures deterministic nature of the training.
219
+ """
220
+ seed_value = (seed_value + dist_rank) * max_epochs
221
+ logging.info(f"GPU SEED: {seed_value}")
222
+ random.seed(seed_value)
223
+ np.random.seed(seed_value)
224
+ torch.manual_seed(seed_value)
225
+
226
+ if torch.cuda.is_available():
227
+ torch.cuda.manual_seed(seed_value)
228
+ torch.cuda.manual_seed_all(seed_value) # for multi-GPU
229
+
230
+
231
+
232
+
233
+ def log_env_variables():
234
+ env_keys = sorted(list(os.environ.keys()))
235
+ st = ""
236
+ for k in env_keys:
237
+ v = os.environ[k]
238
+ st += f"{k}={v}\n"
239
+ logging.info("Logging ENV_VARIABLES")
240
+ logging.info(st)
241
+
242
+
243
+ def is_dist_avail_and_initialized():
244
+ if not dist.is_available():
245
+ return False
246
+ if not dist.is_initialized():
247
+ return False
248
+ return True
249
+
250
+
251
+
252
+ class AverageMeter:
253
+ """Computes and stores the average and current value.
254
+ Args:
255
+ name (str): Name of the metric being tracked
256
+ device (torch.device, optional): Device for tensor operations. Defaults to None.
257
+ fmt (str): Format string for displaying values. Defaults to ":f"
258
+ """
259
+
260
+ def __init__(self, name: str, device: Optional[torch.device] = None, fmt: str = ":f"):
261
+ self.name = name
262
+ self.fmt = fmt
263
+ self.device = device
264
+ self.reset()
265
+
266
+ def reset(self):
267
+ self.val = 0
268
+ self.avg = 0
269
+ self.sum = 0
270
+ self.count = 0
271
+ self._allow_updates = True
272
+
273
+ def update(self, val, n=1):
274
+ if n <= 0:
275
+ raise ValueError(f"n must be positive, got {n}")
276
+
277
+ self.val = val
278
+ self.sum += val * n
279
+ self.count += n
280
+ self.avg = self.sum / self.count if self.count > 0 else 0.0
281
+
282
+ def __str__(self) -> str:
283
+ """String representation showing current and average values."""
284
+ fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})"
285
+ return fmtstr.format(**self.__dict__)
286
+
287
+ @property
288
+ def value(self) -> float:
289
+ """Get the current value."""
290
+ return self.val
291
+
292
+ @property
293
+ def average(self) -> float:
294
+ """Get the running average."""
295
+ return self.avg
296
+
297
+ #################
298
+
299
+
300
+ _UNITS = ('', ' K', ' M', ' B', ' T') # U+202F = thin-space for nicer look
301
+
302
+ def pretty_int(n: int) -> str:
303
+ """Abbreviate a non-negative integer (0 → 0, 12_345 → '12.3 K')."""
304
+ assert n >= 0, 'pretty_int() expects a non-negative int'
305
+ if n < 1_000:
306
+ return f'{n:,}'
307
+ exp = int(math.log10(n) // 3) # group of 3 digits
308
+ exp = min(exp, len(_UNITS) - 1) # cap at trillions
309
+ value = n / 10 ** (3 * exp)
310
+ return f'{value:.1f}'.rstrip('0').rstrip('.') + _UNITS[exp]
311
+
312
+
313
+ def model_summary(model: torch.nn.Module,
314
+ *,
315
+ log_file = None,
316
+ prefix: str = '') -> None:
317
+ """
318
+ Print / save a compact parameter summary.
319
+
320
+ Args
321
+ ----
322
+ model : The PyTorch nn.Module to inspect.
323
+ log_file : Optional path – if given, the full `str(model)` and per-parameter
324
+ lists are written there (three separate *.txt files).
325
+ prefix : Optional string printed at the beginning of every log line
326
+ (handy when several models share the same stdout).
327
+ """
328
+ if get_rank(): # only rank-0 prints
329
+ return
330
+
331
+ # --- counts -------------------------------------------------------------
332
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
333
+ total = sum(p.numel() for p in model.parameters())
334
+ frozen = total - trainable
335
+
336
+ print(prefix + '='*60)
337
+ print(prefix + f'Model type : {model.__class__.__name__}')
338
+ print(prefix + f'Total : {pretty_int(total)} parameters')
339
+ print(prefix + f' trainable: {pretty_int(trainable)}')
340
+ print(prefix + f' frozen : {pretty_int(frozen)}')
341
+ print(prefix + '='*60)
342
+
343
+ # --- optional file dump -------------------------------------------------
344
+ if log_file is None:
345
+ return
346
+
347
+ log_file = Path(log_file)
348
+ log_file.write_text(str(model)) # full architecture
349
+
350
+ # two extra detailed lists
351
+ def _dump(names: Iterable[str], fname: str):
352
+ """Write a formatted per-parameter list to *log_file.with_name(fname)*."""
353
+ with open(log_file.with_name(fname), 'w') as f:
354
+ for n in names:
355
+ p = dict(model.named_parameters())[n]
356
+ shape = str(tuple(p.shape))
357
+ f.write(f'{n:<60s} {shape:<20} {p.numel()}\n')
358
+
359
+ named = dict(model.named_parameters())
360
+ _dump([n for n,p in named.items() if p.requires_grad], 'trainable.txt')
361
+ _dump([n for n,p in named.items() if not p.requires_grad], 'frozen.txt')
362
+
363
+
364
+ def get_rank():
365
+ if not is_dist_avail_and_initialized():
366
+ return 0
367
+ return dist.get_rank()
368
+
369
+
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/normalization.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import logging
9
+ from typing import Optional, Tuple
10
+ from vggt.utils.geometry import closed_form_inverse_se3
11
+ from vggt.train_utils.general import check_and_fix_inf_nan
12
+
13
+
14
+ def check_valid_tensor(input_tensor: Optional[torch.Tensor], name: str = "tensor") -> None:
15
+ """
16
+ Check if a tensor contains NaN or Inf values and log a warning if found.
17
+
18
+ Args:
19
+ input_tensor: The tensor to check
20
+ name: Name of the tensor for logging purposes
21
+ """
22
+ if input_tensor is not None:
23
+ if torch.isnan(input_tensor).any() or torch.isinf(input_tensor).any():
24
+ logging.warning(f"NaN or Inf found in tensor: {name}")
25
+
26
+
27
+ def normalize_camera_extrinsics_and_points_batch(
28
+ extrinsics: torch.Tensor,
29
+ cam_points: Optional[torch.Tensor] = None,
30
+ world_points: Optional[torch.Tensor] = None,
31
+ depths: Optional[torch.Tensor] = None,
32
+ scale_by_points: bool = True,
33
+ point_masks: Optional[torch.Tensor] = None,
34
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
35
+ """
36
+ Normalize camera extrinsics and corresponding 3D points.
37
+
38
+ This function transforms the coordinate system to be centered at the first camera
39
+ and optionally scales the scene to have unit average distance.
40
+
41
+ Args:
42
+ extrinsics: Camera extrinsic matrices of shape (B, S, 3, 4)
43
+ cam_points: 3D points in camera coordinates of shape (B, S, H, W, 3) or (*,3)
44
+ world_points: 3D points in world coordinates of shape (B, S, H, W, 3) or (*,3)
45
+ depths: Depth maps of shape (B, S, H, W)
46
+ scale_by_points: Whether to normalize the scale based on point distances
47
+ point_masks: Boolean masks for valid points of shape (B, S, H, W)
48
+
49
+ Returns:
50
+ Tuple containing:
51
+ - Normalized camera extrinsics of shape (B, S, 3, 4)
52
+ - Normalized camera points (same shape as input cam_points)
53
+ - Normalized world points (same shape as input world_points)
54
+ - Normalized depths (same shape as input depths)
55
+ """
56
+ # Validate inputs
57
+ check_valid_tensor(extrinsics, "extrinsics")
58
+ check_valid_tensor(cam_points, "cam_points")
59
+ check_valid_tensor(world_points, "world_points")
60
+ check_valid_tensor(depths, "depths")
61
+
62
+
63
+ B, S, _, _ = extrinsics.shape
64
+ device = extrinsics.device
65
+ #assert device == torch.device("cpu")
66
+
67
+
68
+ # Convert extrinsics to homogeneous form: (B, N,4,4)
69
+ extrinsics_homog = torch.cat(
70
+ [
71
+ extrinsics,
72
+ torch.zeros((B, S, 1, 4), device=device),
73
+ ],
74
+ dim=-2,
75
+ )
76
+ extrinsics_homog[:, :, -1, -1] = 1.0
77
+
78
+ # first_cam_extrinsic_inv, the inverse of the first camera's extrinsic matrix
79
+ # which can be also viewed as the cam_to_world extrinsic matrix
80
+ first_cam_extrinsic_inv = closed_form_inverse_se3(extrinsics_homog[:, 0])
81
+ # new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv)
82
+ new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv.unsqueeze(1)) # (B,N,4,4)
83
+
84
+
85
+ if world_points is not None:
86
+ # since we are transforming the world points to the first camera's coordinate system
87
+ # we directly use the cam_from_world extrinsic matrix of the first camera
88
+ # instead of using the inverse of the first camera's extrinsic matrix
89
+ R = extrinsics[:, 0, :3, :3]
90
+ t = extrinsics[:, 0, :3, 3]
91
+ new_world_points = (world_points @ R.transpose(-1, -2).unsqueeze(1).unsqueeze(2)) + t.unsqueeze(1).unsqueeze(2).unsqueeze(3)
92
+ else:
93
+ new_world_points = None
94
+
95
+
96
+ if scale_by_points:
97
+ new_depths = depths.clone()
98
+
99
+ dist = new_world_points.norm(dim=-1)
100
+ dist_sum = (dist * point_masks).sum(dim=[1,2,3])
101
+ valid_count = point_masks.sum(dim=[1,2,3])
102
+ avg_scale = (dist_sum / (valid_count + 1e-3)).clamp(min=1e-6, max=1e6)
103
+
104
+
105
+ new_world_points = new_world_points / avg_scale.view(-1, 1, 1, 1, 1)
106
+ new_extrinsics[:, :, :3, 3] = new_extrinsics[:, :, :3, 3] / avg_scale.view(-1, 1, 1)
107
+ if depths is not None:
108
+ new_depths = new_depths / avg_scale.view(-1, 1, 1, 1)
109
+ if cam_points is not None:
110
+ new_cam_points = cam_points.clone()
111
+ new_cam_points = new_cam_points / avg_scale.view(-1, 1, 1, 1, 1)
112
+ else:
113
+ return new_extrinsics[:, :, :3], cam_points, new_world_points, depths
114
+
115
+ new_extrinsics = new_extrinsics[:, :, :3] # 4x4 -> 3x4
116
+ new_extrinsics = check_and_fix_inf_nan(new_extrinsics, "new_extrinsics", hard_max=None)
117
+ if cam_points is not None:
118
+ new_cam_points = check_and_fix_inf_nan(new_cam_points, "new_cam_points", hard_max=None)
119
+ else:
120
+ new_cam_points = None
121
+ new_world_points = check_and_fix_inf_nan(new_world_points, "new_world_points", hard_max=None)
122
+ new_depths = check_and_fix_inf_nan(new_depths, "new_depths", hard_max=None)
123
+
124
+
125
+ return new_extrinsics, new_cam_points, new_world_points, new_depths
126
+
127
+
128
+
129
+
130
+
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/train_utils/normalization_v37.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import logging
9
+ from typing import Optional, Tuple
10
+ from vggt.utils.geometry import closed_form_inverse_se3
11
+ from vggt.train_utils.general import check_and_fix_inf_nan
12
+
13
+
14
+ def check_valid_tensor(input_tensor: Optional[torch.Tensor], name: str = "tensor") -> None:
15
+ """
16
+ Check if a tensor contains NaN or Inf values and log a warning if found.
17
+
18
+ Args:
19
+ input_tensor: The tensor to check
20
+ name: Name of the tensor for logging purposes
21
+ """
22
+ if input_tensor is not None:
23
+ if torch.isnan(input_tensor).any() or torch.isinf(input_tensor).any():
24
+ logging.warning(f"NaN or Inf found in tensor: {name}")
25
+
26
+
27
+ def normalize_camera_extrinsics_and_points_batch(
28
+ extrinsics: torch.Tensor,
29
+ cam_points: Optional[torch.Tensor] = None,
30
+ world_points: Optional[torch.Tensor] = None,
31
+ depths: Optional[torch.Tensor] = None,
32
+ scale_by_points: bool = True,
33
+ point_masks: Optional[torch.Tensor] = None,
34
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
35
+ """
36
+ Normalize camera extrinsics and corresponding 3D points.
37
+
38
+ This function transforms the coordinate system to be centered at the first camera
39
+ and optionally scales the scene to have unit average distance.
40
+
41
+ Args:
42
+ extrinsics: Camera extrinsic matrices of shape (B, S, 3, 4)
43
+ cam_points: 3D points in camera coordinates of shape (B, S, H, W, 3) or (*,3)
44
+ world_points: 3D points in world coordinates of shape (B, S, H, W, 3) or (*,3)
45
+ depths: Depth maps of shape (B, S, H, W)
46
+ scale_by_points: Whether to normalize the scale based on point distances
47
+ point_masks: Boolean masks for valid points of shape (B, S, H, W)
48
+
49
+ Returns:
50
+ Tuple containing:
51
+ - Normalized camera extrinsics of shape (B, S, 3, 4)
52
+ - Normalized camera points (same shape as input cam_points)
53
+ - Normalized world points (same shape as input world_points)
54
+ - Normalized depths (same shape as input depths)
55
+ """
56
+ # Validate inputs
57
+ check_valid_tensor(extrinsics, "extrinsics")
58
+ check_valid_tensor(cam_points, "cam_points")
59
+ check_valid_tensor(world_points, "world_points")
60
+ check_valid_tensor(depths, "depths")
61
+
62
+
63
+ B, S, _, _ = extrinsics.shape
64
+ device = extrinsics.device
65
+ #assert device == torch.device("cpu")
66
+
67
+
68
+ # Convert extrinsics to homogeneous form: (B, N,4,4)
69
+ extrinsics_homog = torch.cat(
70
+ [
71
+ extrinsics,
72
+ torch.zeros((B, S, 1, 4), device=device),
73
+ ],
74
+ dim=-2,
75
+ )
76
+ extrinsics_homog[:, :, -1, -1] = 1.0
77
+
78
+ # first_cam_extrinsic_inv, the inverse of the first camera's extrinsic matrix
79
+ # which can be also viewed as the cam_to_world extrinsic matrix
80
+ first_cam_extrinsic_inv = closed_form_inverse_se3(extrinsics_homog[:, 0])
81
+ # new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv)
82
+ new_extrinsics = torch.matmul(extrinsics_homog, first_cam_extrinsic_inv.unsqueeze(1)) # (B,N,4,4)
83
+
84
+
85
+ if world_points is not None:
86
+ # since we are transforming the world points to the first camera's coordinate system
87
+ # we directly use the cam_from_world extrinsic matrix of the first camera
88
+ # instead of using the inverse of the first camera's extrinsic matrix
89
+ R = extrinsics[:, 0, :3, :3]
90
+ t = extrinsics[:, 0, :3, 3]
91
+ new_world_points = (world_points @ R.transpose(-1, -2).unsqueeze(1).unsqueeze(2)) + t.unsqueeze(1).unsqueeze(2).unsqueeze(3)
92
+ else:
93
+ new_world_points = None
94
+
95
+
96
+ if scale_by_points:
97
+ new_depths = depths.clone()
98
+
99
+ dist = new_world_points.norm(dim=-1)
100
+ dist_sum = (dist * point_masks).sum(dim=[1,2,3])
101
+ valid_count = point_masks.sum(dim=[1,2,3])
102
+ avg_scale = (dist_sum / (valid_count + 1e-3)).clamp(min=1e-6, max=1e6)
103
+
104
+
105
+ new_world_points = new_world_points / avg_scale.view(-1, 1, 1, 1, 1)
106
+ new_extrinsics[:, :, :3, 3] = new_extrinsics[:, :, :3, 3] / avg_scale.view(-1, 1, 1)
107
+ if depths is not None:
108
+ new_depths = new_depths / avg_scale.view(-1, 1, 1, 1)
109
+ if cam_points is not None:
110
+ new_cam_points = cam_points.clone()
111
+ new_cam_points = new_cam_points / avg_scale.view(-1, 1, 1, 1, 1)
112
+ else:
113
+ return new_extrinsics[:, :, :3], cam_points, new_world_points, depths
114
+
115
+ new_extrinsics = new_extrinsics[:, :, :3] # 4x4 -> 3x4
116
+ new_extrinsics = check_and_fix_inf_nan(new_extrinsics, "new_extrinsics", hard_max=None)
117
+ if cam_points is not None:
118
+ new_cam_points = check_and_fix_inf_nan(new_cam_points, "new_cam_points", hard_max=None)
119
+ else:
120
+ new_cam_points = None
121
+ new_world_points = check_and_fix_inf_nan(new_world_points, "new_world_points", hard_max=None)
122
+ new_depths = check_and_fix_inf_nan(new_depths, "new_depths", hard_max=None)
123
+
124
+
125
+ return new_extrinsics, new_cam_points, new_world_points, new_depths
126
+
127
+
128
+
129
+
130
+
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/geometry.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ def unproject_depth_map_to_point_map(
13
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
14
+ ) -> np.ndarray:
15
+ """
16
+ Unproject a batch of depth maps to 3D world coordinates.
17
+
18
+ Args:
19
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
20
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
21
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
22
+
23
+ Returns:
24
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
25
+ """
26
+ if isinstance(depth_map, torch.Tensor):
27
+ depth_map = depth_map.cpu().numpy()
28
+ if isinstance(extrinsics_cam, torch.Tensor):
29
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
30
+ if isinstance(intrinsics_cam, torch.Tensor):
31
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
32
+
33
+ world_points_list = []
34
+ for frame_idx in range(depth_map.shape[0]):
35
+ cur_world_points, _, _ = depth_to_world_coords_points(
36
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
+ )
38
+ world_points_list.append(cur_world_points)
39
+ world_points_array = np.stack(world_points_list, axis=0)
40
+
41
+ return world_points_array
42
+
43
+
44
+ def depth_to_world_coords_points(
45
+ depth_map: np.ndarray,
46
+ extrinsic: np.ndarray,
47
+ intrinsic: np.ndarray,
48
+ eps=1e-8,
49
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
50
+ """
51
+ Convert a depth map to world coordinates.
52
+
53
+ Args:
54
+ depth_map (np.ndarray): Depth map of shape (H, W).
55
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
56
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
57
+
58
+ Returns:
59
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
60
+ """
61
+ if depth_map is None:
62
+ return None, None, None
63
+
64
+ # Valid depth mask
65
+ point_mask = depth_map > eps
66
+
67
+ # Convert depth map to camera coordinates
68
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
69
+
70
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
71
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
72
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
73
+
74
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
75
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
76
+
77
+ # Apply the rotation and translation to the camera coordinates
78
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
79
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
80
+
81
+ return world_coords_points, cam_coords_points, point_mask
82
+
83
+
84
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
85
+ """
86
+ Convert a depth map to camera coordinates.
87
+
88
+ Args:
89
+ depth_map (np.ndarray): Depth map of shape (H, W).
90
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
91
+
92
+ Returns:
93
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
94
+ """
95
+ H, W = depth_map.shape
96
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
97
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
98
+
99
+ # Intrinsic parameters
100
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
101
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
102
+
103
+ # Generate grid of pixel coordinates
104
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
105
+
106
+ # Unproject to camera coordinates
107
+ x_cam = (u - cu) * depth_map / fu
108
+ y_cam = (v - cv) * depth_map / fv
109
+ z_cam = depth_map
110
+
111
+ # Stack to form camera coordinates
112
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
113
+
114
+ return cam_coords
115
+
116
+
117
+ def closed_form_inverse_se3(se3, R=None, T=None):
118
+ """
119
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
120
+
121
+ If `R` and `T` are provided, they must correspond to the rotation and translation
122
+ components of `se3`. Otherwise, they will be extracted from `se3`.
123
+
124
+ Args:
125
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
126
+ R (optional): Nx3x3 array or tensor of rotation matrices.
127
+ T (optional): Nx3x1 array or tensor of translation vectors.
128
+
129
+ Returns:
130
+ Inverted SE3 matrices with the same type and device as `se3`.
131
+
132
+ Shapes:
133
+ se3: (N, 4, 4)
134
+ R: (N, 3, 3)
135
+ T: (N, 3, 1)
136
+ """
137
+ # Check if se3 is a numpy array or a torch tensor
138
+ is_numpy = isinstance(se3, np.ndarray)
139
+
140
+ # Validate shapes
141
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
142
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
143
+
144
+ # Extract R and T if not provided
145
+ if R is None:
146
+ R = se3[:, :3, :3] # (N,3,3)
147
+ if T is None:
148
+ T = se3[:, :3, 3:] # (N,3,1)
149
+
150
+ # Transpose R
151
+ if is_numpy:
152
+ # Compute the transpose of the rotation for NumPy
153
+ R_transposed = np.transpose(R, (0, 2, 1))
154
+ # -R^T t for NumPy
155
+ top_right = -np.matmul(R_transposed, T)
156
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
157
+ else:
158
+ R_transposed = R.transpose(1, 2) # (N,3,3)
159
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
160
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
161
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
162
+
163
+ inverted_matrix[:, :3, :3] = R_transposed
164
+ inverted_matrix[:, :3, 3:] = top_right
165
+
166
+ return inverted_matrix
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/load_fn.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list, mode="crop"):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
20
+ - "crop" (default): Sets width to 518px and center crops height if needed.
21
+ - "pad": Preserves all pixels by making the largest dimension 518px
22
+ and padding the smaller dimension to reach a square shape.
23
+
24
+ Returns:
25
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
26
+
27
+ Raises:
28
+ ValueError: If the input list is empty or if mode is invalid
29
+
30
+ Notes:
31
+ - Images with different dimensions will be padded with white (value=1.0)
32
+ - A warning is printed when images have different shapes
33
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
34
+ and height is center-cropped if larger than 518px
35
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
36
+ and the smaller dimension is padded to reach a square shape (518x518)
37
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
38
+ """
39
+ # Check for empty list
40
+ if len(image_path_list) == 0:
41
+ raise ValueError("At least 1 image is required")
42
+
43
+ # Validate mode
44
+ if mode not in ["crop", "pad"]:
45
+ raise ValueError("Mode must be either 'crop' or 'pad'")
46
+
47
+ images = []
48
+ shapes = set()
49
+ to_tensor = TF.ToTensor()
50
+ target_size = 224
51
+
52
+ # First process all images and collect their shapes
53
+ for image_path in image_path_list:
54
+
55
+ # Open image
56
+ img = Image.open(image_path)
57
+
58
+ # If there's an alpha channel, blend onto white background:
59
+ if img.mode == "RGBA":
60
+ # Create white background
61
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
62
+ # Alpha composite onto the white background
63
+ img = Image.alpha_composite(background, img)
64
+
65
+ # Now convert to "RGB" (this step assigns white for transparent areas)
66
+ img = img.convert("RGB")
67
+
68
+ width, height = img.size
69
+
70
+ if mode == "pad":
71
+ # Make the largest dimension 518px while maintaining aspect ratio
72
+ if width >= height:
73
+ new_width = target_size
74
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
75
+ else:
76
+ new_height = target_size
77
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
78
+ else: # mode == "crop"
79
+ # Original behavior: set width to 518px
80
+ new_width = target_size
81
+ # Calcu late height maintaining aspect ratio, divisible by 14
82
+ # new_height = round(height * (new_width / width) / 14) * 14
83
+ new_height = target_size
84
+
85
+ # Resize with new dimensions (width, height)
86
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
87
+ img = to_tensor(img) # Convert to tensor (0, 1)
88
+
89
+ # Center crop height if it's larger than 518 (only in crop mode)
90
+ if mode == "crop" and new_height > target_size:
91
+ start_y = (new_height - target_size) // 2
92
+ img = img[:, start_y: start_y + target_size, :]
93
+
94
+ # For pad mode, pad to make a square of target_size x target_size
95
+ if mode == "pad":
96
+ h_padding = target_size - img.shape[1]
97
+ w_padding = target_size - img.shape[2]
98
+
99
+ if h_padding > 0 or w_padding > 0:
100
+ pad_top = h_padding // 2
101
+ pad_bottom = h_padding - pad_top
102
+ pad_left = w_padding // 2
103
+ pad_right = w_padding - pad_left
104
+
105
+ # Pad with white (value=1.0)
106
+ img = torch.nn.functional.pad(
107
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
108
+ )
109
+
110
+ shapes.add((img.shape[1], img.shape[2]))
111
+ images.append(img)
112
+
113
+ # Check if we have different shapes
114
+ # In theory our model can also work well with different shapes
115
+ if len(shapes) > 1:
116
+ print(f"Warning: Found images with different shapes: {shapes}")
117
+ # Find maximum dimensions
118
+ max_height = max(shape[0] for shape in shapes)
119
+ max_width = max(shape[1] for shape in shapes)
120
+
121
+ # Pad images if necessary
122
+ padded_images = []
123
+ for img in images:
124
+ h_padding = max_height - img.shape[1]
125
+ w_padding = max_width - img.shape[2]
126
+
127
+ if h_padding > 0 or w_padding > 0:
128
+ pad_top = h_padding // 2
129
+ pad_bottom = h_padding - pad_top
130
+ pad_left = w_padding // 2
131
+ pad_right = w_padding - pad_left
132
+
133
+ img = torch.nn.functional.pad(
134
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
135
+ )
136
+ padded_images.append(img)
137
+ images = padded_images
138
+
139
+ images = torch.stack(images) # concatenate images
140
+
141
+ # Ensure correct shape when single image
142
+ if len(image_path_list) == 1:
143
+ # Verify shape is (1, C, H, W)
144
+ if images.dim() == 3:
145
+ images = images.unsqueeze(0)
146
+
147
+ return images
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/pose_enc.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from .rotation import quat_to_mat, mat_to_quat
9
+
10
+
11
+ def extri_intri_to_pose_encoding(
12
+ extrinsics,
13
+ intrinsics,
14
+ image_size_hw=None, # e.g., (256, 512)
15
+ pose_encoding_type="absT_quaR_FoV",
16
+ ):
17
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
18
+
19
+ This function transforms camera parameters into a unified pose encoding format,
20
+ which can be used for various downstream tasks like pose prediction or representation.
21
+
22
+ Args:
23
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
24
+ where B is batch size and S is sequence length.
25
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
26
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
27
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
28
+ Defined in pixels, with format:
29
+ [[fx, 0, cx],
30
+ [0, fy, cy],
31
+ [0, 0, 1]]
32
+ where fx, fy are focal lengths and (cx, cy) is the principal point
33
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
34
+ Required for computing field of view values. For example: (256, 512).
35
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
36
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
37
+
38
+ Returns:
39
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
40
+ For "absT_quaR_FoV" type, the 9 dimensions are:
41
+ - [:3] = absolute translation vector T (3D)
42
+ - [3:7] = rotation as quaternion quat (4D)
43
+ - [7:] = field of view (2D)
44
+ """
45
+
46
+ # extrinsics: BxSx3x4
47
+ # intrinsics: BxSx3x3
48
+
49
+ if pose_encoding_type == "absT_quaR_FoV":
50
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
51
+ T = extrinsics[:, :, :3, 3] # BxSx3
52
+
53
+ quat = mat_to_quat(R)
54
+ # Note the order of h and w here
55
+ H, W = image_size_hw
56
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
57
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
58
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
59
+ else:
60
+ raise NotImplementedError
61
+
62
+ return pose_encoding
63
+
64
+
65
+ def pose_encoding_to_extri_intri(
66
+ pose_encoding,
67
+ image_size_hw=None, # e.g., (256, 512)
68
+ pose_encoding_type="absT_quaR_FoV",
69
+ build_intrinsics=True,
70
+ ):
71
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
72
+
73
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
74
+ reconstructing the full camera parameters from the compact encoding.
75
+
76
+ Args:
77
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
78
+ where B is batch size and S is sequence length.
79
+ For "absT_quaR_FoV" type, the 9 dimensions are:
80
+ - [:3] = absolute translation vector T (3D)
81
+ - [3:7] = rotation as quaternion quat (4D)
82
+ - [7:] = field of view (2D)
83
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
84
+ Required for reconstructing intrinsics from field of view values.
85
+ For example: (256, 512).
86
+ pose_encoding_type (str): Type of pose encoding used. Currently only
87
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
88
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
89
+ If False, only extrinsics are returned and intrinsics will be None.
90
+
91
+ Returns:
92
+ tuple: (extrinsics, intrinsics)
93
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
94
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
95
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
96
+ a 3x1 translation vector.
97
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
98
+ or None if build_intrinsics is False. Defined in pixels, with format:
99
+ [[fx, 0, cx],
100
+ [0, fy, cy],
101
+ [0, 0, 1]]
102
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
103
+ assumed to be at the center of the image (W/2, H/2).
104
+ """
105
+
106
+ intrinsics = None
107
+
108
+ if pose_encoding_type == "absT_quaR_FoV":
109
+ T = pose_encoding[..., :3]
110
+ quat = pose_encoding[..., 3:7]
111
+ fov_h = pose_encoding[..., 7]
112
+ fov_w = pose_encoding[..., 8]
113
+
114
+ R = quat_to_mat(quat)
115
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
116
+
117
+ if build_intrinsics:
118
+ H, W = image_size_hw
119
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
120
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
121
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
122
+ intrinsics[..., 0, 0] = fx
123
+ intrinsics[..., 1, 1] = fy
124
+ intrinsics[..., 0, 2] = W / 2
125
+ intrinsics[..., 1, 2] = H / 2
126
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
127
+ else:
128
+ raise NotImplementedError
129
+
130
+ return extrinsics, intrinsics
outdoor_v48_4gpu_v2/code/05_02-14:21:58/vggt/utils/visual_track.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def color_from_xy(x, y, W, H, cmap_name="hsv"):
14
+ """
15
+ Map (x, y) -> color in (R, G, B).
16
+ 1) Normalize x,y to [0,1].
17
+ 2) Combine them into a single scalar c in [0,1].
18
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
19
+
20
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
21
+ """
22
+ import matplotlib.cm
23
+ import matplotlib.colors
24
+
25
+ x_norm = x / max(W - 1, 1)
26
+ y_norm = y / max(H - 1, 1)
27
+ # Simple combination:
28
+ c = (x_norm + y_norm) / 2.0
29
+
30
+ cmap = matplotlib.cm.get_cmap(cmap_name)
31
+ # cmap(c) -> (r,g,b,a) in [0,1]
32
+ rgba = cmap(c)
33
+ r, g, b = rgba[0], rgba[1], rgba[2]
34
+ return (r, g, b) # in [0,1], RGB order
35
+
36
+
37
+ def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
38
+ """
39
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
40
+ in [0,255]. The color is determined by the (x,y) position in the first
41
+ visible frame for each track.
42
+
43
+ Args:
44
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
45
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
46
+ image_width, image_height: used for normalizing (x, y).
47
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
48
+
49
+ Returns:
50
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
51
+ """
52
+ S, N, _ = tracks_b.shape
53
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
54
+
55
+ if vis_mask_b is None:
56
+ # treat all as visible
57
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
58
+
59
+ for i in range(N):
60
+ # Find first visible frame for track i
61
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
62
+ if len(visible_frames) == 0:
63
+ # track is never visible; just assign black or something
64
+ track_colors[i] = (0, 0, 0)
65
+ continue
66
+
67
+ first_s = int(visible_frames[0].item())
68
+ # use that frame's (x,y)
69
+ x, y = tracks_b[first_s, i].tolist()
70
+
71
+ # map (x,y) -> (R,G,B) in [0,1]
72
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
73
+ # scale to [0,255]
74
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
75
+ track_colors[i] = (r, g, b)
76
+
77
+ return track_colors
78
+
79
+
80
+ def visualize_tracks_on_images(
81
+ images,
82
+ tracks,
83
+ track_vis_mask=None,
84
+ out_dir="track_visuals_concat_by_xy",
85
+ image_format="CHW", # "CHW" or "HWC"
86
+ normalize_mode="[0,1]",
87
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
88
+ frames_per_row=4, # New parameter for grid layout
89
+ save_grid=True, # Flag to control whether to save the grid image
90
+ ):
91
+ """
92
+ Visualizes frames in a grid layout with specified frames per row.
93
+ Each track's color is determined by its (x,y) position
94
+ in the first visible frame (or frame 0 if always visible).
95
+ Finally convert the BGR result to RGB before saving.
96
+ Also saves each individual frame as a separate PNG file.
97
+
98
+ Args:
99
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
100
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
101
+ track_vis_mask: torch.Tensor (S, N) or None.
102
+ out_dir: folder to save visualizations.
103
+ image_format: "CHW" or "HWC".
104
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
105
+ cmap_name: a matplotlib colormap name for color_from_xy.
106
+ frames_per_row: number of frames to display in each row of the grid.
107
+ save_grid: whether to save all frames in one grid image.
108
+
109
+ Returns:
110
+ None (saves images in out_dir).
111
+ """
112
+
113
+ if len(tracks.shape) == 4:
114
+ tracks = tracks.squeeze(0)
115
+ images = images.squeeze(0)
116
+ if track_vis_mask is not None:
117
+ track_vis_mask = track_vis_mask.squeeze(0)
118
+
119
+ import matplotlib
120
+
121
+ matplotlib.use("Agg") # for non-interactive (optional)
122
+
123
+ os.makedirs(out_dir, exist_ok=True)
124
+
125
+ S = images.shape[0]
126
+ _, N, _ = tracks.shape # (S, N, 2)
127
+
128
+ # Move to CPU
129
+ images = images.cpu().clone()
130
+ tracks = tracks.cpu().clone()
131
+ if track_vis_mask is not None:
132
+ track_vis_mask = track_vis_mask.cpu().clone()
133
+
134
+ # Infer H, W from images shape
135
+ if image_format == "CHW":
136
+ # e.g. images[s].shape = (3, H, W)
137
+ H, W = images.shape[2], images.shape[3]
138
+ else:
139
+ # e.g. images[s].shape = (H, W, 3)
140
+ H, W = images.shape[1], images.shape[2]
141
+
142
+ # Pre-compute the color for each track i based on first visible position
143
+ track_colors_rgb = get_track_colors_by_position(
144
+ tracks, # shape (S, N, 2)
145
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
146
+ image_width=W,
147
+ image_height=H,
148
+ cmap_name=cmap_name,
149
+ )
150
+
151
+ # We'll accumulate each frame's drawn image in a list
152
+ frame_images = []
153
+
154
+ for s in range(S):
155
+ # shape => either (3, H, W) or (H, W, 3)
156
+ img = images[s]
157
+
158
+ # Convert to (H, W, 3)
159
+ if image_format == "CHW":
160
+ img = img.permute(1, 2, 0) # (H, W, 3)
161
+ # else "HWC", do nothing
162
+
163
+ img = img.numpy().astype(np.float32)
164
+
165
+ # Scale to [0,255] if needed
166
+ if normalize_mode == "[0,1]":
167
+ img = np.clip(img, 0, 1) * 255.0
168
+ elif normalize_mode == "[-1,1]":
169
+ img = (img + 1.0) * 0.5 * 255.0
170
+ img = np.clip(img, 0, 255.0)
171
+ # else no normalization
172
+
173
+ # Convert to uint8
174
+ img = img.astype(np.uint8)
175
+
176
+ # For drawing in OpenCV, convert to BGR
177
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
178
+
179
+ # Draw each visible track
180
+ cur_tracks = tracks[s] # shape (N, 2)
181
+ if track_vis_mask is not None:
182
+ valid_indices = torch.where(track_vis_mask[s])[0]
183
+ else:
184
+ valid_indices = range(N)
185
+
186
+ cur_tracks_np = cur_tracks.numpy()
187
+ for i in valid_indices:
188
+ x, y = cur_tracks_np[i]
189
+ pt = (int(round(x)), int(round(y)))
190
+
191
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
192
+ R, G, B = track_colors_rgb[i]
193
+ color_bgr = (int(B), int(G), int(R))
194
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
195
+
196
+ # Convert back to RGB for consistent final saving:
197
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ # Save individual frame
200
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
201
+ # Convert to BGR for OpenCV imwrite
202
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
203
+ cv2.imwrite(frame_path, frame_bgr)
204
+
205
+ frame_images.append(img_rgb)
206
+
207
+ # Only create and save the grid image if save_grid is True
208
+ if save_grid:
209
+ # Calculate grid dimensions
210
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
211
+
212
+ # Create a grid of images
213
+ grid_img = None
214
+ for row in range(num_rows):
215
+ start_idx = row * frames_per_row
216
+ end_idx = min(start_idx + frames_per_row, S)
217
+
218
+ # Concatenate this row horizontally
219
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
220
+
221
+ # If this row has fewer than frames_per_row images, pad with black
222
+ if end_idx - start_idx < frames_per_row:
223
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
224
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
225
+ row_img = np.concatenate([row_img, padding], axis=1)
226
+
227
+ # Add this row to the grid
228
+ if grid_img is None:
229
+ grid_img = row_img
230
+ else:
231
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
232
+
233
+ out_path = os.path.join(out_dir, "tracks_grid.png")
234
+ # Convert back to BGR for OpenCV imwrite
235
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
236
+ cv2.imwrite(out_path, grid_img_bgr)
237
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
238
+
239
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
outdoor_v48_4gpu_v2/mytrain.log ADDED
The diff for this file is too large to render. See raw diff