bulatko commited on
Commit
8c8811e
·
1 Parent(s): edd96e4

fix: repair broken torch.nn imports in vggt modules

Browse files
vggt/heads/track_head.py CHANGED
@@ -4,7 +4,8 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- import torch # Attention computation.nn as nn
 
8
  from .dpt_head import DPTHead
9
  from .track_modules.base_track_predictor import BaseTrackerPredictor
10
 
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ import torch
8
+ import torch.nn as nn
9
  from .dpt_head import DPTHead
10
  from .track_modules.base_track_predictor import BaseTrackerPredictor
11
 
vggt/layers/attention.py CHANGED
@@ -13,7 +13,8 @@ import warnings
13
 
14
  from torch import Tensor
15
  from torch import nn
16
- import torch # Self-attention layers.nn.functional as F
 
17
 
18
  XFORMERS_AVAILABLE = False
19
 
 
13
 
14
  from torch import Tensor
15
  from torch import nn
16
+ import torch
17
+ import torch.nn.functional as F
18
 
19
  XFORMERS_AVAILABLE = False
20
 
vggt/models/vggt.py CHANGED
@@ -4,8 +4,8 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- import torch # Core tensor operations
8
- import torch # Core tensor operations.nn as nn
9
  from huggingface_hub import PyTorchModelHubMixin # used for model hub
10
 
11
  from vggt.models.aggregator import Aggregator
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ import torch
8
+ import torch.nn as nn
9
  from huggingface_hub import PyTorchModelHubMixin # used for model hub
10
 
11
  from vggt.models.aggregator import Aggregator
vggt/utils/rotation.py CHANGED
@@ -6,9 +6,9 @@
6
 
7
  # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
 
9
- import torch # Rotation matrices
 
10
  import numpy as np
11
- import torch # Rotation matrices.nn.functional as F
12
 
13
 
14
  def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
 
6
 
7
  # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
 
9
+ import torch
10
+ import torch.nn.functional as F
11
  import numpy as np
 
12
 
13
 
14
  def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: