bfshi commited on
Commit
c0c592e
·
1 Parent(s): 7e3b296
Files changed (38) hide show
  1. Dockerfile +35 -0
  2. README.md +4 -6
  3. app.py +18 -8
  4. autogaze/__init__.py +1 -0
  5. autogaze/__pycache__/__init__.cpython-310.pyc +0 -0
  6. autogaze/__pycache__/utils.cpython-310.pyc +0 -0
  7. autogaze/datasets/__init__.py +1 -0
  8. autogaze/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  9. autogaze/datasets/__pycache__/video_utils.cpython-310.pyc +0 -0
  10. autogaze/datasets/video_utils.py +133 -0
  11. autogaze/models/__init__.py +1 -0
  12. autogaze/models/__pycache__/__init__.cpython-310.pyc +0 -0
  13. autogaze/models/autogaze/__init__.py +17 -0
  14. autogaze/models/autogaze/__pycache__/__init__.cpython-310.pyc +0 -0
  15. autogaze/models/autogaze/__pycache__/autogaze.cpython-310.pyc +0 -0
  16. autogaze/models/autogaze/__pycache__/configuration_autogaze.cpython-310.pyc +0 -0
  17. autogaze/models/autogaze/__pycache__/modeling_autogaze.cpython-310.pyc +0 -0
  18. autogaze/models/autogaze/__pycache__/modeling_llama_multi_token_pred.cpython-310.pyc +0 -0
  19. autogaze/models/autogaze/autogaze.py +432 -0
  20. autogaze/models/autogaze/configuration_autogaze.py +326 -0
  21. autogaze/models/autogaze/modeling_autogaze.py +431 -0
  22. autogaze/models/autogaze/modeling_llama_multi_token_pred.py +471 -0
  23. autogaze/tasks/__init__.py +1 -0
  24. autogaze/tasks/__pycache__/__init__.cpython-310.pyc +0 -0
  25. autogaze/tasks/video_mae_reconstruction/__init__.py +1 -0
  26. autogaze/tasks/video_mae_reconstruction/__pycache__/__init__.cpython-310.pyc +0 -0
  27. autogaze/tasks/video_mae_reconstruction/__pycache__/configuration_video_mae.cpython-310.pyc +0 -0
  28. autogaze/tasks/video_mae_reconstruction/__pycache__/modeling_video_mae.cpython-310.pyc +0 -0
  29. autogaze/tasks/video_mae_reconstruction/__pycache__/task_video_mae_reconstruction.cpython-310.pyc +0 -0
  30. autogaze/tasks/video_mae_reconstruction/__pycache__/visualize_video_mae_reconstruction.cpython-310.pyc +0 -0
  31. autogaze/tasks/video_mae_reconstruction/configuration_video_mae.py +159 -0
  32. autogaze/tasks/video_mae_reconstruction/modeling_video_mae.py +1412 -0
  33. autogaze/tasks/video_mae_reconstruction/task_video_mae_reconstruction.py +182 -0
  34. autogaze/tasks/video_mae_reconstruction/visualize_video_mae_reconstruction.py +134 -0
  35. autogaze/utils.py +205 -0
  36. demo_utils.py +18 -16
  37. packages.txt +16 -0
  38. requirements.txt +2 -2
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ git \
7
+ git-lfs \
8
+ ffmpeg \
9
+ pkg-config \
10
+ libavcodec-dev \
11
+ libavformat-dev \
12
+ libavutil-dev \
13
+ libswscale-dev \
14
+ libswresample-dev \
15
+ libavdevice-dev \
16
+ libavfilter-dev \
17
+ libsm6 \
18
+ libxext6 \
19
+ cmake \
20
+ rsync \
21
+ libgl1 \
22
+ && rm -rf /var/lib/apt/lists/* \
23
+ && git lfs install
24
+
25
+ COPY requirements.txt .
26
+
27
+ RUN pip install --no-cache-dir -r requirements.txt
28
+
29
+ COPY . .
30
+
31
+ EXPOSE 7860
32
+
33
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
34
+
35
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,14 +1,12 @@
1
  ---
2
  title: AutoGaze
3
- emoji: 📉
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
- short_description: AutoGaze can remove redundant patches in any video.
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: AutoGaze
3
+ emoji: 👀
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 6.5.1
 
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import tempfile
3
  import os
@@ -8,13 +16,6 @@ import av
8
  from PIL import Image
9
  import numpy as np
10
 
11
- try:
12
- import spaces
13
- ZEROGPU_AVAILABLE = True
14
- except ImportError:
15
- ZEROGPU_AVAILABLE = False
16
- print("Warning: spaces module not available. Running without ZeroGPU support.")
17
-
18
  model_cache = {}
19
 
20
  def get_model(device):
@@ -22,7 +23,16 @@ def get_model(device):
22
  model_cache[device] = load_model(device=device)
23
  return model_cache[device]
24
 
25
- device = "cuda" if torch.cuda.is_available() or ZEROGPU_AVAILABLE else "cpu"
 
 
 
 
 
 
 
 
 
26
 
27
  def cleanup_gpu():
28
  """Clean up GPU memory."""
 
1
+ # IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.)
2
+ try:
3
+ import spaces
4
+ ZEROGPU_AVAILABLE = True
5
+ except ImportError:
6
+ ZEROGPU_AVAILABLE = False
7
+ print("Warning: spaces module not available. Running without ZeroGPU support.")
8
+
9
  import gradio as gr
10
  import tempfile
11
  import os
 
16
  from PIL import Image
17
  import numpy as np
18
 
 
 
 
 
 
 
 
19
  model_cache = {}
20
 
21
  def get_model(device):
 
23
  model_cache[device] = load_model(device=device)
24
  return model_cache[device]
25
 
26
+ # Determine device: use CUDA if available locally or if ZeroGPU will provide it
27
+ if ZEROGPU_AVAILABLE:
28
+ device = "cuda" # ZeroGPU will provide GPU
29
+ print("Using ZeroGPU (CUDA device will be allocated on demand)")
30
+ elif torch.cuda.is_available():
31
+ device = "cuda"
32
+ print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
33
+ else:
34
+ device = "cpu"
35
+ print("No GPU available, using CPU")
36
 
37
  def cleanup_gpu():
38
  """Clean up GPU memory."""
autogaze/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """AutoGaze package for video patch reduction."""
autogaze/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
autogaze/__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.09 kB). View file
 
autogaze/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """AutoGaze datasets and utilities."""
autogaze/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (196 Bytes). View file
 
autogaze/datasets/__pycache__/video_utils.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
autogaze/datasets/video_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Common utilities for video loading and processing."""
2
+
3
+ import av
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def get_relative_video_path(path):
9
+ """
10
+ Get the last three levels of the path as the relative path to the video.
11
+ Args:
12
+ path (str): Path to get the last three levels of.
13
+ Returns:
14
+ last_three (str): Last three levels of the path.
15
+ """
16
+ parts = path.replace("\\", "/").split("/")
17
+ return "/".join(parts[-3:]) if len(parts) >= 3 else path
18
+
19
+
20
+ def read_video_pyav(container, indices):
21
+ """
22
+ Decode the video with PyAV decoder.
23
+ Args:
24
+ container (`av.container.input.InputContainer`): PyAV container.
25
+ indices (`List[int]`): List of frame indices to decode.
26
+ Returns:
27
+ result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
28
+ """
29
+ frames = []
30
+ container.seek(0)
31
+ start_index = indices[0]
32
+ end_index = indices[-1]
33
+ for i, frame in enumerate(container.decode(video=0)):
34
+ if i > end_index:
35
+ break
36
+ if i >= start_index and i in indices:
37
+ frames.append(frame)
38
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
39
+
40
+
41
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len, random_sample_frame=False):
42
+ """
43
+ Sample a given number of frame indices from the video.
44
+ Args:
45
+ clip_len (`int`): Total number of frames to sample.
46
+ frame_sample_rate (`int`): Sample every n-th frame.
47
+ seg_len (`int`): Maximum allowed index of sample's last frame.
48
+ Returns:
49
+ indices (`List[int]`): List of sampled frame indices
50
+ """
51
+ converted_len = int(clip_len * frame_sample_rate)
52
+ if seg_len <= converted_len:
53
+ # Not enough frames, just return the first clip_len frames (or as many as possible)
54
+ indices = np.arange(min(clip_len, seg_len))
55
+ indices = np.pad(indices, (0, max(0, clip_len - len(indices))), mode="edge")
56
+ return indices.astype(np.int64)
57
+ if random_sample_frame:
58
+ end_idx = np.random.randint(converted_len, seg_len)
59
+ start_idx = end_idx - converted_len
60
+ else:
61
+ start_idx = 0
62
+ end_idx = converted_len
63
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
64
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
65
+ return indices
66
+
67
+
68
+ def process_video_frames(video, clip_len):
69
+ """
70
+ Process video frames to ensure correct shape and length.
71
+ Args:
72
+ video (np.ndarray): Video frames of shape (num_frames, H, W, 3)
73
+ clip_len (int): Target number of frames
74
+ Returns:
75
+ video (np.ndarray): Processed video of shape (clip_len, H, W, 3)
76
+ """
77
+ # Ensure video has shape (clip_len, H, W, 3)
78
+ if video.shape[0] != clip_len:
79
+ # Pad or repeat last frame if needed
80
+ if video.shape[0] < clip_len:
81
+ pad_frames = clip_len - video.shape[0]
82
+ last_frame = video[-1:]
83
+ video = np.concatenate(
84
+ [video, np.repeat(last_frame, pad_frames, axis=0)], axis=0
85
+ )
86
+ else:
87
+ video = video[:clip_len]
88
+
89
+ assert video.shape[0] == clip_len, (
90
+ f"Video has {video.shape[0]} frames, expected {clip_len}"
91
+ )
92
+ assert video.ndim == 4 and video.shape[-1] == 3, (
93
+ f"Video shape is {video.shape}, expected (clip_len, H, W, 3)"
94
+ )
95
+
96
+ return video
97
+
98
+
99
+ def transform_video_for_pytorch(video, transform=None):
100
+ """
101
+ Transform video frames and convert to PyTorch format.
102
+ Args:
103
+ video (np.ndarray): Video frames of shape (clip_len, H, W, 3)
104
+ transform: Optional transform to apply
105
+ Returns:
106
+ img (np.ndarray): Transformed video of shape (clip_len, C, H, W)
107
+ """
108
+ if transform is not None:
109
+ imgs = transform(list(video)).pixel_values
110
+ if isinstance(imgs[0], list): # frames are wrapped in a python list
111
+ img = imgs[0]
112
+ else:
113
+ img = imgs # frames are not wrapped in a python list
114
+ img = np.stack(img)
115
+ else:
116
+ img = video # fallback: return raw video
117
+
118
+ # Ensure output is (clip_len, C, H, W) for pytorch
119
+ if img.shape[1] == 3 and img.shape[-1] != 3:
120
+ # Already (clip_len, C, H, W)
121
+ pass
122
+ elif img.shape[-1] == 3:
123
+ # (clip_len, H, W, 3) -> (clip_len, 3, H, W)
124
+ img = np.transpose(img, (0, 3, 1, 2))
125
+ else:
126
+ raise ValueError(f"Unexpected image shape after transform: {img.shape}")
127
+
128
+ clip_len = img.shape[0]
129
+ assert img.shape[0] == clip_len and img.shape[1] == 3, (
130
+ f"Output img shape: {img.shape}"
131
+ )
132
+
133
+ return torch.tensor(img)
autogaze/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """AutoGaze models."""
autogaze/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (178 Bytes). View file
 
autogaze/models/autogaze/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .autogaze import AutoGaze
2
+ from .configuration_autogaze import (
3
+ AutoGazeConfig,
4
+ GazeModelConfig,
5
+ VisionModelConfig,
6
+ ConnectorConfig,
7
+ GazeDecoderConfig,
8
+ )
9
+
10
+ __all__ = [
11
+ "AutoGaze",
12
+ "AutoGazeConfig",
13
+ "GazeModelConfig",
14
+ "VisionModelConfig",
15
+ "ConnectorConfig",
16
+ "GazeDecoderConfig",
17
+ ]
autogaze/models/autogaze/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (420 Bytes). View file
 
autogaze/models/autogaze/__pycache__/autogaze.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
autogaze/models/autogaze/__pycache__/configuration_autogaze.cpython-310.pyc ADDED
Binary file (9.83 kB). View file
 
autogaze/models/autogaze/__pycache__/modeling_autogaze.cpython-310.pyc ADDED
Binary file (14.2 kB). View file
 
autogaze/models/autogaze/__pycache__/modeling_llama_multi_token_pred.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
autogaze/models/autogaze/autogaze.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from copy import deepcopy
3
+ import math
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from contextlib import nullcontext
8
+ from einops import rearrange
9
+ from omegaconf import OmegaConf
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from autogaze.utils import get_gazing_pos_from_gazing_mask
13
+ from .modeling_autogaze import AutoGazeModel
14
+ from .configuration_autogaze import AutoGazeConfig
15
+
16
+ class AutoGaze(PreTrainedModel):
17
+ config_class = AutoGazeConfig
18
+
19
+ def __init__(self, config: AutoGazeConfig):
20
+ super().__init__(config)
21
+
22
+ self.config = config
23
+ self.gazing_ratio_config = config.gazing_ratio_config
24
+ self.gazing_ratio_each_frame_config = config.gazing_ratio_each_frame_config
25
+ self.scales = sorted([int(scale) for scale in str(config.scales).split('+')])
26
+ self.num_vision_tokens_each_frame = config.num_vision_tokens_each_frame
27
+ self.num_vision_tokens_each_scale_each_frame = [int(scale**2 / sum([scale**2 for scale in self.scales]) * self.num_vision_tokens_each_frame) for scale in self.scales]
28
+ self.frame_sampling_rate = config.gaze_model_config.vision_model_config.temporal_patch_size
29
+ self.image_mean = config.image_mean
30
+ self.image_std = config.image_std
31
+ self.attn_mode = config.attn_mode
32
+
33
+ # Create the gazing model
34
+ self.gazing_model = AutoGazeModel(config.gaze_model_config)
35
+
36
+ # Task loss requirement
37
+ self.has_task_loss_requirement_during_training = config.has_task_loss_requirement_during_training
38
+ self.has_task_loss_requirement_during_inference = config.has_task_loss_requirement_during_inference
39
+ self.task_loss_requirement_config = config.task_loss_requirement_config
40
+
41
+ def get_gazing_ratio(self, sync_across_ranks=True):
42
+ """
43
+ Sample the gazing ratio for the whole video according to the config.
44
+ """
45
+ sample_strategy = self.gazing_ratio_config['sample_strategy_during_training'] if self.training else self.gazing_ratio_config['sample_strategy_during_inference']
46
+ if sample_strategy == 'fixed':
47
+ ratio = self.gazing_ratio_config['fixed']['gazing_ratio']
48
+ elif sample_strategy == 'uniform':
49
+ ratio = random.uniform(self.gazing_ratio_config['uniform']['gazing_ratio_min'], self.gazing_ratio_config['uniform']['gazing_ratio_max'])
50
+ elif sample_strategy == 'exponential':
51
+ ratio = random.expovariate(self.gazing_ratio_config['exponential']['lambda'])
52
+ while ratio < self.gazing_ratio_config['exponential']['gazing_ratio_min'] or ratio > self.gazing_ratio_config['exponential']['gazing_ratio_max']:
53
+ ratio = random.expovariate(self.gazing_ratio_config['exponential']['lambda'])
54
+
55
+ if sync_across_ranks:
56
+ ratio = torch.tensor(ratio).cuda()
57
+ if torch.distributed.is_initialized():
58
+ torch.distributed.broadcast(ratio, src=0) # Make every rank use the same gazing ratio. Otherwise, each rank will have different gazing ratio, and the train/inference time is bounded by the slowest rank (with highest gazing ratio).
59
+ ratio = ratio.item()
60
+
61
+ return ratio
62
+
63
+ def get_gazing_ratio_each_frame(self, inputs, video, gazing_ratio_mean, num_frames, temperature, use_cache):
64
+ """
65
+ Sample the gazing ratio for each frame according to the config.
66
+ """
67
+ sample_strategy = self.gazing_ratio_each_frame_config['sample_strategy_during_training'] if self.training else self.gazing_ratio_each_frame_config['sample_strategy_during_inference']
68
+ if sample_strategy == 'uniform':
69
+ gazing_ratio_each_frame = torch.ones(num_frames) * gazing_ratio_mean
70
+ elif sample_strategy == 'dirichlet':
71
+ gazing_ratio_agg = gazing_ratio_mean * num_frames
72
+ alpha = self.gazing_ratio_each_frame_config['dirichlet']['alpha']
73
+ if isinstance(alpha, str):
74
+ alpha = torch.tensor([float(a) for a in alpha.split(',')])
75
+ assert len(alpha) == num_frames, "The number of alpha values must be equal to the number of frames"
76
+ gazing_ratio_each_frame = torch.distributions.dirichlet.Dirichlet(torch.ones(num_frames) * alpha).sample() * gazing_ratio_agg
77
+ gazing_ratio_each_frame = gazing_ratio_each_frame.clamp(min=0, max=1)
78
+ elif sample_strategy == 'self':
79
+ assert use_cache == False, "using cache is not supported for self-predicted gazing ratio"
80
+
81
+ # Only preserve one sample for each group
82
+ if "group_size" in inputs:
83
+ video = rearrange(video, '(g b) t c h w -> g b t c h w', g=inputs["group_size"])[0]
84
+
85
+ assert video.shape[0] == 1, "Currently only batch_size=1 is supported because otherwise we need to support different gazing ratio constraints in the same batch in model.generate()"
86
+
87
+ # Max gazing ratio for each frame
88
+ max_gazing_ratio_each_frame = torch.ones(num_frames) * gazing_ratio_mean
89
+ max_num_gaze_tokens_each_frame = (max_gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
90
+
91
+ # Sample task loss requirement
92
+ task_loss_requirement = self.get_task_loss_requirement(video, force_sampling=True)
93
+
94
+ # Sample the gazing
95
+ with torch.no_grad():
96
+ if self.training:
97
+ gazing_info = self.gazing_model.generate(
98
+ video,
99
+ max_gaze_tokens_each_frame=max_num_gaze_tokens_each_frame,
100
+ task_loss_requirement=task_loss_requirement,
101
+ do_sample=True,
102
+ temperature=temperature,
103
+ )
104
+ else:
105
+ gazing_info = self.gazing_model.generate(
106
+ video,
107
+ max_gaze_tokens_each_frame=max_num_gaze_tokens_each_frame,
108
+ task_loss_requirement=task_loss_requirement,
109
+ do_sample=False,
110
+ )
111
+
112
+ if_padded_gazing = gazing_info["if_padded_gazing"]
113
+ num_gazing_each_frame = gazing_info["num_gazing_each_frame"]
114
+ if_padded_gazing = if_padded_gazing.split(num_gazing_each_frame.tolist(), dim=1)
115
+ num_non_padded_gazing_each_frame = torch.stack([(~if_padded_gazing[i]).sum(dim=-1) for i in range(len(if_padded_gazing))], dim=1) # (B, num_frames)
116
+
117
+ gazing_ratio_each_frame = num_non_padded_gazing_each_frame[0] / self.num_vision_tokens_each_frame
118
+ else:
119
+ raise NotImplementedError(f"Sample strategy {sample_strategy} not implemented.")
120
+
121
+ return gazing_ratio_each_frame
122
+
123
+ def get_task_loss_requirement(self, video, sync_across_ranks=True, force_sampling=False):
124
+ """
125
+ Sample the task loss requirement for each frame according to the config.
126
+
127
+ inputs:
128
+ video: tensor of shape (B, T, C, H, W)
129
+ returns:
130
+ task_loss_requirement: tensor of shape (B, T // frame_sampling_rate), representing the task loss requirement for each frame of each video. None if no task loss requirement is used.
131
+ """
132
+ has_task_loss_requirement = self.has_task_loss_requirement_during_training if self.training else self.has_task_loss_requirement_during_inference
133
+ if not has_task_loss_requirement and not force_sampling:
134
+ return None
135
+
136
+ B, T = video.shape[:2]
137
+ sample_strategy = self.task_loss_requirement_config['sample_strategy_during_training'] if self.training else self.task_loss_requirement_config['sample_strategy_during_inference']
138
+ if sample_strategy == 'fixed':
139
+ task_loss_requirement = self.task_loss_requirement_config['fixed']['task_loss_requirement']
140
+ task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
141
+ elif sample_strategy == 'uniform':
142
+ task_loss_requirement_min = self.task_loss_requirement_config['uniform']['task_loss_requirement_min']
143
+ task_loss_requirement_max = self.task_loss_requirement_config['uniform']['task_loss_requirement_max']
144
+ task_loss_requirement = random.uniform(task_loss_requirement_min, task_loss_requirement_max)
145
+ task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
146
+ else:
147
+ raise NotImplementedError(f"Task loss requirement sample strategy {self.task_loss_requirement_config['sample_strategy']} not implemented")
148
+
149
+ if sync_across_ranks:
150
+ if torch.distributed.is_initialized():
151
+ torch.distributed.broadcast(task_loss_requirement, src=0) # Make every rank use the same gazing ratio. Otherwise, each rank will have different gazing ratio, and the train/inference time is bounded by the slowest rank (with highest gazing ratio).
152
+
153
+ return task_loss_requirement
154
+
155
+ def get_mask_from_gazing_pos(self, video, gazing_pos, if_padded_gazing):
156
+ """
157
+ Create the video gazing mask from the gazing positions.
158
+
159
+ inputs:
160
+ video: B, T, C, H, W
161
+ gazing_pos: B, N
162
+ if_padded_gazing: B, N
163
+ returns:
164
+ mask: list of B * T * N_each_scale
165
+ """
166
+ B, T = video.shape[:2]
167
+ mask = torch.zeros(B, self.num_vision_tokens_each_frame * (T // self.frame_sampling_rate) + 1, device=video.device) # +1 for the padded gazing positions
168
+ tmp_gazing_pos = gazing_pos.clone()
169
+ tmp_gazing_pos[if_padded_gazing] = mask.shape[1] - 1 # Set the padded gazing positions to the last position
170
+ mask[torch.arange(B)[:, None], tmp_gazing_pos] = 1
171
+ mask = mask[:, :-1] # Remove the last position (padded gazing positions)
172
+ mask = mask.reshape(B, T // self.frame_sampling_rate, self.num_vision_tokens_each_frame)
173
+ mask = [mask[:, :, sum(self.num_vision_tokens_each_scale_each_frame[:i]):sum(self.num_vision_tokens_each_scale_each_frame[:i+1])] for i in range(len(self.scales))] # list of B * T * N_each_scale
174
+
175
+ return mask
176
+
177
+ def input_res_adapt(self, pixel_values, target_scales, target_patch_size):
178
+ """
179
+ Preprocess the input to adapt to the target scales and patch size.
180
+
181
+ inputs:
182
+ pixel_values: B, T, C, H, W
183
+ returns:
184
+ pixel_values: B, T, C, H, W
185
+ res_adapt_info: dict, the information of resolution adaptation, for future recovery.
186
+ """
187
+ B, T, C, H, W = pixel_values.shape
188
+ assert H == W == target_scales[-1], "Now we need the input video to be the same size as the largest scale of the vision model" # FIXME: in the future we should use relative resize ratio as the scales, e.g., 0.125+0.25+0.5+1. In this way we can also support naflex ViT.
189
+ assert len(self.scales) == len(target_scales), "The scales of the gaze model and the vision model must be the same"
190
+ tile_feature_map_size_each_scale = [int(self.num_vision_tokens_each_scale_each_frame[i] ** 0.5) for i in range(len(self.scales))]
191
+ original_feature_map_height_each_scale = [target_scales[i] // target_patch_size for i in range(len(target_scales))]
192
+ original_feature_map_width_each_scale = [target_scales[i] // target_patch_size for i in range(len(target_scales))]
193
+ num_tiles_height = math.ceil(original_feature_map_height_each_scale[-1] / tile_feature_map_size_each_scale[-1])
194
+ num_tiles_width = math.ceil(original_feature_map_width_each_scale[-1] / tile_feature_map_size_each_scale[-1])
195
+ pad_H = num_tiles_height * tile_feature_map_size_each_scale[-1] * target_patch_size - H
196
+ pad_W = num_tiles_width * tile_feature_map_size_each_scale[-1] * target_patch_size - W
197
+ pixel_values = F.pad(pixel_values, (0, pad_W, 0, pad_H))
198
+ pixel_values = rearrange(pixel_values, 'b t c (nh sh) (nw sw) -> (b nh nw) t c sh sw', nh=num_tiles_height, nw=num_tiles_width)
199
+ res_adapt_info = {
200
+ 'tile_feature_map_size_each_scale': tile_feature_map_size_each_scale,
201
+ 'original_feature_map_height_each_scale': original_feature_map_height_each_scale,
202
+ 'original_feature_map_width_each_scale': original_feature_map_width_each_scale,
203
+ 'num_tiles_height': num_tiles_height,
204
+ 'num_tiles_width': num_tiles_width,
205
+ 'pad_H': pad_H,
206
+ 'pad_W': pad_W,
207
+ }
208
+
209
+ return pixel_values, res_adapt_info
210
+
211
+ def recover_output_from_res_adapt(self, gaze_outputs, res_adapt_info):
212
+ """
213
+ Postprocess the output to recover from resolution adaptation.
214
+
215
+ inputs:
216
+ gaze_outputs: dict, the outputs of the gazing model.
217
+ res_adapt_info: dict, the information of resolution adaptation.
218
+ returns:
219
+ gaze_outputs: dict, the outputs of the gazing model.
220
+ """
221
+ num_tiles_height = res_adapt_info['num_tiles_height']
222
+ num_tiles_width = res_adapt_info['num_tiles_width']
223
+ tile_feature_map_size_each_scale = res_adapt_info['tile_feature_map_size_each_scale']
224
+ original_feature_map_height_each_scale = res_adapt_info['original_feature_map_height_each_scale']
225
+ original_feature_map_width_each_scale = res_adapt_info['original_feature_map_width_each_scale']
226
+
227
+ # Recover the gazing mask. Remove the gazing for the padded regions.
228
+ new_gazing_mask = []
229
+ for scale_idx in range(len(gaze_outputs['scales'])):
230
+ cur_gazing_mask = gaze_outputs['gazing_mask'][scale_idx]
231
+ cur_gazing_mask = rearrange(cur_gazing_mask, '(b nh nw) t (sh sw) -> b t (nh sh) (nw sw)', nh=num_tiles_height, nw=num_tiles_width, sh=tile_feature_map_size_each_scale[scale_idx], sw=tile_feature_map_size_each_scale[scale_idx])
232
+ cur_gazing_mask = cur_gazing_mask[:, :, :original_feature_map_height_each_scale[scale_idx], :original_feature_map_width_each_scale[scale_idx]]
233
+ cur_gazing_mask = cur_gazing_mask.flatten(-2, -1) # (b t (nh sh) (nw sw)) -> (b t (nh sh * nw sw))
234
+ new_gazing_mask.append(cur_gazing_mask)
235
+
236
+ # Recover the num_gazing_each_frame and num_vision_tokens_each_frame
237
+ new_num_vision_tokens_each_frame = sum([mask.shape[-1] for mask in new_gazing_mask])
238
+
239
+ # Recover the gazing pos, if_padded_gazing, and num_gazing_each_frame, by inderring from the gazing mask. Note this will lose the original order of the gazing!
240
+ new_gazing_mask_all_scales = torch.cat(new_gazing_mask, dim=-1) # B, T, N
241
+ B, T = new_gazing_mask_all_scales.shape[:2]
242
+ new_gazing_pos, new_if_padded_gazing = get_gazing_pos_from_gazing_mask(new_gazing_mask_all_scales.flatten(0, 1))
243
+ new_gazing_pos, new_if_padded_gazing = rearrange(new_gazing_pos, '(b t) n -> b t n', b=B, t=T), rearrange(new_if_padded_gazing, '(b t) n -> b t n', b=B, t=T)
244
+ max_num_gazing_each_frame = (~new_if_padded_gazing).sum(dim=-1).max(dim=0)[0]
245
+ assert all([torch.all(new_if_padded_gazing[:, t, num:] == True) for t, num in enumerate(max_num_gazing_each_frame)]), "The removed gazing should all be padded."
246
+ new_gazing_pos = [new_gazing_pos[:, t, :num] for t, num in enumerate(max_num_gazing_each_frame)]
247
+ new_if_padded_gazing = [new_if_padded_gazing[:, t, :num] for t, num in enumerate(max_num_gazing_each_frame)]
248
+ new_gazing_pos = [gazing_pos + new_num_vision_tokens_each_frame * t for t, gazing_pos in enumerate(new_gazing_pos)]
249
+ new_gazing_pos, new_if_padded_gazing = torch.cat(new_gazing_pos, dim=1), torch.cat(new_if_padded_gazing, dim=1)
250
+ new_num_gazing_each_frame = max_num_gazing_each_frame
251
+
252
+ # Update the outputs
253
+ gaze_outputs['gazing_pos'] = new_gazing_pos
254
+ gaze_outputs['gazing_mask'] = new_gazing_mask
255
+ gaze_outputs['frame_sampling_rate'] = gaze_outputs['frame_sampling_rate']
256
+ gaze_outputs['num_vision_tokens_each_frame'] = new_num_vision_tokens_each_frame
257
+ gaze_outputs['num_gazing_each_frame'] = new_num_gazing_each_frame
258
+ gaze_outputs['if_padded_gazing'] = new_if_padded_gazing
259
+
260
+ # Currently we haven't reordered actions probs and task loss prediction based on the new gazing pos, so delete it for now for safety.
261
+ del(gaze_outputs['log_action_probs'])
262
+ del(gaze_outputs['task_loss_prediction'])
263
+
264
+ return gaze_outputs
265
+
266
+ #FIXME: separate forward and generate functions
267
+ def forward(
268
+ self,
269
+ inputs,
270
+ target_scales=None,
271
+ target_patch_size=None,
272
+ target_image_mean=None,
273
+ target_image_std=None,
274
+ gazing_info=None,
275
+ temperature=1,
276
+ gazing_ratio=None,
277
+ task_loss_requirement=None,
278
+ generate_only=False,
279
+ use_cache=False,
280
+ past_key_values=None,
281
+ past_inputs_embeds=None,
282
+ past_attention_mask=None,
283
+ past_conv_values=None,
284
+ ):
285
+ """
286
+ inputs:
287
+ video: B, T, C, H, W
288
+ target_scales: list of scales for downstream vision model. If None, then use the scales in the gaze model.
289
+ target_patch_size: patch size for downstream vision model. If None, then use the patch size in the gaze model.
290
+ target_image_mean: image mean for downstream vision model. If None, then use the image mean in the gaze model.
291
+ target_image_std: image std for downstream vision model. If None, then use the image std in the gaze model.
292
+ gazing_info: dict, the ground truth gazing information for NTP pre-training. If None, then run the gazing model to predict gazing positions.
293
+ temperature: temperature for generating gazing.
294
+ gazing_ratio: gazing ratio for the gazing model. If None, then sample the gazing ratio according to the config.
295
+ task_loss_requirement: task loss requirement for the gazing model. If None, then sample the task loss requirement according to the config.
296
+ generate_only: whether to only generate the gazing positions, or to also calculate the probability of taking such gaze.
297
+ use_cache: whether to use the cache for the gazing model.
298
+ past_key_values: the past key values for the gazing model.
299
+ past_inputs_embeds: the past inputs embeds for the gazing model.
300
+ past_attention_mask: the past attention mask for the gazing model.
301
+ past_conv_values: the past conv values for the gazing model.
302
+ returns:
303
+ to_return: dict, the outputs of the gazing model.
304
+ """
305
+ if not generate_only:
306
+ assert past_key_values is None and past_inputs_embeds is None and past_attention_mask is None and past_conv_values is None, \
307
+ "If not in generate-only mode, we don't support past_key_values, past_inputs_embeds, past_attention_mask, and past_conv_values yet."
308
+
309
+ video = inputs['video']
310
+
311
+ # Preprocess the input to fix the image mean and std
312
+ if target_image_mean is not None and target_image_std is not None:
313
+ video = rearrange(video, 'b t c h w -> b t h w c')
314
+ video = video * torch.tensor(target_image_std, device=video.device, dtype=video.dtype) + torch.tensor(target_image_mean, device=video.device, dtype=video.dtype)
315
+ video = video * 2 - 1 # Vivit preprocesssor has a rescaling factor of 1/127.5 instead of 1/255, and it has an offset of -1.
316
+ video = (video - torch.tensor(self.image_mean, device=video.device, dtype=video.dtype)) / torch.tensor(self.image_std, device=video.device, dtype=video.dtype)
317
+ video = rearrange(video, 'b t h w c -> b t c h w')
318
+
319
+ # Preprocess the input for resolution adaptation
320
+ if target_scales is not None and target_patch_size is not None:
321
+ if not (target_scales == self.scales and [(scale // target_patch_size) ** 2 for scale in target_scales] == self.num_vision_tokens_each_scale_each_frame):
322
+ video, res_adapt_info = self.input_res_adapt(video, target_scales, target_patch_size)
323
+
324
+ B, T = video.shape[:2]
325
+
326
+ # If gazing_pos is already provided, then directly calculate the probability of taking such gaze. Usually in the cases of calculating pi(a|s) in PPO/GRPO/etc.
327
+ # Otherwise, run the gazing model first to predict gazing positions.
328
+ if gazing_info is None or len(gazing_info) == 0:
329
+ with torch.autocast("cuda", dtype=torch.bfloat16) if self.attn_mode == "flash_attention_2" else nullcontext():
330
+
331
+ if gazing_ratio is not None and task_loss_requirement is not None:
332
+ # If the user specifies the gazing ratio and task loss requirement, then use gazing ratio as the max gazing ratio and use task loss requirement to control when to stop
333
+ if isinstance(gazing_ratio, list):
334
+ assert len(gazing_ratio) == T // self.frame_sampling_rate, "The number of gazing ratios must be equal to the number of frames"
335
+ gazing_ratio = torch.tensor(gazing_ratio)
336
+ gazing_ratio_each_frame = torch.ones(T // self.frame_sampling_rate) * gazing_ratio
337
+ num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
338
+ task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
339
+ elif gazing_ratio is not None:
340
+ # If the user specifies the gazing ratio, then turn off the task loss requirement
341
+ if isinstance(gazing_ratio, list):
342
+ assert len(gazing_ratio) == T // self.frame_sampling_rate, "The number of gazing ratios must be equal to the number of frames"
343
+ gazing_ratio = torch.tensor(gazing_ratio)
344
+ gazing_ratio_each_frame = torch.ones(T // self.frame_sampling_rate) * gazing_ratio
345
+ num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
346
+ task_loss_requirement = None
347
+ elif task_loss_requirement is not None:
348
+ # If the user specifies the task loss requirement, then turn off the gazing ratio limit
349
+ gazing_ratio = 1
350
+ gazing_ratio_each_frame = torch.ones(T // self.frame_sampling_rate) * gazing_ratio
351
+ num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
352
+ task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
353
+ else:
354
+ gazing_ratio = self.get_gazing_ratio()
355
+ gazing_ratio_each_frame = self.get_gazing_ratio_each_frame(inputs, video, gazing_ratio, T // self.frame_sampling_rate, temperature, use_cache)
356
+ num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
357
+ task_loss_requirement = self.get_task_loss_requirement(video)
358
+
359
+ if self.training:
360
+ gazing_info = self.gazing_model.generate(
361
+ video,
362
+ max_gaze_tokens_each_frame=num_gaze_tokens_each_frame,
363
+ task_loss_requirement=task_loss_requirement,
364
+ do_sample=True,
365
+ temperature=temperature,
366
+ use_cache=use_cache,
367
+ past_key_values=past_key_values,
368
+ past_inputs_embeds=past_inputs_embeds,
369
+ past_attention_mask=past_attention_mask,
370
+ past_conv_values=past_conv_values,
371
+ )
372
+ else:
373
+ gazing_info = self.gazing_model.generate(
374
+ video,
375
+ max_gaze_tokens_each_frame=num_gaze_tokens_each_frame,
376
+ task_loss_requirement=task_loss_requirement,
377
+ do_sample=False,
378
+ use_cache=use_cache,
379
+ past_key_values=past_key_values,
380
+ past_inputs_embeds=past_inputs_embeds,
381
+ past_attention_mask=past_attention_mask,
382
+ past_conv_values=past_conv_values,
383
+ )
384
+
385
+ # Unpack gazing_info
386
+ gazing_pos = gazing_info["gazing_pos"]
387
+ num_gazing_each_frame = gazing_info["num_gazing_each_frame"]
388
+ if_padded_gazing = gazing_info["if_padded_gazing"]
389
+ task_loss_requirement = gazing_info.get("task_loss_requirement", None)
390
+ new_past_key_values = gazing_info.get("past_key_values", None)
391
+ new_past_inputs_embeds = gazing_info.get("past_inputs_embeds", None)
392
+ new_past_attention_mask = gazing_info.get("past_attention_mask", None)
393
+ new_past_conv_values = gazing_info.get("past_conv_values", None)
394
+
395
+ # Get the log probablity of taking such gaze (log_action_probs)
396
+ if not generate_only:
397
+ with torch.autocast("cuda", dtype=torch.bfloat16):
398
+ forward_outputs = self.gazing_model(video, gazing_info) # B * N
399
+ action_probs = forward_outputs.gaze_probs
400
+ task_loss_prediction = forward_outputs.task_loss_prediction
401
+ log_action_probs = torch.log(action_probs + 1e-8) # B * N
402
+ else:
403
+ log_action_probs = None
404
+ task_loss_prediction = None
405
+
406
+ # Generate (multi-scale) gazing masks for ease of visualization
407
+ mask = self.get_mask_from_gazing_pos(video, gazing_pos, if_padded_gazing)
408
+
409
+ to_return = {
410
+ 'gazing_pos': gazing_pos,
411
+ 'log_action_probs': log_action_probs,
412
+ 'gazing_mask': mask,
413
+ "scales": self.scales,
414
+ "frame_sampling_rate": self.frame_sampling_rate,
415
+ "num_vision_tokens_each_frame": self.num_vision_tokens_each_frame,
416
+ "num_gazing_each_frame": num_gazing_each_frame,
417
+ "if_padded_gazing": if_padded_gazing,
418
+ "task_loss_prediction": task_loss_prediction,
419
+ "has_task_loss_requirement": task_loss_requirement is not None,
420
+ "task_loss_requirement": task_loss_requirement,
421
+ "past_key_values": new_past_key_values if use_cache else None,
422
+ "past_inputs_embeds": new_past_inputs_embeds if use_cache else None,
423
+ "past_attention_mask": new_past_attention_mask if use_cache else None,
424
+ "past_conv_values": new_past_conv_values if use_cache else None,
425
+ }
426
+
427
+ # Postprocess the output to recover from resolution adaptation
428
+ if target_scales is not None and target_patch_size is not None:
429
+ if not (target_scales == self.scales and [(scale // target_patch_size) ** 2 for scale in target_scales] == self.num_vision_tokens_each_scale_each_frame):
430
+ to_return.update(self.recover_output_from_res_adapt(to_return, res_adapt_info))
431
+
432
+ return to_return
autogaze/models/autogaze/configuration_autogaze.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """AutoGaze model configuration"""
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+ from omegaconf import OmegaConf
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from transformers.modeling_rope_utils import rope_config_validation
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+
14
+ class GazeDecoderConfig(PretrainedConfig):
15
+ r"""
16
+ Based on LLamaConfig from transformers.
17
+ ```"""
18
+ model_type = "llama"
19
+ keys_to_ignore_at_inference = ["past_key_values"]
20
+ # Default tensor parallel plan for base model `LlamaModel`
21
+ base_model_tp_plan = {
22
+ "layers.*.self_attn.q_proj": "colwise",
23
+ "layers.*.self_attn.k_proj": "colwise",
24
+ "layers.*.self_attn.v_proj": "colwise",
25
+ "layers.*.self_attn.o_proj": "rowwise",
26
+ "layers.*.mlp.gate_proj": "colwise",
27
+ "layers.*.mlp.up_proj": "colwise",
28
+ "layers.*.mlp.down_proj": "rowwise",
29
+ }
30
+ base_model_pp_plan = {
31
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
32
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
33
+ "norm": (["hidden_states"], ["hidden_states"]),
34
+ }
35
+
36
+ def __init__(
37
+ self,
38
+ vocab_size=32000,
39
+ hidden_size=4096,
40
+ intermediate_size=11008,
41
+ num_hidden_layers=32,
42
+ num_attention_heads=32,
43
+ num_key_value_heads=None,
44
+ hidden_act="silu",
45
+ max_position_embeddings=2048,
46
+ initializer_range=0.02,
47
+ rms_norm_eps=1e-6,
48
+ use_cache=True,
49
+ pad_token_id=None,
50
+ bos_token_id=1,
51
+ eos_token_id=2,
52
+ pretraining_tp=1,
53
+ tie_word_embeddings=False,
54
+ rope_theta=10000.0,
55
+ rope_scaling=None,
56
+ attention_bias=False,
57
+ attention_dropout=0.0,
58
+ mlp_bias=False,
59
+ head_dim=None,
60
+ attn_mode="sdpa",
61
+ num_multi_token_pred=1,
62
+ **kwargs,
63
+ ):
64
+ self.vocab_size = vocab_size
65
+ self.max_position_embeddings = max_position_embeddings
66
+ self.hidden_size = hidden_size
67
+ self.intermediate_size = intermediate_size
68
+ self.num_hidden_layers = num_hidden_layers
69
+ self.num_attention_heads = num_attention_heads
70
+
71
+ # for backward compatibility
72
+ if num_key_value_heads is None:
73
+ num_key_value_heads = num_attention_heads
74
+
75
+ self.num_key_value_heads = num_key_value_heads
76
+ self.hidden_act = hidden_act
77
+ self.initializer_range = initializer_range
78
+ self.rms_norm_eps = rms_norm_eps
79
+ self.pretraining_tp = pretraining_tp
80
+ self.use_cache = use_cache
81
+ self.rope_theta = rope_theta
82
+ self.rope_scaling = rope_scaling
83
+ self.attention_bias = attention_bias
84
+ self.attention_dropout = attention_dropout
85
+ self.mlp_bias = mlp_bias
86
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
87
+ self.num_multi_token_pred = num_multi_token_pred
88
+ self._attn_implementation = attn_mode
89
+ # Validate the correctness of rotary position embeddings parameters
90
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
91
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
92
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
93
+ rope_config_validation(self)
94
+
95
+ super().__init__(
96
+ pad_token_id=pad_token_id,
97
+ bos_token_id=bos_token_id,
98
+ eos_token_id=eos_token_id,
99
+ tie_word_embeddings=tie_word_embeddings,
100
+ **kwargs,
101
+ )
102
+
103
+
104
+ class VisionModelConfig(PretrainedConfig):
105
+ r"""
106
+ Configuration for the vision model component of AutoGaze.
107
+
108
+ Args:
109
+ hidden_dim (`int`, *optional*, defaults to `192`):
110
+ Hidden dimension of the vision model.
111
+ out_dim (`int`, *optional*, defaults to `192`):
112
+ Output dimension of the vision model.
113
+ depth (`int`, *optional*, defaults to `1`):
114
+ Depth of the vision model.
115
+ kernel_size (`int`, *optional*, defaults to `16`):
116
+ Kernel size for spatial convolution.
117
+ temporal_patch_size (`int`, *optional*, defaults to `1`):
118
+ Temporal patch size for video processing.
119
+ trunk_temporal_kernel_size (`int`, *optional*, defaults to `3`):
120
+ Temporal kernel size for trunk blocks.
121
+ trunk_spatial_kernel_size (`int`, *optional*, defaults to `3`):
122
+ Spatial kernel size for trunk blocks.
123
+ """
124
+ def __init__(
125
+ self,
126
+ hidden_dim=192,
127
+ out_dim=192,
128
+ depth=1,
129
+ kernel_size=16,
130
+ temporal_patch_size=1,
131
+ trunk_temporal_kernel_size=3,
132
+ trunk_spatial_kernel_size=3,
133
+ **kwargs,
134
+ ):
135
+ self.hidden_dim = hidden_dim
136
+ self.out_dim = out_dim
137
+ self.depth = depth
138
+ self.kernel_size = kernel_size
139
+ self.temporal_patch_size = temporal_patch_size
140
+ self.trunk_temporal_kernel_size = trunk_temporal_kernel_size
141
+ self.trunk_spatial_kernel_size = trunk_spatial_kernel_size
142
+
143
+ super().__init__(**kwargs)
144
+
145
+
146
+ class ConnectorConfig(PretrainedConfig):
147
+ r"""
148
+ Configuration for the connector component between vision encoder and gaze model.
149
+
150
+ Args:
151
+ hidden_dim (`int`, *optional*, defaults to `192`):
152
+ Hidden dimension of the connector.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ hidden_dim=192,
158
+ num_tokens=196,
159
+ **kwargs,
160
+ ):
161
+ self.hidden_dim = hidden_dim
162
+ self.num_tokens = num_tokens
163
+
164
+ super().__init__(**kwargs)
165
+
166
+
167
+ class GazeModelConfig(PretrainedConfig):
168
+ r"""
169
+ Configuration for the gaze model, containing vision model, connector, and decoder configs.
170
+
171
+ Args:
172
+ num_multi_token_pred (`int`, *optional*, defaults to `1`):
173
+ Number of tokens to predict in parallel.
174
+ input_img_size (`int`, *optional*, defaults to `224`):
175
+ Input image size.
176
+ vision_model_config (`VisionModelConfig` or `dict`, *optional*):
177
+ Configuration for the vision model.
178
+ connector_config (`ConnectorConfig` or `dict`, *optional*):
179
+ Configuration for the connector.
180
+ gaze_decoder_config (`GazeDecoderConfig` or `dict`, *optional*):
181
+ Configuration for the gaze decoder (LLaMA-based).
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ input_img_size=224,
187
+ vision_model_config={},
188
+ connector_config={},
189
+ gaze_decoder_config={},
190
+ num_vision_tokens_each_frame=196,
191
+ attn_mode="sdpa",
192
+ **kwargs,
193
+ ):
194
+ self.input_img_size = input_img_size
195
+ self.vision_model_config = VisionModelConfig(**vision_model_config)
196
+
197
+ connector_config.update({
198
+ "num_tokens": (input_img_size // self.vision_model_config.kernel_size)**2,
199
+ })
200
+ self.connector_config = ConnectorConfig(**connector_config)
201
+
202
+ gaze_decoder_config.update({
203
+ "vocab_size": num_vision_tokens_each_frame + 1,
204
+ "eos_token_id": num_vision_tokens_each_frame,
205
+ "attn_mode": attn_mode,
206
+ })
207
+ self.gaze_decoder_config = GazeDecoderConfig(**gaze_decoder_config)
208
+
209
+ self.num_vision_tokens_each_frame = num_vision_tokens_each_frame
210
+
211
+ super().__init__(**kwargs)
212
+
213
+
214
+ class AutoGazeConfig(PretrainedConfig):
215
+ r"""
216
+ This is the configuration class to store the configuration of an [`AutoGaze`] model. It is used to instantiate an
217
+ AutoGaze model according to the specified arguments, defining the model architecture.
218
+
219
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
220
+ documentation from [`PretrainedConfig`] for more information.
221
+
222
+ Args:
223
+ gazing_ratio_config (`dict`, *optional*):
224
+ Configuration for sampling gazing ratio during training and inference.
225
+ scales (`str` or `int`, *optional*, defaults to `"224"`):
226
+ Scales for the vision model. Can be a single scale or multiple scales separated by '+'.
227
+ num_vision_tokens_each_frame (`int`, *optional*, defaults to `196`):
228
+ Number of vision tokens per frame.
229
+ gaze_model_config (`GazeModelConfig` or `dict`, *optional*):
230
+ Configuration for the gaze model, including vision_model_config, connector_config, and gaze_decoder_config.
231
+ gazing_ratio_each_frame_config (`dict`, *optional*):
232
+ Configuration for sampling gazing ratio for each frame.
233
+ has_task_loss_requirement_during_training (`bool`, *optional*, defaults to `False`):
234
+ Whether to use task loss requirement during training.
235
+ has_task_loss_requirement_during_inference (`bool`, *optional*, defaults to `False`):
236
+ Whether to use task loss requirement during inference.
237
+ task_loss_requirement_config (`dict`, *optional*):
238
+ Configuration for task loss requirement sampling.
239
+ image_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
240
+ Image mean for normalization.
241
+ image_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
242
+ Image std for normalization.
243
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
244
+ Whether to use flash attention.
245
+ max_batch_size (`int`, *optional*):
246
+ Maximum batch size.
247
+
248
+ ```python
249
+ >>> from autogaze.models.autogaze import AutoGaze, AutoGazeConfig
250
+
251
+ >>> # Initializing an AutoGaze configuration
252
+ >>> configuration = AutoGazeConfig()
253
+
254
+ >>> # Initializing a model from the configuration
255
+ >>> model = AutoGaze(configuration)
256
+
257
+ >>> # Accessing the model configuration
258
+ >>> configuration = model.config
259
+ ```"""
260
+
261
+ model_type = "autogaze"
262
+
263
+ def __init__(
264
+ self,
265
+ gazing_ratio_config=None,
266
+ scales="224",
267
+ num_vision_tokens_each_frame=196,
268
+ gaze_model_config={},
269
+ gazing_ratio_each_frame_config=None,
270
+ has_task_loss_requirement_during_training=False,
271
+ has_task_loss_requirement_during_inference=False,
272
+ task_loss_requirement_config=None,
273
+ image_mean=[0.485, 0.456, 0.406],
274
+ image_std=[0.229, 0.224, 0.225],
275
+ use_flash_attn=True,
276
+ max_batch_size=None,
277
+ **kwargs,
278
+ ):
279
+ self.gazing_ratio_config = gazing_ratio_config or {
280
+ "sample_strategy_during_training": "fixed",
281
+ "sample_strategy_during_inference": "fixed",
282
+ "fixed": {"gazing_ratio": 0.5},
283
+ "uniform": {"gazing_ratio_min": 0, "gazing_ratio_max": 1},
284
+ "exponential": {"gazing_ratio_min": 0, "gazing_ratio_max": 1, "lambda": 10},
285
+ }
286
+ self.scales = scales
287
+ self.num_vision_tokens_each_frame = num_vision_tokens_each_frame
288
+ self.attn_mode = "flash_attention_2" if use_flash_attn else "sdpa"
289
+
290
+ gaze_model_config.update({
291
+ "num_vision_tokens_each_frame": num_vision_tokens_each_frame,
292
+ "attn_mode": self.attn_mode,
293
+ })
294
+ self.gaze_model_config = GazeModelConfig(**gaze_model_config)
295
+
296
+ self.gazing_ratio_each_frame_config = gazing_ratio_each_frame_config or {
297
+ "sample_strategy_during_training": "uniform",
298
+ "sample_strategy_during_inference": "uniform",
299
+ "uniform": {},
300
+ "dirichlet": {"alpha": 0.5},
301
+ "self": {},
302
+ }
303
+ self.has_task_loss_requirement_during_training = has_task_loss_requirement_during_training
304
+ self.has_task_loss_requirement_during_inference = has_task_loss_requirement_during_inference
305
+ self.task_loss_requirement_config = task_loss_requirement_config or {
306
+ "sample_strategy_during_training": "fixed",
307
+ "sample_strategy_during_inference": "fixed",
308
+ "fixed": {"task_loss_requirement": 0.7},
309
+ "uniform": {"task_loss_requirement_min": 0.6, "task_loss_requirement_max": 0.9},
310
+ }
311
+ self.image_mean = image_mean
312
+ self.image_std = image_std
313
+ self.use_flash_attn = use_flash_attn
314
+ self.max_batch_size = max_batch_size
315
+
316
+ super().__init__(**kwargs)
317
+
318
+
319
+ __all__ = [
320
+ "AutoGazeConfig",
321
+ "GazeModelConfig",
322
+ "VisionModelConfig",
323
+ "ConnectorConfig",
324
+ "GazeDecoderConfig",
325
+ ]
326
+
autogaze/models/autogaze/modeling_autogaze.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Optional, Tuple
3
+ from dataclasses import dataclass
4
+ from einops import rearrange
5
+
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from timm.models.convnext import ConvNeXtBlock
11
+ from timm.layers import LayerNorm2d
12
+
13
+ from transformers.modeling_outputs import ModelOutput
14
+ from transformers import LogitsProcessor, LogitsProcessorList
15
+
16
+ from .configuration_autogaze import GazeModelConfig, VisionModelConfig, ConnectorConfig
17
+ from .modeling_llama_multi_token_pred import LlamaForCausalLM_MultiTokenPred
18
+
19
+
20
+ @dataclass
21
+ class AutoGazeOutput(ModelOutput):
22
+ gaze_logits: Optional[torch.FloatTensor] = None
23
+ gaze_probs: Optional[torch.FloatTensor] = None
24
+ loss: Optional[torch.FloatTensor] = None
25
+ logits: torch.FloatTensor = None
26
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
27
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
28
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
29
+ task_loss_prediction: Optional[torch.FloatTensor] = None
30
+
31
+
32
+ class NoRepeatTokensLogitsProcessor(LogitsProcessor):
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
37
+ # input_ids: (batch_size, sequence_length)
38
+ # scores: (batch_size, vocab_size) or (batch_size, num_multi_token_pred, vocab_size)
39
+ if scores.ndim == 3:
40
+ scores[torch.arange(scores.shape[0])[..., None], :, input_ids] = -float("inf")
41
+ else:
42
+ scores[torch.arange(scores.shape[0])[..., None], input_ids] = -float("inf")
43
+ return scores
44
+
45
+
46
+ class NoEosTokenLogitsProcessor(LogitsProcessor):
47
+ def __init__(self):
48
+ super().__init__()
49
+
50
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
51
+ # input_ids: (batch_size, sequence_length)
52
+ # scores: (batch_size, vocab_size) or (batch_size, num_multi_token_pred, vocab_size)
53
+ scores[..., -1] = -float("inf")
54
+ return scores
55
+
56
+
57
+ class AutoGazeModel(nn.Module):
58
+ def __init__(self, gaze_model_config: GazeModelConfig):
59
+ super().__init__()
60
+
61
+ self.num_vision_tokens_each_frame = gaze_model_config.num_vision_tokens_each_frame
62
+ self.input_img_size = gaze_model_config.input_img_size
63
+ self.frame_sampling_rate = gaze_model_config.vision_model_config.temporal_patch_size
64
+ self.num_multi_token_pred = gaze_model_config.gaze_decoder_config.num_multi_token_pred
65
+ self.gaze_decoder_config = gaze_model_config.gaze_decoder_config # Store for reference
66
+
67
+ # Create the vision model, connector, and gaze decoder
68
+ self.vision_model = ShallowVideoConvNet(gaze_model_config.vision_model_config)
69
+ self.connector = Connector(gaze_model_config.connector_config)
70
+ self.gaze_decoder = LlamaForCausalLM_MultiTokenPred(gaze_model_config.gaze_decoder_config)
71
+
72
+ # Add logits processors to prevent the model from repeating the same token and generating eos token during gazing.
73
+ self.logits_processor = LogitsProcessorList()
74
+ self.logits_processor.append(NoRepeatTokensLogitsProcessor()) # don't allow repeated gazing
75
+ self.logits_processor.append(NoEosTokenLogitsProcessor()) # don't allow generating eos token duing gazing
76
+
77
+ def embed(self, video=None, gaze_pos_ids=None, use_cache=False, past_conv_values=None):
78
+ """
79
+ inputs:
80
+ video: (B x T x C x H x W).
81
+ gaze_pos_ids: list of (B, N), N is the number of gazing positions in each frame. The length of the list is T // frame_sampling_rate.
82
+ returns:
83
+ embeds: a list of interleaved vision and gaze embeddings. list of (B, N, C)
84
+ gaze_token_mask: a list of masks that indicate if the current embedding is a gaze embedding. (1 is gaze embedding, 0 is vision embedding). list of (N, )
85
+ gaze_pred_source_relative: a list of (relative) source index of where the gaze prediction is coming from. For example, if the gaze prediction is coming from two tokens before it, the source index is -2. list of (N, ).
86
+ For vision embeddings, there's no source prediction, so the source index is -1.
87
+ attention_mask: a list of (B, N) that indicates if the current embedding should be masked out (for EOS token). 1 is not masked, 0 is masked.
88
+ """
89
+ B, T = video.shape[:2]
90
+ assert (video is None or gaze_pos_ids is None) or video.shape[1] // self.frame_sampling_rate == len(gaze_pos_ids), \
91
+ "The number of frames in the video (after subsampling) and in gaze position IDs must be the same, but got {} and {}".format(video.shape[1] // self.frame_sampling_rate, len(gaze_pos_ids))
92
+
93
+ if video is not None:
94
+ vision_features, new_past_conv_values = self.vision_model(video, use_cache=use_cache, past_conv_values=past_conv_values)
95
+ vision_features = vision_features.transpose(1, 2)
96
+ vision_features = rearrange(vision_features, 'b t c h w -> b t (h w) c')
97
+ vision_features = self.connector(vision_features)
98
+ vision_attention_mask = [torch.ones(B, vision_features.shape[2], device=vision_features.device).long() for _ in range(vision_features.shape[1])]
99
+
100
+ if gaze_pos_ids is not None:
101
+ num_gazing_each_frame = [gaze_pos_ids[t].shape[1] for t in range(len(gaze_pos_ids))]
102
+ gaze_pos_ids = torch.cat(gaze_pos_ids, dim=1)
103
+ gaze_attention_mask = (gaze_pos_ids != self.gaze_decoder_config.eos_token_id).to(torch.long)
104
+ gaze_embeds = self.gaze_decoder.model.embed_tokens(gaze_pos_ids)
105
+ gaze_embeds = list(gaze_embeds.split(num_gazing_each_frame, dim=1))
106
+ gaze_attention_mask = list(gaze_attention_mask.split(num_gazing_each_frame, dim=1))
107
+
108
+ embeds = []
109
+ gaze_token_mask = []
110
+ gaze_pred_source_relative = []
111
+ attention_mask = []
112
+ for t in range(T // self.frame_sampling_rate):
113
+ if video is not None:
114
+ embeds.append(vision_features[:, t, :, :])
115
+ gaze_token_mask.append(torch.zeros(vision_features.shape[2], device=vision_features.device).long())
116
+ gaze_pred_source_relative.append(torch.zeros(vision_features.shape[2], device=vision_features.device).long() - 1)
117
+ attention_mask.append(vision_attention_mask[t])
118
+ if gaze_pos_ids is not None:
119
+ embeds.append(gaze_embeds[t])
120
+ gaze_token_mask.append(torch.ones(gaze_embeds[t].shape[1], device=gaze_embeds[t].device).long())
121
+ gaze_pred_source_relative.append(-(torch.arange(gaze_embeds[t].shape[1], device=gaze_embeds[t].device) % self.num_multi_token_pred + 1))
122
+ attention_mask.append(gaze_attention_mask[t])
123
+ return embeds, gaze_token_mask, gaze_pred_source_relative, attention_mask, new_past_conv_values if video is not None else None
124
+
125
+ @torch.no_grad()
126
+ def generate(
127
+ self,
128
+ video,
129
+ max_gaze_tokens_each_frame=100,
130
+ task_loss_requirement=None,
131
+ use_cache=False,
132
+ past_key_values=None,
133
+ past_inputs_embeds=None,
134
+ past_attention_mask=None,
135
+ past_conv_values=None,
136
+ **generation_kwargs,
137
+ ):
138
+ """
139
+ Inputs:
140
+ video: (B, T, C, H, W)
141
+ max_gaze_tokens_each_frame: int or (T, ). Indicating the max gazing length for each frame. If is int, then all frames have the same max gazing length.
142
+ task_loss_requirement (optional): (B, T). Indicating the task loss requirement for each frame.
143
+ past_key_values (optional): The past key values for the gaze model. Can be used for streaming generation.
144
+ past_inputs_embeds (optional): The past inputs embeds for the gaze model. Can be used for streaming generation.
145
+ past_attention_mask (optional): The past attention mask for the gaze model. Can be used for streaming generation.
146
+ """
147
+ if past_key_values is not None or past_inputs_embeds is not None or past_attention_mask is not None or past_conv_values is not None:
148
+ assert past_key_values is not None and past_inputs_embeds is not None and past_attention_mask is not None and past_conv_values is not None, \
149
+ "If past_key_values, past_inputs_embeds, past_attention_mask, or past_conv_values is provided, then all four must be provided!"
150
+
151
+ # Subsample frames and resize
152
+ B, T = video.shape[:2]
153
+ video = rearrange(video, 'b t c h w -> (b t) c h w')
154
+ video = F.interpolate(video, size=(self.input_img_size, self.input_img_size), mode="bicubic", align_corners=False)
155
+ video = rearrange(video, '(b t) c h w -> b t c h w', b=B)
156
+
157
+ # Embed all the frames
158
+ video_embeds, _, __, ___, past_conv_values = self.embed(video=video, use_cache=use_cache, past_conv_values=past_conv_values)
159
+
160
+ # Generate gaze position IDs for each frame
161
+ gaze_pos_ids_list = []
162
+ inputs_embeds = [] if past_inputs_embeds is None else past_inputs_embeds
163
+ attention_mask = [] if past_attention_mask is None else past_attention_mask
164
+ past_key_values = None if past_key_values is None else past_key_values
165
+ num_gazing_each_frame = []
166
+ if_padded_gazing = []
167
+ for t in range(len(video_embeds)):
168
+
169
+ # Update inputs_embeds and attention mask for the new frame
170
+ inputs_embeds.append(video_embeds[t])
171
+ attention_mask.append(torch.ones(video_embeds[t].shape[0], video_embeds[t].shape[1], device=video_embeds[t].device).long())
172
+
173
+ # Put task loss requirement into generation config
174
+ generation_config = self.gaze_decoder.generation_config
175
+ generation_config.task_loss_requirement = task_loss_requirement[:, t] if task_loss_requirement is not None else None
176
+
177
+ # Get the max gazing length for the current frame
178
+ assert isinstance(max_gaze_tokens_each_frame, int) or len(max_gaze_tokens_each_frame) == len(video_embeds), \
179
+ "max_gaze_tokens_each_frame must be an int or a tensor of the same length as the video embeddings, but got {} and {}".format(max_gaze_tokens_each_frame, len(video_embeds))
180
+ max_gaze_tokens = max_gaze_tokens_each_frame if isinstance(max_gaze_tokens_each_frame, int) else max_gaze_tokens_each_frame[t]
181
+
182
+ # Generate gaze position IDs for the current frame
183
+ is_gradient_checkpointing = self.gaze_decoder.is_gradient_checkpointing
184
+ if is_gradient_checkpointing:
185
+ self.gaze_decoder.gradient_checkpointing_disable()
186
+ gaze_outputs = self.gaze_decoder.generate(
187
+ inputs_embeds=torch.cat(inputs_embeds, dim=1), # We need to pass the whole sequence of inputs_embeds (both current and past) to the model even when we use use_cache=True!!!
188
+ attention_mask=torch.cat(attention_mask, dim=1),
189
+ position_ids=torch.cat(attention_mask, dim=1).cumsum(dim=-1) - 1,
190
+ max_new_tokens=max_gaze_tokens,
191
+ logits_processor=self.logits_processor,
192
+ pad_token_id=self.gaze_decoder_config.eos_token_id,
193
+ eos_token_id=self.gaze_decoder_config.eos_token_id,
194
+ past_key_values=past_key_values,
195
+ use_cache=True,
196
+ return_dict_in_generate=True,
197
+ generation_config=generation_config,
198
+ **generation_kwargs,
199
+ )
200
+ if is_gradient_checkpointing:
201
+ self.gaze_decoder.gradient_checkpointing_enable()
202
+
203
+ # Get the predicted gaze ids
204
+ gaze_pos_ids = gaze_outputs.sequences # B * N
205
+ gaze_pos_ids_list.append(gaze_pos_ids + self.num_vision_tokens_each_frame * t)
206
+
207
+ # Update inputs_embeds for the next frame
208
+ inputs_embeds.append(self.gaze_decoder.model.embed_tokens(gaze_pos_ids))
209
+
210
+ # Update past_key_values for the next frame
211
+ past_key_values = gaze_outputs.past_key_values
212
+
213
+ # Update auxiliary information
214
+ num_gazing_each_frame.append(gaze_pos_ids.shape[1])
215
+ if_padded_gazing.append(gaze_pos_ids == self.gaze_decoder_config.eos_token_id)
216
+
217
+ # Update attention mask
218
+ attention_mask.append((gaze_pos_ids != self.gaze_decoder_config.eos_token_id).to(torch.long))
219
+
220
+ # Concatenate gaze position IDs from all frames
221
+ gaze_pos_ids = torch.cat(gaze_pos_ids_list, dim=1)
222
+
223
+ # Get auxiliary information
224
+ num_gazing_each_frame = torch.tensor(num_gazing_each_frame, device=gaze_pos_ids.device).to(torch.long)
225
+ if_padded_gazing = torch.cat(if_padded_gazing, dim=1)
226
+
227
+ to_return = {
228
+ "gazing_pos": gaze_pos_ids, # In gaze_pos_ids, the padded gazing positions are not necessarily eos_token_id, so one needs to use if_padded_gazing to determine if the gazing position is padded!!!
229
+ "num_gazing_each_frame": num_gazing_each_frame,
230
+ "if_padded_gazing": if_padded_gazing,
231
+ "task_loss_requirement": task_loss_requirement,
232
+ "past_inputs_embeds": inputs_embeds if use_cache else None,
233
+ "past_attention_mask": attention_mask if use_cache else None,
234
+ "past_key_values": past_key_values if use_cache else None,
235
+ "past_conv_values": past_conv_values if use_cache else None,
236
+ }
237
+ return to_return
238
+
239
+ def forward(self, video, gazing_info, **kwargs):
240
+ # Unpack gazing_info
241
+ gaze_pos_ids = gazing_info["gazing_pos"]
242
+ num_gazing_each_frame = gazing_info["num_gazing_each_frame"]
243
+ if_padded_gazing = gazing_info["if_padded_gazing"]
244
+
245
+ # Subsample frames and resize
246
+ B, T = video.shape[:2]
247
+ video = rearrange(video, 'b t c h w -> (b t) c h w')
248
+ video = F.interpolate(video, size=(self.input_img_size, self.input_img_size), mode="bicubic", align_corners=False)
249
+ video = rearrange(video, '(b t) c h w -> b t c h w', b=B)
250
+
251
+ # Split the gaze frame-wise
252
+ gaze_pos_ids_split = list(gaze_pos_ids.split(num_gazing_each_frame.tolist(), dim=1))
253
+ gaze_pos_ids_split = [gaze_pos_ids_split[t] - self.num_vision_tokens_each_frame * t for t in range(len(gaze_pos_ids_split))]
254
+ if_padded_gazing_split = list(if_padded_gazing.split(num_gazing_each_frame.tolist(), dim=1))
255
+
256
+ # Fill the padded gazing positions with eos_token_id
257
+ gaze_pos_ids_split = [gaze_pos * (~padded) + self.gaze_decoder_config.eos_token_id * padded for gaze_pos, padded in zip(gaze_pos_ids_split, if_padded_gazing_split)]
258
+
259
+ # Embed the video and gaze position IDs
260
+ inputs_embeds, gaze_token_mask, gaze_pred_source_relative, attention_mask, _ = self.embed(video=video, gaze_pos_ids=gaze_pos_ids_split)
261
+ inputs_embeds = torch.cat(inputs_embeds, dim=1) # B * N * C
262
+ gaze_token_mask = torch.cat(gaze_token_mask, dim=0) # N
263
+ gaze_pred_source_relative = torch.cat(gaze_pred_source_relative, dim=0) # N
264
+ attention_mask = torch.cat(attention_mask, dim=1) # B * N
265
+
266
+ # Run model forward
267
+ outputs = self.gaze_decoder(
268
+ inputs_embeds=inputs_embeds,
269
+ attention_mask=attention_mask,
270
+ position_ids=attention_mask.cumsum(dim=-1) - 1,
271
+ **kwargs,
272
+ )
273
+
274
+ # Get gaze logits and probs
275
+ logits_multi_token_pred = outputs.logits
276
+ task_loss_prediction_multi_token_pred = outputs.task_loss_prediction # B * N * num_multi_token_pred
277
+ logits_multi_token_pred = rearrange(logits_multi_token_pred, 'b n (k c) -> b n k c', k=self.num_multi_token_pred)
278
+ gaze_probs_all_multi_token_pred = F.softmax(logits_multi_token_pred, dim=-1)
279
+
280
+ shifted_probs = []
281
+ shifted_task_loss_prediction = []
282
+ for i in range(self.num_multi_token_pred):
283
+ shifted_probs.append(F.pad(gaze_probs_all_multi_token_pred[:, :-(i + 1), i, :], (0, 0, i + 1, 0), value=0))
284
+ shifted_task_loss_prediction.append(F.pad(task_loss_prediction_multi_token_pred[:, :task_loss_prediction_multi_token_pred.shape[1] - i, i], (i, 0), value=0))
285
+ shifted_probs = torch.stack(shifted_probs, dim=2) # B, N, K, C
286
+ shifted_task_loss_prediction = torch.stack(shifted_task_loss_prediction, dim=2) # B, N, K
287
+
288
+ gaze_probs_all = shifted_probs[:, torch.arange(logits_multi_token_pred.shape[1]), -gaze_pred_source_relative - 1]
289
+ task_loss_prediction = shifted_task_loss_prediction[:, torch.arange(logits_multi_token_pred.shape[1]), (-gaze_pred_source_relative) % self.num_multi_token_pred] # B, N
290
+
291
+ gaze_input_token_pos = torch.nonzero(gaze_token_mask, as_tuple=True)[0]
292
+ gaze_probs_all = gaze_probs_all[:, gaze_input_token_pos, :]
293
+ task_loss_prediction = task_loss_prediction[:, gaze_input_token_pos]
294
+ B, N = gaze_probs_all.shape[:2]
295
+ gaze_probs = gaze_probs_all.reshape(B * N, -1)[torch.arange(B * N), torch.cat(gaze_pos_ids_split, dim=1).flatten()].reshape(B, N) # [B, T]
296
+
297
+
298
+ outputs = AutoGazeOutput(
299
+ gaze_probs=gaze_probs,
300
+ loss=outputs.loss,
301
+ logits=outputs.logits,
302
+ past_key_values=outputs.past_key_values,
303
+ hidden_states=outputs.hidden_states,
304
+ attentions=outputs.attentions,
305
+ task_loss_prediction=task_loss_prediction,
306
+ )
307
+ return outputs
308
+
309
+
310
+ ################ Shallow Vision Encoder #################
311
+
312
+ class Conv3dBlockForStreaming(nn.Module):
313
+ def __init__(self, hidden_dim, temporal_patch_size, spatial_kernel_size):
314
+ super().__init__()
315
+ self.hidden_dim = hidden_dim
316
+ self.temporal_patch_size = temporal_patch_size
317
+ self.spatial_kernel_size = spatial_kernel_size
318
+
319
+ self.conv3d = nn.Conv3d(
320
+ hidden_dim, hidden_dim,
321
+ kernel_size=(temporal_patch_size, spatial_kernel_size, spatial_kernel_size),
322
+ padding=(0, (spatial_kernel_size - 1) // 2, (spatial_kernel_size - 1) // 2), # We manually pad the temporal dimension in forward, to support streaming
323
+ bias=True,
324
+ )
325
+ self.relu = nn.ReLU(inplace=True)
326
+
327
+ def forward(self, x, use_cache=False, past_conv_values=None):
328
+ if not (use_cache and past_conv_values is not None):
329
+ x = F.pad(x, (0, 0, 0, 0, self.temporal_patch_size - 1, 0), value=0)
330
+ else:
331
+ x = torch.cat([past_conv_values, x], dim=2)
332
+ new_past_conv_values = x[:, :, -(self.temporal_patch_size - 1):]
333
+
334
+ x = self.conv3d(x)
335
+
336
+ x = self.relu(x)
337
+
338
+ return x, new_past_conv_values
339
+
340
+
341
+ class ShallowVideoConvNet(nn.Module):
342
+ """
343
+ A shallow video convolutional network for video gaze modeling, inspired by ViViT's patch embedding approach.
344
+ Expects input of shape (B, T, C, H, W) or (B*T, C, H, W).
345
+ """
346
+ def __init__(self, config: VisionModelConfig):
347
+ super().__init__()
348
+ hidden_dim = config.hidden_dim
349
+ out_dim = config.out_dim
350
+ depth = config.depth
351
+ kernel_size = config.kernel_size
352
+ self.temporal_patch_size = getattr(config, "temporal_patch_size", 1)
353
+
354
+ # For video, first merge temporal and batch if needed, then apply 3D conv for temporal patching
355
+ self.temporal_conv = nn.Conv3d(
356
+ in_channels=3, # RGB
357
+ out_channels=hidden_dim,
358
+ kernel_size=(self.temporal_patch_size, kernel_size, kernel_size),
359
+ stride=(self.temporal_patch_size, kernel_size, kernel_size),
360
+ bias=True,
361
+ )
362
+ self.norm = nn.LayerNorm(hidden_dim)
363
+
364
+ self.trunk_temporal_kernel_size = config.trunk_temporal_kernel_size
365
+ self.trunk_spatial_kernel_size = config.trunk_spatial_kernel_size
366
+ blocks = []
367
+ for i in range(depth):
368
+ blocks.append(
369
+ Conv3dBlockForStreaming(
370
+ hidden_dim=hidden_dim,
371
+ temporal_patch_size=self.trunk_temporal_kernel_size,
372
+ spatial_kernel_size=self.trunk_spatial_kernel_size,
373
+ )
374
+ )
375
+ self.blocks = nn.ModuleList(blocks)
376
+
377
+ self.out_proj = nn.Conv3d(
378
+ hidden_dim, out_dim, kernel_size=1, stride=1, bias=True
379
+ )
380
+
381
+ def forward(self, x, use_cache=False, past_conv_values=None):
382
+ # x: (B, T, C, H, W) or (B*T, C, H, W)
383
+ if x.dim() == 5:
384
+ # (B, T, C, H, W) -> (B, C, T, H, W)
385
+ x = x.permute(0, 2, 1, 3, 4)
386
+ elif x.dim() == 4:
387
+ # (B*T, C, H, W) -> (B*T, C, 1, H, W)
388
+ x = x.unsqueeze(2)
389
+ else:
390
+ raise ValueError("Input must be 4D or 5D tensor")
391
+ x = self.temporal_conv(x) # (B, hidden_dim, T', H', W')
392
+ # Collapse temporal dimension into batch for normalization and blocks
393
+ B, C, T, H, W = x.shape
394
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W) # (B*T, C, H, W)
395
+ # Flatten spatial dims for norm: (B*T, C, H*W)
396
+ x = x.view(B * T, C, -1).permute(0, 2, 1) # (B*T, H*W, C)
397
+ x = self.norm(x)
398
+ x = x.permute(0, 2, 1).contiguous().view(B * T, C, H, W) # (B*T, C, H, W)
399
+ # Reshape back to (B, C, T, H, W)
400
+ x = x.view(B, T, C, H, W).permute(0, 2, 1, 3, 4)
401
+ # Main trunk
402
+ new_past_conv_values = []
403
+ for i, block in enumerate(self.blocks):
404
+ x, new_past_conv_values_i = block(
405
+ x,
406
+ use_cache=use_cache,
407
+ past_conv_values=past_conv_values[i] if use_cache and past_conv_values is not None else None
408
+ )
409
+ new_past_conv_values.append(new_past_conv_values_i)
410
+ x = self.out_proj(x)
411
+ # Output shape: (B, out_dim, T', H', W')
412
+ return x, new_past_conv_values
413
+
414
+
415
+ ################ Connector Between Vision Encoder and Gaze Model #################
416
+
417
+ class Connector(nn.Module):
418
+ def __init__(self, config: ConnectorConfig):
419
+ super().__init__()
420
+
421
+ self.hidden_dim = config.hidden_dim
422
+ self.num_tokens = config.num_tokens
423
+
424
+ self.pos_embed = nn.Parameter(torch.randn(self.num_tokens, self.hidden_dim))
425
+
426
+ def forward(self, x):
427
+ """
428
+ x: (B, T, N, C)
429
+ """
430
+ x = x + self.pos_embed[None, None]
431
+ return x
autogaze/models/autogaze/modeling_llama_multi_token_pred.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import os
21
+ from dataclasses import dataclass
22
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
23
+ from copy import deepcopy
24
+ from einops import rearrange
25
+ from importlib.metadata import version
26
+ from packaging.version import Version
27
+ import torch
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+
31
+ from transformers.cache_utils import Cache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPast,
35
+ CausalLMOutputWithPast,
36
+ )
37
+ from transformers.utils import (
38
+ can_return_tuple,
39
+ )
40
+ from transformers.utils.deprecation import deprecate_kwarg
41
+ from transformers.models.llama.modeling_llama import (
42
+ LlamaModel,
43
+ LlamaPreTrainedModel,
44
+ )
45
+
46
+ from transformers.cache_utils import (
47
+ Cache,
48
+ )
49
+ from transformers.utils import (
50
+ ModelOutput,
51
+ )
52
+ from transformers.generation.configuration_utils import (
53
+ GenerationConfig,
54
+ )
55
+ from transformers.generation.logits_process import (
56
+ LogitsProcessorList,
57
+ )
58
+ from transformers.generation.stopping_criteria import (
59
+ StoppingCriteriaList,
60
+ )
61
+ from transformers.generation.utils import (
62
+ GenerateNonBeamOutput,
63
+ GenerateDecoderOnlyOutput,
64
+ GenerateEncoderDecoderOutput,
65
+ ALL_CACHE_NAMES,
66
+ )
67
+
68
+ if TYPE_CHECKING:
69
+ from transformers.modeling_utils import PreTrainedModel
70
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
71
+ from transformers.generation.streamers import BaseStreamer
72
+
73
+ LOW_TRANSFORMERS_VERSION = Version(version("transformers")) < Version("4.52.0")
74
+
75
+
76
+ @dataclass
77
+ class CausalLMOutputWithPast(ModelOutput):
78
+ """
79
+ Base class for causal language model (or autoregressive) outputs.
80
+
81
+ Args:
82
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
83
+ Language modeling loss (for next-token prediction).
84
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
85
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
86
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
87
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
88
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
89
+
90
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
91
+ `past_key_values` input) to speed up sequential decoding.
92
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
93
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
94
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
95
+
96
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
97
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
98
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
99
+ sequence_length)`.
100
+
101
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
102
+ heads.
103
+ """
104
+
105
+ loss: Optional[torch.FloatTensor] = None
106
+ logits: Optional[torch.FloatTensor] = None
107
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
108
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
109
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
110
+ task_loss_prediction: Optional[torch.FloatTensor] = None
111
+
112
+
113
+ class LlamaForCausalLM_MultiTokenPred(LlamaPreTrainedModel, GenerationMixin):
114
+ _tp_plan = {"lm_head": "colwise_rep"}
115
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
116
+
117
+ def __init__(self, config):
118
+ super().__init__(config)
119
+ self.model = LlamaModel(config)
120
+ self.vocab_size = config.vocab_size
121
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_multi_token_pred, bias=False)
122
+ self.task_loss_prediction_head = nn.Linear(config.hidden_size, config.num_multi_token_pred, bias=False)
123
+
124
+ # Initialize weights and apply final processing
125
+ self.post_init()
126
+
127
+ def get_input_embeddings(self):
128
+ return self.model.embed_tokens
129
+
130
+ def set_input_embeddings(self, value):
131
+ self.model.embed_tokens = value
132
+
133
+ def get_output_embeddings(self):
134
+ return self.lm_head
135
+
136
+ def set_output_embeddings(self, new_embeddings):
137
+ self.lm_head = new_embeddings
138
+
139
+ def set_decoder(self, decoder):
140
+ self.model = decoder
141
+
142
+ def get_decoder(self):
143
+ return self.model
144
+
145
+
146
+ @can_return_tuple
147
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
148
+ def forward(
149
+ self,
150
+ input_ids: Optional[torch.LongTensor] = None,
151
+ attention_mask: Optional[torch.Tensor] = None,
152
+ position_ids: Optional[torch.LongTensor] = None,
153
+ past_key_values: Optional[Cache] = None,
154
+ inputs_embeds: Optional[torch.FloatTensor] = None,
155
+ labels: Optional[torch.LongTensor] = None,
156
+ use_cache: Optional[bool] = None,
157
+ output_attentions: Optional[bool] = None,
158
+ output_hidden_states: Optional[bool] = None,
159
+ cache_position: Optional[torch.LongTensor] = None,
160
+ logits_to_keep: Union[int, torch.Tensor] = 0,
161
+ **kwargs,
162
+ ) -> CausalLMOutputWithPast:
163
+ """
164
+ This function is copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.forward.
165
+ """
166
+
167
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
168
+ output_hidden_states = (
169
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
170
+ )
171
+
172
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
173
+ outputs: BaseModelOutputWithPast = self.model(
174
+ input_ids=input_ids,
175
+ attention_mask=attention_mask,
176
+ position_ids=position_ids,
177
+ past_key_values=past_key_values,
178
+ inputs_embeds=inputs_embeds,
179
+ use_cache=use_cache,
180
+ output_attentions=output_attentions,
181
+ output_hidden_states=output_hidden_states,
182
+ cache_position=cache_position,
183
+ **kwargs,
184
+ )
185
+
186
+ hidden_states = outputs.last_hidden_state
187
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
188
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
189
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
190
+ task_loss_prediction = self.task_loss_prediction_head(hidden_states[:, slice_indices, :])
191
+
192
+ loss = None
193
+ if labels is not None:
194
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
195
+
196
+ return CausalLMOutputWithPast(
197
+ loss=loss,
198
+ logits=logits,
199
+ past_key_values=outputs.past_key_values,
200
+ hidden_states=outputs.hidden_states,
201
+ attentions=outputs.attentions,
202
+ task_loss_prediction=task_loss_prediction,
203
+ )
204
+
205
+ def _update_model_kwargs_for_generation(
206
+ self,
207
+ outputs: ModelOutput,
208
+ model_kwargs: Dict[str, Any],
209
+ is_encoder_decoder: bool = False,
210
+ num_new_tokens: int = 1,
211
+ ) -> Dict[str, Any]:
212
+ # update past_key_values keeping its naming used in model code
213
+ for possible_cache_name in ALL_CACHE_NAMES:
214
+ if possible_cache_name in outputs:
215
+ # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
216
+ if possible_cache_name in ("past_buckets_states", "mems"):
217
+ cache_name = "past_key_values"
218
+ else:
219
+ cache_name = possible_cache_name
220
+ model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
221
+ break
222
+
223
+ # update token_type_ids with last value
224
+ if "token_type_ids" in model_kwargs:
225
+ token_type_ids = model_kwargs["token_type_ids"]
226
+ assert token_type_ids.dim() == 2
227
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1).repeat(1, num_new_tokens)], dim=-1)
228
+
229
+ if not is_encoder_decoder:
230
+ # update attention mask
231
+ if "attention_mask" in model_kwargs:
232
+ attention_mask = model_kwargs["attention_mask"]
233
+ model_kwargs["attention_mask"] = torch.cat(
234
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1
235
+ )
236
+ else:
237
+ # update decoder attention mask
238
+ if "decoder_attention_mask" in model_kwargs:
239
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
240
+ model_kwargs["decoder_attention_mask"] = torch.cat(
241
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], num_new_tokens))],
242
+ dim=-1,
243
+ )
244
+
245
+ # Since we generate multiple tokens at once, the number of new tokens > 1 and all those tokens need to
246
+ # be cached later.
247
+ if model_kwargs.get("use_cache", True):
248
+ model_kwargs["cache_position"] = torch.arange(
249
+ model_kwargs["cache_position"][-1] + 1, model_kwargs["cache_position"][-1] + num_new_tokens + 1, dtype=model_kwargs["cache_position"].dtype
250
+ ).to(model_kwargs["cache_position"].device)
251
+ else:
252
+ past_positions = model_kwargs.pop("cache_position")
253
+ new_positions = torch.arange(
254
+ past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
255
+ ).to(past_positions.device)
256
+ model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
257
+ return model_kwargs
258
+
259
+ def _sample(
260
+ self,
261
+ input_ids: torch.LongTensor,
262
+ logits_processor: LogitsProcessorList,
263
+ stopping_criteria: StoppingCriteriaList,
264
+ generation_config: GenerationConfig,
265
+ synced_gpus: bool,
266
+ streamer: Optional["BaseStreamer"] = None,
267
+ **model_kwargs,
268
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
269
+ """
270
+ This function is copied from transformers.generation.utils.GenerationMixin._sample.
271
+ """
272
+ # init values
273
+ pad_token_id = generation_config._pad_token_tensor
274
+ output_attentions = generation_config.output_attentions
275
+ output_hidden_states = generation_config.output_hidden_states
276
+ output_scores = generation_config.output_scores
277
+ output_logits = generation_config.output_logits
278
+ return_dict_in_generate = generation_config.return_dict_in_generate
279
+ do_sample = generation_config.do_sample
280
+ task_loss_requirement = generation_config.task_loss_requirement
281
+
282
+ # init attention / hidden states / scores tuples
283
+ scores = () if (return_dict_in_generate and output_scores) else None
284
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
285
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
286
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
287
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
288
+
289
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
290
+ if return_dict_in_generate and self.config.is_encoder_decoder:
291
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
292
+ encoder_hidden_states = (
293
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
294
+ )
295
+
296
+ # keep track of which sequences are already finished
297
+ batch_size, cur_len = input_ids.shape
298
+ this_peer_finished = False
299
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
300
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) if LOW_TRANSFORMERS_VERSION else self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
301
+
302
+ model_forward = self.__call__
303
+ if isinstance(model_kwargs.get("past_key_values"), Cache):
304
+ is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
305
+ if getattr(self, "hf_quantizer", None) is not None:
306
+ is_compileable &= self.hf_quantizer.is_compileable
307
+ is_compileable = is_compileable and not generation_config.disable_compile
308
+ if is_compileable and (
309
+ self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
310
+ ):
311
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
312
+ model_forward = self.get_compiled_call(generation_config.compile_config)
313
+
314
+ if generation_config.prefill_chunk_size is not None:
315
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
316
+ is_prefill = False
317
+ else:
318
+ is_prefill = True
319
+
320
+ is_first_token = True
321
+ final_past_key_values = None
322
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
323
+ # prepare model inputs
324
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
325
+
326
+ # prepare variable output controls (note: some models won't accept all output controls)
327
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
328
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
329
+
330
+ if is_prefill:
331
+ outputs = self(**model_inputs, return_dict=True)
332
+ is_prefill = False
333
+ else:
334
+ outputs = model_forward(**model_inputs, return_dict=True)
335
+
336
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
337
+ model_kwargs = self._update_model_kwargs_for_generation(
338
+ outputs,
339
+ model_kwargs,
340
+ is_encoder_decoder=self.config.is_encoder_decoder,
341
+ num_new_tokens=self.config.num_multi_token_pred,
342
+ )
343
+
344
+ if synced_gpus and this_peer_finished:
345
+ continue
346
+
347
+ # Deepcopy the to-return past_key_values, such that it won't change during the dummy model forwards when this peer is finished but peers from other devices are still generating.
348
+ final_past_key_values = deepcopy(model_kwargs.get("past_key_values"))
349
+
350
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
351
+ # (the clone itself is always small)
352
+ next_token_logits_all = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
353
+ task_loss_prediction_all = outputs.task_loss_prediction[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
354
+
355
+ # Process all new tokens at once.
356
+ next_token_logits_all = rearrange(next_token_logits_all, 'b (n v) -> b n v', n=self.config.num_multi_token_pred, v=self.config.vocab_size)
357
+
358
+ # pre-process distribution
359
+ next_token_scores_all = logits_processor(input_ids, next_token_logits_all)
360
+
361
+ # token selection
362
+ next_tokens_all = []
363
+ early_stopped = False
364
+ for i in range(self.config.num_multi_token_pred):
365
+ next_token_scores_i = next_token_scores_all[:, i, :]
366
+
367
+ # exit early if meeting max number of token or not token left to choose
368
+ if len(next_tokens_all) + input_ids.shape[1] >= generation_config.max_new_tokens:
369
+ early_stopped = True
370
+ break
371
+ if torch.all(next_token_scores_i == -float("inf")):
372
+ early_stopped = True
373
+ break
374
+
375
+ if do_sample:
376
+ probs = nn.functional.softmax(next_token_scores_i, dim=-1)
377
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
378
+ next_tokens_i = torch.multinomial(probs, num_samples=1).squeeze(1)
379
+ else:
380
+ next_tokens_i = torch.argmax(next_token_scores_i, dim=-1)
381
+
382
+ next_tokens_all.append(next_tokens_i)
383
+
384
+ # avoid repeating gazing
385
+ next_token_scores_all[torch.arange(next_tokens_i.shape[0]), i + 1:, next_tokens_i] = -float("inf")
386
+
387
+ next_tokens_all = torch.stack(next_tokens_all, dim=1)
388
+
389
+ # already finished sentences should have their next token be a padding token
390
+ next_tokens_all = next_tokens_all * unfinished_sequences[..., None] + pad_token_id * (1 - unfinished_sequences[..., None])
391
+
392
+ # Mark finished if task loss requirement is met
393
+ meet_task_loss_requirement = torch.zeros_like(next_tokens_all, dtype=torch.bool)
394
+ if task_loss_requirement is not None:
395
+ meet_task_loss_requirement = task_loss_prediction_all[:, :next_tokens_all.shape[1]] <= task_loss_requirement[..., None]
396
+ if is_first_token:
397
+ meet_task_loss_requirement[:, 0] = False
398
+ next_tokens_all = next_tokens_all * (~meet_task_loss_requirement) + pad_token_id * meet_task_loss_requirement
399
+
400
+ # Truncate the next tokens to the max new tokens
401
+ meet_max_new_tokens = False
402
+ if next_tokens_all.shape[1] + input_ids.shape[1] >= generation_config.max_new_tokens:
403
+ next_tokens_all = next_tokens_all[:, :generation_config.max_new_tokens - input_ids.shape[1]]
404
+ meet_max_new_tokens = True
405
+
406
+ # update generated ids, model inputs, and length for next step
407
+ input_ids = torch.cat([input_ids, next_tokens_all], dim=-1)
408
+ if streamer is not None:
409
+ for i in range(next_tokens_all.shape[1]):
410
+ streamer.put(next_tokens_all[:, i].cpu())
411
+
412
+ # Update the finishing flags
413
+ unfinished_sequences = unfinished_sequences & ~torch.any(meet_task_loss_requirement, dim=-1) & ~meet_max_new_tokens & ~early_stopped
414
+ this_peer_finished = unfinished_sequences.max() == 0
415
+ cur_len += 1
416
+
417
+ is_first_token = False
418
+
419
+ # Store scores, attentions and hidden_states when required
420
+ if return_dict_in_generate:
421
+ if output_scores:
422
+ scores += (next_token_scores_all,)
423
+ if output_logits:
424
+ raw_logits += (next_token_logits_all,)
425
+ if output_attentions:
426
+ decoder_attentions += (
427
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
428
+ )
429
+ if self.config.is_encoder_decoder:
430
+ cross_attentions += (outputs.cross_attentions,)
431
+
432
+ if output_hidden_states:
433
+ decoder_hidden_states += (
434
+ (outputs.decoder_hidden_states,)
435
+ if self.config.is_encoder_decoder
436
+ else (outputs.hidden_states,)
437
+ )
438
+
439
+
440
+
441
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
442
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
443
+ del outputs
444
+
445
+ if streamer is not None:
446
+ streamer.end()
447
+
448
+ if return_dict_in_generate:
449
+ if self.config.is_encoder_decoder:
450
+ return GenerateEncoderDecoderOutput(
451
+ sequences=input_ids,
452
+ scores=scores,
453
+ logits=raw_logits,
454
+ encoder_attentions=encoder_attentions,
455
+ encoder_hidden_states=encoder_hidden_states,
456
+ decoder_attentions=decoder_attentions,
457
+ cross_attentions=cross_attentions,
458
+ decoder_hidden_states=decoder_hidden_states,
459
+ past_key_values=final_past_key_values,
460
+ )
461
+ else:
462
+ return GenerateDecoderOnlyOutput(
463
+ sequences=input_ids,
464
+ scores=scores,
465
+ logits=raw_logits,
466
+ attentions=decoder_attentions,
467
+ hidden_states=decoder_hidden_states,
468
+ past_key_values=final_past_key_values,
469
+ )
470
+ else:
471
+ return input_ids
autogaze/tasks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """AutoGaze tasks."""
autogaze/tasks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
autogaze/tasks/video_mae_reconstruction/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .task_video_mae_reconstruction import VideoMAEReconstruction
autogaze/tasks/video_mae_reconstruction/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (250 Bytes). View file
 
autogaze/tasks/video_mae_reconstruction/__pycache__/configuration_video_mae.cpython-310.pyc ADDED
Binary file (5.86 kB). View file
 
autogaze/tasks/video_mae_reconstruction/__pycache__/modeling_video_mae.cpython-310.pyc ADDED
Binary file (43.9 kB). View file
 
autogaze/tasks/video_mae_reconstruction/__pycache__/task_video_mae_reconstruction.cpython-310.pyc ADDED
Binary file (6.91 kB). View file
 
autogaze/tasks/video_mae_reconstruction/__pycache__/visualize_video_mae_reconstruction.cpython-310.pyc ADDED
Binary file (3.44 kB). View file
 
autogaze/tasks/video_mae_reconstruction/configuration_video_mae.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ViT MAE model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ViTMAEConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`ViTMAEModel`]. It is used to instantiate an ViT
27
+ MAE model according to the specified arguments, defining the model architecture. Instantiating a configuration with
28
+ the defaults will yield a similar configuration to that of the ViT
29
+ [facebook/vit-mae-base](https://huggingface.co/facebook/vit-mae-base) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ hidden_size (`int`, *optional*, defaults to 768):
37
+ Dimensionality of the encoder layers and the pooler layer.
38
+ num_hidden_layers (`int`, *optional*, defaults to 12):
39
+ Number of hidden layers in the Transformer encoder.
40
+ num_attention_heads (`int`, *optional*, defaults to 12):
41
+ Number of attention heads for each attention layer in the Transformer encoder.
42
+ intermediate_size (`int`, *optional*, defaults to 3072):
43
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
47
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
48
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
49
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
50
+ The dropout ratio for the attention probabilities.
51
+ initializer_range (`float`, *optional*, defaults to 0.02):
52
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
53
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
54
+ The epsilon used by the layer normalization layers.
55
+ image_size (`int`, *optional*, defaults to 224):
56
+ The size (resolution) of each image.
57
+ patch_size (`int`, *optional*, defaults to 16):
58
+ The size (resolution) of each patch.
59
+ num_channels (`int`, *optional*, defaults to 3):
60
+ The number of input channels.
61
+ qkv_bias (`bool`, *optional*, defaults to `True`):
62
+ Whether to add a bias to the queries, keys and values.
63
+ decoder_num_attention_heads (`int`, *optional*, defaults to 16):
64
+ Number of attention heads for each attention layer in the decoder.
65
+ decoder_hidden_size (`int`, *optional*, defaults to 512):
66
+ Dimensionality of the decoder.
67
+ decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
68
+ Number of hidden layers in the decoder.
69
+ decoder_intermediate_size (`int`, *optional*, defaults to 2048):
70
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
71
+ mask_ratio (`float`, *optional*, defaults to 0.75):
72
+ The ratio of the number of masked tokens in the input sequence.
73
+ norm_pix_loss (`bool`, *optional*, defaults to `False`):
74
+ Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved
75
+ representation quality in the experiments of the authors.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import ViTMAEConfig, ViTMAEModel
81
+
82
+ >>> # Initializing a ViT MAE vit-mae-base style configuration
83
+ >>> configuration = ViTMAEConfig()
84
+
85
+ >>> # Initializing a model (with random weights) from the vit-mae-base style configuration
86
+ >>> model = ViTMAEModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "vit_mae"
93
+
94
+ def __init__(
95
+ self,
96
+ hidden_size=768,
97
+ num_hidden_layers=12,
98
+ num_attention_heads=12,
99
+ intermediate_size=3072,
100
+ hidden_act="gelu",
101
+ hidden_dropout_prob=0.0,
102
+ attention_probs_dropout_prob=0.0,
103
+ initializer_range=0.02,
104
+ layer_norm_eps=1e-12,
105
+ image_size=224,
106
+ patch_size=16,
107
+ num_channels=3,
108
+ qkv_bias=True,
109
+ decoder_num_attention_heads=16,
110
+ decoder_hidden_size=512,
111
+ decoder_num_hidden_layers=8,
112
+ decoder_intermediate_size=2048,
113
+ mask_ratio=0.75,
114
+ norm_pix_loss=False,
115
+ scales='224',
116
+ loss_type='l1',
117
+ loss_weights='1',
118
+ l1_loss_config=None,
119
+ dinov2_reg_loss_config=None,
120
+ siglip2_loss_config=None,
121
+ scale_embed=True,
122
+ max_num_frames=256,
123
+ time_embed=True,
124
+ causal=True,
125
+ **kwargs,
126
+ ):
127
+ super().__init__(**kwargs)
128
+
129
+ self.hidden_size = hidden_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.num_attention_heads = num_attention_heads
132
+ self.intermediate_size = intermediate_size
133
+ self.hidden_act = hidden_act
134
+ self.hidden_dropout_prob = hidden_dropout_prob
135
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
136
+ self.initializer_range = initializer_range
137
+ self.layer_norm_eps = layer_norm_eps
138
+ self.image_size = image_size
139
+ self.patch_size = patch_size
140
+ self.num_channels = num_channels
141
+ self.qkv_bias = qkv_bias
142
+ self.decoder_num_attention_heads = decoder_num_attention_heads
143
+ self.decoder_hidden_size = decoder_hidden_size
144
+ self.decoder_num_hidden_layers = decoder_num_hidden_layers
145
+ self.decoder_intermediate_size = decoder_intermediate_size
146
+ self.mask_ratio = mask_ratio
147
+ self.norm_pix_loss = norm_pix_loss
148
+ self.scales = scales
149
+ self.loss_type = loss_type
150
+ self.loss_weights = loss_weights
151
+ self.l1_loss_config = l1_loss_config
152
+ self.dinov2_reg_loss_config = dinov2_reg_loss_config
153
+ self.siglip2_loss_config = siglip2_loss_config
154
+ self.scale_embed = scale_embed
155
+ self.max_num_frames = max_num_frames
156
+ self.time_embed = time_embed
157
+ self.causal = causal
158
+
159
+ __all__ = ["ViTMAEConfig"]
autogaze/tasks/video_mae_reconstruction/modeling_video_mae.py ADDED
@@ -0,0 +1,1412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViT MAE (masked autoencoder) model."""
16
+
17
+ import collections.abc
18
+ from copy import deepcopy
19
+ from dataclasses import dataclass
20
+ from typing import Callable, Optional, Set, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ import torch.nn.functional as F
27
+ from einops import rearrange
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import BaseModelOutput
30
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
31
+ from transformers.pytorch_utils import prune_linear_layer
32
+ from transformers.utils import (
33
+ ModelOutput,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ torch_int,
39
+ )
40
+ from .configuration_video_mae import ViTMAEConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads):
47
+ """
48
+ Finds the heads and their indices taking `already_pruned_heads` into account.
49
+
50
+ Args:
51
+ heads (`Set[int]`): A set of head indices we want to prune.
52
+ n_heads (`int`): The number of heads in the model.
53
+ head_size (`int`): The size of each head.
54
+ already_pruned_heads (`Set[int]`): A set of already pruned heads.
55
+
56
+ Returns:
57
+ `Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
58
+ """
59
+ mask = torch.ones(n_heads, head_size)
60
+ heads = set(heads) - already_pruned_heads
61
+ for head in heads:
62
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
63
+ mask[head] = 0
64
+ mask = mask.view(-1).contiguous().eq(1)
65
+ index = torch.arange(len(mask))[mask].long()
66
+ return heads, index
67
+
68
+
69
+ @dataclass
70
+ class ViTMAEModelOutput(ModelOutput):
71
+ """
72
+ Class for ViTMAEModel's outputs, with potential hidden states and attentions.
73
+
74
+ Args:
75
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
76
+ Sequence of hidden-states at the output of the last layer of the model.
77
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
78
+ Tensor indicating which patches are masked (1) and which are not (0).
79
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
80
+ Tensor containing the original index of the (shuffled) masked patches.
81
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
82
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
83
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
84
+ plus the initial embedding outputs.
85
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
86
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
87
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
88
+ the self-attention heads.
89
+ """
90
+
91
+ last_hidden_state: Optional[torch.FloatTensor] = None
92
+ mask: Optional[torch.LongTensor] = None
93
+ ids_restore: Optional[torch.LongTensor] = None
94
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
95
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
96
+
97
+
98
+ @dataclass
99
+ class ViTMAEDecoderOutput(ModelOutput):
100
+ """
101
+ Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
102
+
103
+ Args:
104
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
105
+ Pixel reconstruction logits.
106
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
107
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
108
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
109
+ plus the initial embedding outputs.
110
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
111
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
112
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
113
+ the self-attention heads.
114
+ """
115
+
116
+ logits: Optional[torch.FloatTensor] = None
117
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
118
+ num_decoded_tokens_each_frame: Optional[torch.LongTensor] = None
119
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
120
+
121
+
122
+ @dataclass
123
+ class ViTMAEForPreTrainingOutput(ModelOutput):
124
+ """
125
+ Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
126
+
127
+ Args:
128
+ loss_each_reconstruction_frame (`torch.FloatTensor` of shape `(batch_size, num_selected_frames)`):
129
+ Pixel reconstruction loss for each reconstruction frame.
130
+ loss_mean (`torch.FloatTensor` of shape `(1,)`):
131
+ Mean of the pixel reconstruction loss for each reconstruction frame.
132
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
133
+ Pixel reconstruction logits.
134
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
135
+ Tensor indicating which patches are masked (1) and which are not (0).
136
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
137
+ Tensor containing the original index of the (shuffled) masked patches.
138
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
139
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
140
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
141
+ plus the initial embedding outputs.
142
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
143
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
144
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
145
+ the self-attention heads.
146
+ """
147
+
148
+ loss_each_reconstruction_frame: Optional[torch.FloatTensor] = None
149
+ loss_mean: Optional[torch.FloatTensor] = None
150
+ reconstruction: Optional[torch.FloatTensor] = None
151
+ logits: Optional[torch.FloatTensor] = None
152
+ mask: Optional[torch.LongTensor] = None
153
+ ids_restore: Optional[torch.LongTensor] = None
154
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
155
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
156
+
157
+
158
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
159
+ """
160
+ Create 2D sin/cos positional embeddings.
161
+
162
+ Args:
163
+ embed_dim (`int`):
164
+ Embedding dimension.
165
+ grid_size (`int`):
166
+ The grid height and width.
167
+ add_cls_token (`bool`, *optional*, defaults to `False`):
168
+ Whether or not to add a classification (CLS) token.
169
+
170
+ Returns:
171
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
172
+ position embeddings (with or without classification token)
173
+ """
174
+ grid_h = np.arange(grid_size, dtype=np.float32)
175
+ grid_w = np.arange(grid_size, dtype=np.float32)
176
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
177
+ grid = np.stack(grid, axis=0)
178
+
179
+ grid = grid.reshape([2, 1, grid_size, grid_size])
180
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
181
+ if add_cls_token:
182
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
183
+ return pos_embed
184
+
185
+
186
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
187
+ if embed_dim % 2 != 0:
188
+ raise ValueError("embed_dim must be even")
189
+
190
+ # use half of dimensions to encode grid_h
191
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
192
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
193
+
194
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
195
+ return emb
196
+
197
+
198
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
199
+ """
200
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
201
+ """
202
+ if embed_dim % 2 != 0:
203
+ raise ValueError("embed_dim must be even")
204
+
205
+ omega = np.arange(embed_dim // 2, dtype=float)
206
+ omega /= embed_dim / 2.0
207
+ omega = 1.0 / 10000**omega # (D/2,)
208
+
209
+ pos = pos.reshape(-1) # (M,)
210
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
211
+
212
+ emb_sin = np.sin(out) # (M, D/2)
213
+ emb_cos = np.cos(out) # (M, D/2)
214
+
215
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
216
+ return emb
217
+
218
+
219
+ class ViTMAEEmbeddings(nn.Module):
220
+ """
221
+ Construct the CLS token, position and patch embeddings.
222
+
223
+ """
224
+
225
+ def __init__(self, config):
226
+ super().__init__()
227
+
228
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
229
+ self.patch_embeddings = ViTMAEPatchEmbeddings(config)
230
+ self.num_patches = self.patch_embeddings.num_patches
231
+ # fixed sin-cos embedding
232
+ self.position_embeddings = nn.Parameter(
233
+ torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
234
+ )
235
+ self.patch_size = config.patch_size
236
+ self.config = config
237
+
238
+ # multi-scale setting
239
+ self.scales = sorted([int(scale) for scale in config.scales.split('+')])
240
+ self.num_patch_each_scale = [(scale // config.patch_size)**2 for scale in self.scales]
241
+ if config.scale_embed:
242
+ self.scale_embed = nn.Parameter(torch.randn(len(self.scales), config.hidden_size) * 0)
243
+
244
+ # time embedding
245
+ if config.time_embed:
246
+ self.time_embed = nn.Parameter(torch.randn(config.max_num_frames, config.hidden_size) * 0)
247
+
248
+ def initialize_weights(self):
249
+ # initialize (and freeze) position embeddings by sin-cos embedding
250
+ pos_embed = get_2d_sincos_pos_embed(
251
+ self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
252
+ )
253
+ self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
254
+
255
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
256
+ w = self.patch_embeddings.projection.weight.data
257
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
258
+
259
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
260
+ torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
261
+
262
+ # initialize scale embed
263
+ if self.config.scale_embed:
264
+ torch.nn.init.normal_(self.scale_embed, std=self.config.initializer_range)
265
+
266
+ # initialize time embed
267
+ if self.config.time_embed:
268
+ torch.nn.init.normal_(self.time_embed, std=self.config.initializer_range)
269
+
270
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
271
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
272
+ """
273
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
274
+ images. This method is also adapted to support torch.jit tracing.
275
+
276
+ Adapted from:
277
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
278
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
279
+ """
280
+
281
+ num_patches = embeddings.shape[1] - 1
282
+ num_positions = self.position_embeddings.shape[1] - 1
283
+
284
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
285
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
286
+ return self.position_embeddings
287
+
288
+ class_pos_embed = self.position_embeddings[:, :1]
289
+ patch_pos_embed = self.position_embeddings[:, 1:]
290
+
291
+ dim = embeddings.shape[-1]
292
+
293
+ new_height = height // self.patch_size
294
+ new_width = width // self.patch_size
295
+
296
+ sqrt_num_positions = torch_int(num_positions**0.5)
297
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
298
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
299
+
300
+ patch_pos_embed = nn.functional.interpolate(
301
+ patch_pos_embed,
302
+ size=(new_height, new_width),
303
+ mode="bicubic",
304
+ align_corners=False,
305
+ )
306
+
307
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
308
+
309
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
310
+
311
+ def mask_with_gazing(self, sequence, gazing_info):
312
+ """
313
+ Mask the sequence with the gazing information.
314
+ For the padded gazing, we select a dummy token to fill in the positions (the dummy token is currently the first token in each sequence).
315
+
316
+ Args:
317
+ sequence: The sequence to mask.
318
+ gazing_info:
319
+ gazing_pos: The gazing positions of each whole sequence. (B, N)
320
+ num_gazing_each_frame: The number of gazing positions for each frame, including the padded gazing. (T, )
321
+ if_padded_gazing: Whether the gazing is padded. (B, N)
322
+ """
323
+ gazing_pos = gazing_info['gazing_pos'].clone()
324
+ num_gazing_each_frame = gazing_info['num_gazing_each_frame'].clone()
325
+ if_padded_gazing = gazing_info['if_padded_gazing'].clone()
326
+
327
+ B, seq_length, dim = sequence.shape
328
+ gaze_length = gazing_pos.shape[1]
329
+ assert gaze_length == num_gazing_each_frame.sum()
330
+
331
+ # Record the original sequence length into gazing_info
332
+ gazing_info['original_seq_length'] = seq_length
333
+
334
+ # Pad the sequence with an additional token for padded gazing to select
335
+ sequence = torch.cat([sequence, sequence[:, :1]], dim=1)
336
+
337
+ # Change all the padded gazing id to the last token id
338
+ gazing_pos = gazing_pos.flatten()
339
+ gazing_pos[if_padded_gazing.flatten()] = seq_length
340
+ gazing_pos = gazing_pos.view(B, -1)
341
+
342
+ # Get the unmasked part of the sequence for MAE encoding
343
+ sequence_unmasked = sequence[torch.arange(B)[..., None], gazing_pos]
344
+
345
+ return sequence_unmasked
346
+
347
+ def forward(self, pixel_values, gazing_info=None, noise=None, interpolate_pos_encoding: bool = False):
348
+ """
349
+ pixel_values: (B, T, C, H, W)
350
+ """
351
+ B, T = pixel_values.shape[:2]
352
+ pixel_values = rearrange(pixel_values, 'b t c h w -> (b t) c h w')
353
+
354
+ embeddings = []
355
+ for i, scale in enumerate(self.scales):
356
+ pixel_values_cur_scale = F.interpolate(pixel_values, size=(scale, scale), mode="bicubic", align_corners=False)
357
+ embeddings_cur_scale = self.patch_embeddings(pixel_values_cur_scale, interpolate_pos_encoding=interpolate_pos_encoding)
358
+ if interpolate_pos_encoding:
359
+ position_embeddings_cur_scale = self.interpolate_pos_encoding(embeddings_cur_scale, scale, scale)
360
+ else:
361
+ position_embeddings_cur_scale = self.position_embeddings
362
+
363
+ # add position embeddings w/o cls token
364
+ embeddings_cur_scale = embeddings_cur_scale + position_embeddings_cur_scale[:, 1:, :]
365
+
366
+ # add scale embedding
367
+ if self.config.scale_embed:
368
+ scale_embeddings_cur_scale = self.scale_embed[i][None, None]
369
+ embeddings_cur_scale = embeddings_cur_scale + scale_embeddings_cur_scale
370
+
371
+ embeddings.append(embeddings_cur_scale)
372
+ embeddings = torch.cat(embeddings, dim=1) # (B * T) * N * C
373
+
374
+ # add time embedding
375
+ embeddings = rearrange(embeddings, '(b t) n c -> b t n c', b=B, t=T) # B * T * N * C
376
+ if self.config.time_embed:
377
+ time_embeddings = self.time_embed[None, :T, None, :] # 1 * T * 1 * C
378
+ embeddings = embeddings + time_embeddings
379
+
380
+ embeddings = rearrange(embeddings, 'b t n c -> b (t n) c') # B * (T * N) * C
381
+
382
+ # masking: length -> length * config.mask_ratio
383
+ embeddings = self.mask_with_gazing(embeddings, gazing_info)
384
+
385
+ # append cls token
386
+ cls_token = self.cls_token + self.position_embeddings[:, :1, :]
387
+ cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
388
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
389
+
390
+ return embeddings
391
+
392
+
393
+ class ViTMAEPatchEmbeddings(nn.Module):
394
+ """
395
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
396
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
397
+ Transformer.
398
+ """
399
+
400
+ def __init__(self, config):
401
+ super().__init__()
402
+ image_size, patch_size = config.image_size, config.patch_size
403
+ num_channels, hidden_size = config.num_channels, config.hidden_size
404
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
405
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
406
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
407
+ self.image_size = image_size
408
+ self.patch_size = patch_size
409
+ self.num_channels = num_channels
410
+ self.num_patches = num_patches
411
+
412
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
413
+
414
+ def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
415
+ batch_size, num_channels, height, width = pixel_values.shape
416
+ if num_channels != self.num_channels:
417
+ raise ValueError(
418
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
419
+ )
420
+
421
+ if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
422
+ raise ValueError(
423
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
424
+ )
425
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
426
+ return x
427
+
428
+
429
+ # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
430
+ def eager_attention_forward(
431
+ module: nn.Module,
432
+ query: torch.Tensor,
433
+ key: torch.Tensor,
434
+ value: torch.Tensor,
435
+ attention_mask: Optional[torch.Tensor],
436
+ scaling: float,
437
+ dropout: float = 0.0,
438
+ **kwargs,
439
+ ):
440
+ # Take the dot product between "query" and "key" to get the raw attention scores.
441
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
442
+
443
+ # Normalize the attention scores to probabilities.
444
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
445
+
446
+ # This is actually dropping out entire tokens to attend to, which might
447
+ # seem a bit unusual, but is taken from the original Transformer paper.
448
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
449
+
450
+ # Mask heads if we want to
451
+ if attention_mask is not None:
452
+ attn_weights = attn_weights * attention_mask
453
+
454
+ attn_output = torch.matmul(attn_weights, value)
455
+ attn_output = attn_output.transpose(1, 2).contiguous()
456
+
457
+ return attn_output, attn_weights
458
+
459
+
460
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
461
+ class ViTMAESelfAttention(nn.Module):
462
+ def __init__(self, config: ViTMAEConfig) -> None:
463
+ super().__init__()
464
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
465
+ raise ValueError(
466
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
467
+ f"heads {config.num_attention_heads}."
468
+ )
469
+
470
+ self.config = config
471
+ self.num_attention_heads = config.num_attention_heads
472
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
473
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
474
+ self.dropout_prob = config.attention_probs_dropout_prob
475
+ self.scaling = self.attention_head_size**-0.5
476
+ self.is_causal = False
477
+
478
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
479
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
480
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
481
+
482
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
483
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
484
+ x = x.view(new_x_shape)
485
+ return x.permute(0, 2, 1, 3)
486
+
487
+ def forward(
488
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
489
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
490
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
491
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
492
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
493
+
494
+ attention_interface: Callable = eager_attention_forward
495
+ if self.config._attn_implementation != "eager":
496
+ if self.config._attn_implementation == "sdpa" and output_attentions:
497
+ logger.warning_once(
498
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
499
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
500
+ )
501
+ assert False, "SDPA doesn't support output_attentions=True. If falling back to eager, please change the attention mask implementation."
502
+ else:
503
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
504
+
505
+ context_layer, attention_probs = attention_interface(
506
+ self,
507
+ query_layer,
508
+ key_layer,
509
+ value_layer,
510
+ head_mask,
511
+ is_causal=self.is_causal,
512
+ scaling=self.scaling,
513
+ dropout=0.0 if not self.training else self.dropout_prob,
514
+ )
515
+
516
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
517
+ context_layer = context_layer.reshape(new_context_layer_shape)
518
+
519
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
520
+
521
+ return outputs
522
+
523
+
524
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
525
+ class ViTMAESelfOutput(nn.Module):
526
+ """
527
+ The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
528
+ layernorm applied before each block.
529
+ """
530
+
531
+ def __init__(self, config: ViTMAEConfig) -> None:
532
+ super().__init__()
533
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
534
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
535
+
536
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
537
+ hidden_states = self.dense(hidden_states)
538
+ hidden_states = self.dropout(hidden_states)
539
+
540
+ return hidden_states
541
+
542
+
543
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
544
+ class ViTMAEAttention(nn.Module):
545
+ def __init__(self, config: ViTMAEConfig) -> None:
546
+ super().__init__()
547
+ self.attention = ViTMAESelfAttention(config)
548
+ self.output = ViTMAESelfOutput(config)
549
+ self.pruned_heads = set()
550
+
551
+ def prune_heads(self, heads: Set[int]) -> None:
552
+ if len(heads) == 0:
553
+ return
554
+ heads, index = find_pruneable_heads_and_indices(
555
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
556
+ )
557
+
558
+ # Prune linear layers
559
+ self.attention.query = prune_linear_layer(self.attention.query, index)
560
+ self.attention.key = prune_linear_layer(self.attention.key, index)
561
+ self.attention.value = prune_linear_layer(self.attention.value, index)
562
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
563
+
564
+ # Update hyper params and store pruned heads
565
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
566
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
567
+ self.pruned_heads = self.pruned_heads.union(heads)
568
+
569
+ def forward(
570
+ self,
571
+ hidden_states: torch.Tensor,
572
+ head_mask: Optional[torch.Tensor] = None,
573
+ output_attentions: bool = False,
574
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
575
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
576
+
577
+ attention_output = self.output(self_outputs[0], hidden_states)
578
+
579
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
580
+ return outputs
581
+
582
+
583
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
584
+ class ViTMAEIntermediate(nn.Module):
585
+ def __init__(self, config: ViTMAEConfig) -> None:
586
+ super().__init__()
587
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
588
+ if isinstance(config.hidden_act, str):
589
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
590
+ else:
591
+ self.intermediate_act_fn = config.hidden_act
592
+
593
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
594
+ hidden_states = self.dense(hidden_states)
595
+ hidden_states = self.intermediate_act_fn(hidden_states)
596
+
597
+ return hidden_states
598
+
599
+
600
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE
601
+ class ViTMAEOutput(nn.Module):
602
+ def __init__(self, config: ViTMAEConfig) -> None:
603
+ super().__init__()
604
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
605
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
606
+
607
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
608
+ hidden_states = self.dense(hidden_states)
609
+ hidden_states = self.dropout(hidden_states)
610
+
611
+ hidden_states = hidden_states + input_tensor
612
+
613
+ return hidden_states
614
+
615
+
616
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
617
+ class ViTMAELayer(nn.Module):
618
+ """This corresponds to the Block class in the timm implementation."""
619
+
620
+ def __init__(self, config: ViTMAEConfig) -> None:
621
+ super().__init__()
622
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
623
+ self.seq_len_dim = 1
624
+ self.attention = ViTMAEAttention(config)
625
+ self.intermediate = ViTMAEIntermediate(config)
626
+ self.output = ViTMAEOutput(config)
627
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
628
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
629
+
630
+ def forward(
631
+ self,
632
+ hidden_states: torch.Tensor,
633
+ head_mask: Optional[torch.Tensor] = None,
634
+ output_attentions: bool = False,
635
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
636
+ self_attention_outputs = self.attention(
637
+ self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention
638
+ head_mask,
639
+ output_attentions=output_attentions,
640
+ )
641
+ attention_output = self_attention_outputs[0]
642
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
643
+
644
+ # first residual connection
645
+ hidden_states = attention_output + hidden_states
646
+
647
+ # in ViTMAE, layernorm is also applied after self-attention
648
+ layer_output = self.layernorm_after(hidden_states)
649
+ layer_output = self.intermediate(layer_output)
650
+
651
+ # second residual connection is done here
652
+ layer_output = self.output(layer_output, hidden_states)
653
+
654
+ outputs = (layer_output,) + outputs
655
+
656
+ return outputs
657
+
658
+
659
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
660
+ class ViTMAEEncoder(nn.Module):
661
+ def __init__(self, config: ViTMAEConfig) -> None:
662
+ super().__init__()
663
+ self.config = config
664
+ self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
665
+ self.gradient_checkpointing = False
666
+
667
+ def forward(
668
+ self,
669
+ hidden_states: torch.Tensor,
670
+ head_mask: Optional[torch.Tensor] = None,
671
+ output_attentions: bool = False,
672
+ output_hidden_states: bool = False,
673
+ return_dict: bool = True,
674
+ ) -> Union[tuple, BaseModelOutput]:
675
+ all_hidden_states = () if output_hidden_states else None
676
+ all_self_attentions = () if output_attentions else None
677
+
678
+ for i, layer_module in enumerate(self.layer):
679
+ if output_hidden_states:
680
+ all_hidden_states = all_hidden_states + (hidden_states,)
681
+
682
+ if self.gradient_checkpointing and self.training:
683
+ layer_outputs = self._gradient_checkpointing_func(
684
+ layer_module.__call__,
685
+ hidden_states,
686
+ head_mask,
687
+ output_attentions,
688
+ )
689
+ else:
690
+ layer_outputs = layer_module(hidden_states, head_mask, output_attentions)
691
+
692
+ hidden_states = layer_outputs[0]
693
+
694
+ if output_attentions:
695
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
696
+
697
+ if output_hidden_states:
698
+ all_hidden_states = all_hidden_states + (hidden_states,)
699
+
700
+ if not return_dict:
701
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
702
+ return BaseModelOutput(
703
+ last_hidden_state=hidden_states,
704
+ hidden_states=all_hidden_states,
705
+ attentions=all_self_attentions,
706
+ )
707
+
708
+
709
+ class ViTMAEPreTrainedModel(PreTrainedModel):
710
+ """
711
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
712
+ models.
713
+ """
714
+
715
+ config_class = ViTMAEConfig
716
+ base_model_prefix = "vit"
717
+ main_input_name = "pixel_values"
718
+ supports_gradient_checkpointing = True
719
+ _supports_sdpa = True
720
+ _supports_flash_attn_2 = True
721
+ _supports_flex_attn = True
722
+
723
+ def _init_weights(self, module):
724
+ """Initialize the weights"""
725
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
726
+ # Slightly different from the TF version which uses truncated_normal for initialization
727
+ # cf https://github.com/pytorch/pytorch/pull/5617
728
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
729
+ if module.bias is not None:
730
+ module.bias.data.zero_()
731
+ elif isinstance(module, nn.LayerNorm):
732
+ module.bias.data.zero_()
733
+ module.weight.data.fill_(1.0)
734
+ elif isinstance(module, ViTMAEEmbeddings):
735
+ module.initialize_weights()
736
+ elif isinstance(module, ViTMAEDecoder):
737
+ module.mask_token.data.zero_()
738
+ module.decoder_pos_embed.data.zero_()
739
+ if self.config.scale_embed:
740
+ torch.nn.init.normal_(module.decoder_scale_embed, std=self.config.initializer_range)
741
+ if self.config.time_embed:
742
+ torch.nn.init.normal_(module.time_embed, std=self.config.initializer_range)
743
+
744
+
745
+ class ViTMAEModel(ViTMAEPreTrainedModel):
746
+ def __init__(self, config):
747
+ super().__init__(config)
748
+ self.config = config
749
+
750
+ self.embeddings = ViTMAEEmbeddings(config)
751
+ self.encoder = ViTMAEEncoder(config)
752
+
753
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
754
+
755
+ # Initialize weights and apply final processing
756
+ self.post_init()
757
+
758
+ def get_input_embeddings(self):
759
+ return self.embeddings.patch_embeddings
760
+
761
+ def _prune_heads(self, heads_to_prune):
762
+ """
763
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
764
+ class PreTrainedModel
765
+ """
766
+ for layer, heads in heads_to_prune.items():
767
+ self.encoder.layer[layer].attention.prune_heads(heads)
768
+
769
+ def forward(
770
+ self,
771
+ pixel_values: Optional[torch.FloatTensor] = None,
772
+ gazing_info: Optional[dict] = None,
773
+ noise: Optional[torch.FloatTensor] = None,
774
+ head_mask: Optional[torch.FloatTensor] = None,
775
+ output_attentions: Optional[bool] = None,
776
+ output_hidden_states: Optional[bool] = None,
777
+ return_dict: Optional[bool] = None,
778
+ interpolate_pos_encoding: bool = False,
779
+ ) -> Union[Tuple, ViTMAEModelOutput]:
780
+
781
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
782
+ output_hidden_states = (
783
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
784
+ )
785
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
786
+
787
+ if pixel_values is None:
788
+ raise ValueError("You have to specify pixel_values")
789
+
790
+ # Prepare head mask if needed
791
+ # 1.0 in head_mask indicate we keep the head
792
+ # attention_probs has shape bsz x n_heads x N x N
793
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
794
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
795
+ head_mask = head_mask.to(self.dtype)
796
+
797
+ embedding_output = self.embeddings(
798
+ pixel_values, gazing_info=gazing_info, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
799
+ )
800
+
801
+ encoder_outputs = self.encoder(
802
+ embedding_output,
803
+ head_mask=head_mask,
804
+ output_attentions=output_attentions,
805
+ output_hidden_states=output_hidden_states,
806
+ return_dict=return_dict,
807
+ )
808
+ sequence_output = encoder_outputs[0]
809
+ sequence_output = self.layernorm(sequence_output)
810
+
811
+ if not return_dict:
812
+ return sequence_output + encoder_outputs[1:]
813
+
814
+ return ViTMAEModelOutput(
815
+ last_hidden_state=sequence_output,
816
+ hidden_states=encoder_outputs.hidden_states,
817
+ attentions=encoder_outputs.attentions,
818
+ )
819
+
820
+
821
+ class ViTMAEDecoder(ViTMAEPreTrainedModel):
822
+ def __init__(self, config, num_patches):
823
+ super().__init__(config)
824
+ self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
825
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
826
+ self.decoder_pos_embed = nn.Parameter(
827
+ torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False
828
+ ) # fixed sin-cos embedding
829
+
830
+ decoder_config = deepcopy(config)
831
+ decoder_config.hidden_size = config.decoder_hidden_size
832
+ decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
833
+ decoder_config.num_attention_heads = config.decoder_num_attention_heads
834
+ decoder_config.intermediate_size = config.decoder_intermediate_size
835
+ self.decoder_layers = nn.ModuleList(
836
+ [ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
837
+ )
838
+
839
+ self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
840
+ self.decoder_pred = nn.Linear(
841
+ config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
842
+ ) # encoder to decoder
843
+ self.gradient_checkpointing = False
844
+ self.config = config
845
+
846
+ # multi-scale setting
847
+ self.scales = sorted([int(scale) for scale in config.scales.split('+')])
848
+ self.num_patch_each_frame_each_scale = [(scale // config.patch_size)**2 for scale in self.scales]
849
+ if self.config.scale_embed:
850
+ self.decoder_scale_embed = nn.Parameter(torch.randn(len(self.scales), config.decoder_hidden_size) * 0)
851
+
852
+ # time embed
853
+ if self.config.time_embed:
854
+ self.time_embed = nn.Parameter(torch.randn(config.max_num_frames, config.decoder_hidden_size) * 0)
855
+
856
+ self.num_token_each_frame = sum(self.num_patch_each_frame_each_scale)
857
+
858
+ self.initialize_weights(num_patches)
859
+
860
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
861
+ """
862
+ This method is a modified version of the interpolation function for ViT-mae model at the decoder, that
863
+ allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
864
+ resolution images.
865
+
866
+ Adapted from:
867
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
868
+ """
869
+
870
+ # -1 removes the class dimension since we later append it without interpolation
871
+ embeddings_positions = embeddings.shape[1] - 1
872
+
873
+ # Separation of class token and patch tokens
874
+ class_pos_embed = self.decoder_pos_embed[:, :1]
875
+ patch_pos_embed = self.decoder_pos_embed[:, 1:]
876
+
877
+ # To retain the final 3d tensor with the required dimensions
878
+ dim = self.decoder_pos_embed.shape[-1]
879
+
880
+ # Increasing a dimension to enable bicubic interpolation
881
+ patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
882
+
883
+ # permute to bring the dimension to be interpolated, to the last
884
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
885
+
886
+ # Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
887
+ # we keep the second last dimension constant
888
+ patch_pos_embed = nn.functional.interpolate(
889
+ patch_pos_embed,
890
+ size=(patch_pos_embed.shape[-2], embeddings_positions),
891
+ mode="bicubic",
892
+ align_corners=False,
893
+ )
894
+
895
+ # Converting back to the original shape
896
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
897
+ # Adding the class token back
898
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
899
+
900
+ def initialize_weights(self, num_patches):
901
+ # initialize (and freeze) position embeddings by sin-cos embedding
902
+ decoder_pos_embed = get_2d_sincos_pos_embed(
903
+ self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
904
+ )
905
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
906
+
907
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
908
+ torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)
909
+
910
+ def forward(
911
+ self,
912
+ hidden_states,
913
+ gazing_info=None,
914
+ frame_idx_to_reconstruct=None,
915
+ head_mask=None,
916
+ output_attentions=False,
917
+ output_hidden_states=False,
918
+ return_dict=True,
919
+ interpolate_pos_encoding: bool = False,
920
+ ):
921
+ gazing_pos = gazing_info['gazing_pos']
922
+ num_gazing_each_frame = gazing_info['num_gazing_each_frame']
923
+ if_padded_gazing = gazing_info['if_padded_gazing']
924
+ original_seq_length = gazing_info['original_seq_length']
925
+
926
+ B = hidden_states.shape[0]
927
+ gaze_length = gazing_pos.shape[1]
928
+ assert gaze_length == num_gazing_each_frame.sum()
929
+ T = len(num_gazing_each_frame)
930
+ original_seq_length_each_frame = original_seq_length // T
931
+
932
+ # embed tokens
933
+ x = self.decoder_embed(hidden_states)
934
+
935
+ # Take out cls token
936
+ x_ = x[:, 1:, :]
937
+ cls_token = x[:, :1, :]
938
+
939
+ # Change all the padded gazing id to the last token id
940
+ gazing_pos = gazing_pos.flatten()
941
+ gazing_pos[if_padded_gazing.flatten()] = original_seq_length
942
+ gazing_pos = gazing_pos.view(B, -1)
943
+
944
+ # add mask tokens back to the sequence (temporarily append an additional token for padded gazing to select)
945
+ full_seq = self.mask_token.repeat(x.shape[0], original_seq_length + 1, 1).to(x.dtype)
946
+ full_seq[torch.arange(B)[..., None], gazing_pos] = x_
947
+ full_seq = full_seq[:, :-1, :]
948
+
949
+ # add pos embed and scale embed
950
+ full_seq = rearrange(full_seq, 'b (t n) c -> (b t) n c', t=T)
951
+ decoder_pos_embed = []
952
+ decoder_scale_embed = []
953
+ for i, scale in enumerate(self.scales):
954
+ x_cur_scale = full_seq[:, sum(self.num_patch_each_frame_each_scale[:i]):sum(self.num_patch_each_frame_each_scale[:i+1])]
955
+ if interpolate_pos_encoding:
956
+ decoder_pos_embed_cur_scale = self.interpolate_pos_encoding(F.pad(x_cur_scale, (0, 0, 1, 0)))[:, 1:]
957
+ else:
958
+ decoder_pos_embed_cur_scale = self.decoder_pos_embed
959
+ decoder_pos_embed.append(decoder_pos_embed_cur_scale)
960
+ if self.config.scale_embed:
961
+ decoder_scale_embed.append(self.decoder_scale_embed[i][None, None].repeat(1, decoder_pos_embed_cur_scale.shape[1], 1))
962
+ decoder_pos_embed = torch.cat(decoder_pos_embed, dim=1)
963
+ decoder_scale_embed = torch.cat(decoder_scale_embed, dim=1) if self.config.scale_embed else 0
964
+ full_seq = full_seq + decoder_pos_embed + decoder_scale_embed
965
+ full_seq = rearrange(full_seq, '(b t) n c -> b (t n) c', t=T)
966
+
967
+ # add time embed
968
+ if self.config.time_embed:
969
+ time_embed = self.time_embed[None, :T, None, :]
970
+ full_seq = rearrange(full_seq, 'b (t n) c -> b t n c', t=T)
971
+ full_seq = full_seq + time_embed
972
+ full_seq = rearrange(full_seq, 'b t n c -> b (t n) c', t=T)
973
+
974
+ # Get the index of tokens to feed into decoder (encoded tokens + mask tokens for selected frames)
975
+ idx_to_decode = gazing_pos.clone()
976
+ idx_to_decode = list(idx_to_decode.split(num_gazing_each_frame.tolist(), dim=-1))
977
+ for frame_idx in frame_idx_to_reconstruct:
978
+ idx_to_decode[frame_idx] = torch.arange(original_seq_length_each_frame, device=gazing_pos.device)[None].repeat(B, 1) + original_seq_length_each_frame * frame_idx
979
+ idx_to_decode = torch.cat(idx_to_decode, dim=-1)
980
+
981
+ # Get the tokens to decode
982
+ full_seq = torch.cat([full_seq, full_seq[:, :1]], dim=1)
983
+ hidden_states = full_seq[torch.arange(B)[..., None], idx_to_decode]
984
+
985
+ # add cls token
986
+ cls_token = cls_token + self.decoder_pos_embed[:, :1]
987
+ hidden_states = torch.cat([cls_token, hidden_states], dim=1)
988
+
989
+ # apply Transformer layers (blocks)
990
+ head_mask = head_mask.to(self.dtype)
991
+ all_hidden_states = () if output_hidden_states else None
992
+ all_self_attentions = () if output_attentions else None
993
+ for i, layer_module in enumerate(self.decoder_layers):
994
+ if output_hidden_states:
995
+ all_hidden_states = all_hidden_states + (hidden_states,)
996
+
997
+ if self.gradient_checkpointing and self.training:
998
+ layer_outputs = self._gradient_checkpointing_func(
999
+ layer_module.__call__,
1000
+ hidden_states,
1001
+ head_mask,
1002
+ output_attentions,
1003
+ )
1004
+ else:
1005
+ layer_outputs = layer_module(hidden_states, head_mask=head_mask, output_attentions=output_attentions)
1006
+
1007
+ hidden_states = layer_outputs[0]
1008
+
1009
+ if output_attentions:
1010
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1011
+
1012
+ if output_hidden_states:
1013
+ all_hidden_states = all_hidden_states + (hidden_states,)
1014
+
1015
+ hidden_states = self.decoder_norm(hidden_states)
1016
+
1017
+ # predictor projection
1018
+ logits = self.decoder_pred(hidden_states)
1019
+
1020
+ # remove cls token
1021
+ logits = logits[:, 1:, :]
1022
+
1023
+ if not return_dict:
1024
+ return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
1025
+ return ViTMAEDecoderOutput(
1026
+ logits=logits,
1027
+ hidden_states=all_hidden_states,
1028
+ attentions=all_self_attentions,
1029
+ )
1030
+
1031
+
1032
+ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
1033
+ def __init__(self, config):
1034
+ super().__init__(config)
1035
+ self.config = config
1036
+
1037
+ self.vit = ViTMAEModel(config)
1038
+ self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches)
1039
+
1040
+ # multi-scale setting
1041
+ self.scales = sorted([int(scale) for scale in config.scales.split('+')])
1042
+ self.num_patch_each_scale = [(scale // config.patch_size)**2 for scale in self.scales]
1043
+ self.num_token_each_frame = sum(self.num_patch_each_scale)
1044
+
1045
+ # loss setting
1046
+ self.loss_type = [str(loss) for loss in config.loss_type.split('+')]
1047
+ self.loss_weights = [float(weight) for weight in config.loss_weights.split('+')]
1048
+ self.transform = None # will be initialized in the outer
1049
+ self.loss_fns = []
1050
+ for loss in self.loss_type:
1051
+ if loss == 'l1':
1052
+ self.loss_fns.append(self.l1_loss)
1053
+ elif loss == 'dinov2_reg':
1054
+ self.dinov2_reg = None # will be initialized in the outer
1055
+ self.dinov2_reg_transform = None # will be initialized in the outer
1056
+ self.loss_fns.append(self.dinov2_reg_loss)
1057
+ elif loss == 'siglip2':
1058
+ self.siglip2 = None # will be initialized in the outer
1059
+ self.siglip2_transform = None # will be initialized in the outer
1060
+ self.loss_fns.append(self.siglip2_loss)
1061
+ else:
1062
+ raise ValueError(f"Loss type {loss} not supported")
1063
+
1064
+ # Initialize weights and apply final processing
1065
+ self.post_init()
1066
+
1067
+ def get_input_embeddings(self):
1068
+ return self.vit.embeddings.patch_embeddings
1069
+
1070
+ def _prune_heads(self, heads_to_prune):
1071
+ """
1072
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1073
+ class PreTrainedModel
1074
+ """
1075
+ for layer, heads in heads_to_prune.items():
1076
+ self.encoder.layer[layer].attention.prune_heads(heads)
1077
+
1078
+ def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
1079
+ """
1080
+ Args:
1081
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1082
+ Pixel values.
1083
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
1084
+ interpolation flag passed during the forward pass.
1085
+
1086
+ Returns:
1087
+ `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1088
+ Patchified pixel values.
1089
+ """
1090
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
1091
+ # sanity checks
1092
+ if not interpolate_pos_encoding and (
1093
+ pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
1094
+ ):
1095
+ raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
1096
+ if pixel_values.shape[1] != num_channels:
1097
+ raise ValueError(
1098
+ "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
1099
+ )
1100
+
1101
+ # patchify
1102
+ batch_size = pixel_values.shape[0]
1103
+ num_patches_h = pixel_values.shape[2] // patch_size
1104
+ num_patches_w = pixel_values.shape[3] // patch_size
1105
+ patchified_pixel_values = pixel_values.reshape(
1106
+ batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
1107
+ )
1108
+ patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
1109
+ patchified_pixel_values = patchified_pixel_values.reshape(
1110
+ batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
1111
+ )
1112
+ return patchified_pixel_values
1113
+
1114
+ def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
1115
+ """
1116
+ Args:
1117
+ patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1118
+ Patchified pixel values.
1119
+ original_image_size (`Tuple[int, int]`, *optional*):
1120
+ Original image size.
1121
+
1122
+ Returns:
1123
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
1124
+ Pixel values.
1125
+ """
1126
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
1127
+ original_image_size = (
1128
+ original_image_size
1129
+ if original_image_size is not None
1130
+ else (self.config.image_size, self.config.image_size)
1131
+ )
1132
+ original_height, original_width = original_image_size
1133
+ num_patches_h = original_height // patch_size
1134
+ num_patches_w = original_width // patch_size
1135
+ # sanity check
1136
+ if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
1137
+ raise ValueError(
1138
+ f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
1139
+ )
1140
+
1141
+ # unpatchify
1142
+ batch_size = patchified_pixel_values.shape[0]
1143
+ patchified_pixel_values = patchified_pixel_values.reshape(
1144
+ batch_size,
1145
+ num_patches_h,
1146
+ num_patches_w,
1147
+ patch_size,
1148
+ patch_size,
1149
+ num_channels,
1150
+ )
1151
+ patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
1152
+ pixel_values = patchified_pixel_values.reshape(
1153
+ batch_size,
1154
+ num_channels,
1155
+ num_patches_h * patch_size,
1156
+ num_patches_w * patch_size,
1157
+ )
1158
+ return pixel_values
1159
+
1160
+ def retransform(self, image, source_transform, target_transform):
1161
+ # Revert the source transform
1162
+ image = rearrange(image, 'b c h w -> b h w c')
1163
+ if source_transform.do_normalize:
1164
+ image = image * torch.tensor(source_transform.image_std, device=image.device, dtype=image.dtype) + torch.tensor(source_transform.image_mean, device=image.device, dtype=image.dtype)
1165
+ if source_transform.do_rescale:
1166
+ if hasattr(source_transform, 'offset') and source_transform.offset:
1167
+ image = image + 1
1168
+ image = image / source_transform.rescale_factor
1169
+ image = rearrange(image, 'b h w c -> b c h w')
1170
+
1171
+ # Apply the target transform
1172
+ image = rearrange(image, 'b c h w -> b h w c')
1173
+ if target_transform.do_rescale:
1174
+ image = image * target_transform.rescale_factor
1175
+ if hasattr(target_transform, 'offset') and target_transform.offset:
1176
+ image = image - 1
1177
+ if target_transform.do_normalize:
1178
+ image = (image - torch.tensor(target_transform.image_mean, device=image.device, dtype=image.dtype)) / torch.tensor(target_transform.image_std, device=image.device, dtype=image.dtype)
1179
+ image = rearrange(image, 'b h w c -> b c h w')
1180
+
1181
+ return image
1182
+
1183
+ def l1_loss(self, pred, target):
1184
+ """
1185
+ pred, target: (B, C, H, W)
1186
+ """
1187
+ return (pred - target).abs().mean(dim=(-1, -2, -3))
1188
+
1189
+ def dinov2_reg_loss(self, pred, target):
1190
+ """
1191
+ pred, target: (B, C, H, W)
1192
+ """
1193
+ def get_dinov2_reg_features(image):
1194
+ image = self.retransform(image, self.transform, self.dinov2_reg_transform)
1195
+ features = self.dinov2_reg(image, output_hidden_states=True).hidden_states
1196
+ features = torch.cat([feature[:, self.dinov2_reg.config.num_register_tokens + 1:] for feature in features[-4:]], dim=-1)
1197
+ return features
1198
+
1199
+ pred_features = get_dinov2_reg_features(pred)
1200
+ target_features = get_dinov2_reg_features(target)
1201
+
1202
+ # Get average l2 loss over last k layers' features
1203
+ loss = (pred_features - target_features).pow(2).mean(dim=(-1, -2))
1204
+
1205
+ return loss
1206
+
1207
+ def siglip2_loss(self, pred, target):
1208
+ """
1209
+ pred, target: (B, C, H, W)
1210
+ """
1211
+ def get_siglip2_features(image):
1212
+ image = self.retransform(image, self.transform, self.siglip2_transform)
1213
+ features = self.siglip2(image, output_hidden_states=True).hidden_states
1214
+ features = torch.cat(features[-4:], dim=-1)
1215
+ return features
1216
+
1217
+ pred_features = get_siglip2_features(pred)
1218
+ target_features = get_siglip2_features(target)
1219
+
1220
+ # Get average l2 loss over last k layers' features
1221
+ loss = (pred_features - target_features).pow(2).mean(dim=(-1, -2))
1222
+
1223
+ return loss
1224
+
1225
+ def forward_loss(self, pixel_values, pred, interpolate_pos_encoding: bool = False):
1226
+ """
1227
+ Args:
1228
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, T, num_channels, height, width)`):
1229
+ Pixel values.
1230
+ pred (`torch.FloatTensor` of shape `(batch_size, T, num_patches, patch_size**2 * num_channels)`:
1231
+ Predicted pixel values.
1232
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
1233
+ interpolation flag passed during the forward pass.
1234
+
1235
+ Returns:
1236
+ `torch.FloatTensor`: Pixel reconstruction loss.
1237
+ """
1238
+ B, T = pixel_values.shape[:2]
1239
+ pixel_values = pixel_values.flatten(0, 1) # (B * T), C, H, W
1240
+ pred = pred.flatten(0, 1) # (B * T), N, C
1241
+
1242
+ pred = self.unpatchify(pred, original_image_size=(pixel_values.shape[2], pixel_values.shape[3]))
1243
+
1244
+ loss = 0
1245
+ for loss_fn, loss_weight in zip(self.loss_fns, self.loss_weights):
1246
+ loss += loss_weight * loss_fn(pred, pixel_values)
1247
+
1248
+ loss = rearrange(loss, '(b t) -> b t', b=B, t=T)
1249
+ mean_loss = loss.mean(dim=-1)
1250
+
1251
+ return loss, mean_loss
1252
+
1253
+ def get_reconstructed_image(self, pixel_values, pred, interpolate_pos_encoding: bool = False):
1254
+ """
1255
+ pixel_values: (B, T, C, H, W)
1256
+ pred: (B, T, N, C)
1257
+ """
1258
+ B, T = pixel_values.shape[:2]
1259
+ pixel_values = pixel_values.flatten(0, 1) # (B * T), C, H, W
1260
+ pred = pred.flatten(0, 1) # (B * T), N, C
1261
+
1262
+ pred = self.unpatchify(pred, original_image_size=(pixel_values.shape[2], pixel_values.shape[3]))
1263
+
1264
+ pred = rearrange(pred, '(b t) c h w -> b t c h w', b=B, t=T)
1265
+
1266
+ return pred
1267
+
1268
+ def get_causal_mask(self, num_tokens_each_frame, num_layers, batch_size, num_heads, token_mask=None, cls_token=True):
1269
+ """
1270
+ Assume a input of shape B * N * C, where N contains tokens from several frames.
1271
+ Each frame has num_tokens_each_frame[t] tokens.
1272
+ Create a block-causal attention mask such that each token can only attend to tokens from either previous frames or the same frame.
1273
+ Additionally, mask any tokens indicated by token_mask (e.g., the tokens at padded gazing positions)
1274
+
1275
+ Inputs:
1276
+ num_tokens_each_frame: (T)
1277
+ token_mask: (B, N)
1278
+ cls_token: whether to include the cls token in the mask
1279
+ Return:
1280
+ mask: batch x num_heads x seq_length x seq_length
1281
+ """
1282
+ T = len(num_tokens_each_frame)
1283
+ N = num_tokens_each_frame.sum()
1284
+ device = num_tokens_each_frame.device
1285
+
1286
+ # Create a causal mask
1287
+ mask = torch.tril(torch.ones(batch_size, N, N, device=device))
1288
+
1289
+ # Make the tokens inside each frame attend to each other
1290
+ for t in range(T):
1291
+ mask[:, sum(num_tokens_each_frame[:t]):sum(num_tokens_each_frame[:t+1]), sum(num_tokens_each_frame[:t]):sum(num_tokens_each_frame[:t+1])] = 1
1292
+
1293
+ # Mask out tokens indicated by token_mask
1294
+ if token_mask is not None:
1295
+ token_mask = token_mask.unsqueeze(1).repeat(1, N, 1)
1296
+ mask = mask * (~token_mask).float()
1297
+
1298
+ # Add mask for cls token
1299
+ if cls_token:
1300
+ mask_ = mask.clone()
1301
+ mask = torch.tril(torch.ones(batch_size, N + 1, N + 1, device=device))
1302
+ mask[:, 1:, 1:] = mask_
1303
+
1304
+ # Each token must be able to attend to itself
1305
+ mask[:, torch.arange(N), torch.arange(N)] = 1
1306
+
1307
+ # According to different attention implementations, the mask values are different.
1308
+ if self.config._attn_implementation == "flex_attention" or self.config._attn_implementation == "sdpa":
1309
+ # mask is a float tensor that will be added to the attention scores. This means the tokens to be attended should have mask value of 0, and the rest should have mask value of -inf.
1310
+ mask = torch.where(mask == 1, 0, -torch.inf)
1311
+ elif self.config._attn_implementation == "flash_attention_2":
1312
+ raise NotImplementedError("Flash attention 2 doesn't support custom attention mask. Please use attention_implementation='flex_attention'.")
1313
+ elif self.config._attn_implementation == "eager":
1314
+ # mask is a float tensor that will be multiplied to the attn prob after softmax. This means the tokens to be attended should have mask value of 1, and the rest should have mask value of 0.
1315
+ pass
1316
+
1317
+ mask = mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
1318
+
1319
+ return mask.to(num_tokens_each_frame.device)
1320
+
1321
+ def forward(
1322
+ self,
1323
+ pixel_values: Optional[torch.FloatTensor] = None,
1324
+ gazing_info: Optional[dict] = None,
1325
+ frame_idx_to_reconstruct: Optional[torch.LongTensor] = None,
1326
+ noise: Optional[torch.FloatTensor] = None,
1327
+ output_attentions: Optional[bool] = None,
1328
+ output_hidden_states: Optional[bool] = None,
1329
+ return_dict: Optional[bool] = None,
1330
+ interpolate_pos_encoding: bool = False,
1331
+ ) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
1332
+ """
1333
+ pixel_values: (B, T, C, H, W)
1334
+ gazing_info:
1335
+ gazing_pos: The gazing positions of each whole sequence. (B, N)
1336
+ num_gazing_each_frame: The number of gazing positions for each frame, including the padded gazing. (T, )
1337
+ if_padded_gazing: Whether the gazing is padded. (B, N)
1338
+ frame_idx_to_reconstruct: (num_selected_frames, )
1339
+ """
1340
+ B, T = pixel_values.shape[:2]
1341
+
1342
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1343
+
1344
+ # Get the encoder attention mask
1345
+ encoder_attn_mask = self.get_causal_mask(gazing_info['num_gazing_each_frame'], self.config.num_hidden_layers, B, self.config.num_attention_heads, token_mask=gazing_info['if_padded_gazing'], cls_token=True) if self.config.causal else None
1346
+
1347
+ # Get the encoder outputs
1348
+ outputs = self.vit(
1349
+ pixel_values,
1350
+ gazing_info=gazing_info,
1351
+ noise=noise,
1352
+ head_mask=encoder_attn_mask,
1353
+ output_attentions=output_attentions,
1354
+ output_hidden_states=output_hidden_states,
1355
+ return_dict=return_dict,
1356
+ interpolate_pos_encoding=interpolate_pos_encoding,
1357
+ )
1358
+ latent = outputs.last_hidden_state # B * N * C
1359
+
1360
+ # Get the number of tokens to decode for each frame
1361
+ num_decoded_tokens_each_frame = gazing_info['num_gazing_each_frame'].clone()
1362
+ num_decoded_tokens_each_frame[frame_idx_to_reconstruct] = self.num_token_each_frame
1363
+
1364
+ # Get the gazing padding mask for decoder
1365
+ if_padded_gazing_decoder = gazing_info['if_padded_gazing'].clone()
1366
+ if_padded_gazing_decoder = list(if_padded_gazing_decoder.split(gazing_info['num_gazing_each_frame'].tolist(), dim=-1))
1367
+ for frame_idx in frame_idx_to_reconstruct:
1368
+ if_padded_gazing_decoder[frame_idx] = torch.zeros(B, self.num_token_each_frame).to(gazing_info['if_padded_gazing'].device).to(torch.bool)
1369
+ if_padded_gazing_decoder = torch.cat(if_padded_gazing_decoder, dim=-1)
1370
+
1371
+ # Get the decoder attention mask
1372
+ decoder_attn_mask = self.get_causal_mask(num_decoded_tokens_each_frame, self.config.decoder_num_hidden_layers, B, self.config.decoder_num_attention_heads, token_mask=if_padded_gazing_decoder, cls_token=True) if self.config.causal else None
1373
+
1374
+ # Get the decoder outputs
1375
+ decoder_outputs = self.decoder(
1376
+ latent,
1377
+ gazing_info=gazing_info,
1378
+ frame_idx_to_reconstruct=frame_idx_to_reconstruct,
1379
+ head_mask=decoder_attn_mask,
1380
+ interpolate_pos_encoding=interpolate_pos_encoding,
1381
+ )
1382
+ logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
1383
+
1384
+ # Only keep the predictions for the selected frames
1385
+ decoded_token_idx_to_keep = []
1386
+ for frame_idx in frame_idx_to_reconstruct:
1387
+ decoded_token_idx_to_keep.append(torch.arange(sum(num_decoded_tokens_each_frame[:frame_idx]), sum(num_decoded_tokens_each_frame[:frame_idx+1])))
1388
+ decoded_token_idx_to_keep = torch.cat(decoded_token_idx_to_keep, dim=0)
1389
+ logits = logits[:, decoded_token_idx_to_keep]
1390
+ logits = rearrange(logits, 'b (t n) c -> b t n c', t=len(frame_idx_to_reconstruct)) # B * num_selected_frames * N * C
1391
+
1392
+ # throw away the reconstruction and masks for smaller scales
1393
+ logits = logits[:, :, sum(self.num_patch_each_scale[:-1]):, :]
1394
+
1395
+ loss_each_reconstruction_frame, loss_mean = self.forward_loss(pixel_values[:, frame_idx_to_reconstruct], logits, interpolate_pos_encoding=interpolate_pos_encoding)
1396
+ reconstruction = self.get_reconstructed_image(pixel_values[:, frame_idx_to_reconstruct], logits, interpolate_pos_encoding=interpolate_pos_encoding) # B * num_selected_frames * C * H * W
1397
+
1398
+ if not return_dict:
1399
+ output = (logits, reconstruction) + outputs[2:]
1400
+ return ((loss_each_reconstruction_frame, loss_mean) + output) if loss_each_reconstruction_frame is not None else output
1401
+
1402
+ return ViTMAEForPreTrainingOutput(
1403
+ loss_each_reconstruction_frame=loss_each_reconstruction_frame,
1404
+ loss_mean=loss_mean,
1405
+ reconstruction=reconstruction,
1406
+ logits=logits,
1407
+ hidden_states=outputs.hidden_states,
1408
+ attentions=outputs.attentions,
1409
+ )
1410
+
1411
+
1412
+ __all__ = ["ViTMAEForPreTraining", "ViTMAELayer", "ViTMAEModel", "ViTMAEPreTrainedModel"]
autogaze/tasks/video_mae_reconstruction/task_video_mae_reconstruction.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from transformers import AutoModel, AutoImageProcessor, VivitImageProcessor
6
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
7
+ from transformers.models.siglip2.modeling_siglip2 import Siglip2VisionModel
8
+
9
+ from .modeling_video_mae import ViTMAEForPreTraining
10
+ from .visualize_video_mae_reconstruction import VisualizeReconstruction
11
+
12
+
13
+ class VideoMAEReconstruction(nn.Module):
14
+ def __init__(self, recon_model, recon_model_config, scales, recon_sample_rate, attn_mode):
15
+ super().__init__()
16
+
17
+ # Create model
18
+ self.scales = sorted([int(scale) for scale in str(scales).split("+")])
19
+ self.transform = VivitImageProcessor.from_pretrained(recon_model, size=self.scales[-1]) # use mae image preprocessor config to intialize video preprocessor
20
+ self.mae = ViTMAEForPreTraining.from_pretrained(recon_model, attn_implementation="sdpa", scales=str(scales), **OmegaConf.to_container(recon_model_config))
21
+ self.mae.transform = self.transform
22
+ if "dinov2_reg" in self.mae.loss_type:
23
+ self.mae.dinov2_reg = AutoModel.from_pretrained(recon_model_config.dinov2_reg_loss_config.model, attn_implementation=attn_mode)
24
+ self.mae.dinov2_reg_transform = AutoImageProcessor.from_pretrained(recon_model_config.dinov2_reg_loss_config.model)
25
+ for param in self.mae.dinov2_reg.parameters():
26
+ param.requires_grad = False
27
+ self.mae.dinov2_reg.eval()
28
+ if "siglip2" in self.mae.loss_type:
29
+ if "naflex" in recon_model_config.siglip2_loss_config.model:
30
+ self.mae.siglip2 = Siglip2VisionModel.from_pretrained(recon_model_config.siglip2_loss_config.model, attn_implementation=attn_mode)
31
+ else:
32
+ self.mae.siglip2 = SiglipVisionModel.from_pretrained(recon_model_config.siglip2_loss_config.model, attn_implementation=attn_mode)
33
+ self.mae.siglip2_transform = AutoImageProcessor.from_pretrained(recon_model_config.siglip2_loss_config.model)
34
+ for param in self.mae.siglip2.parameters():
35
+ param.requires_grad = False
36
+ self.mae.siglip2.eval()
37
+
38
+ # Sampling strategy for reconstruction
39
+ self.recon_sample_rate = recon_sample_rate
40
+
41
+ # Create visualization methods
42
+ self.visualize_methods = [VisualizeReconstruction()]
43
+
44
+ # kwargs for the gaze model input. Will be passed to the gaze model during training.
45
+ self.gaze_model_kwargs = {
46
+ "target_scales": self.scales,
47
+ "target_patch_size": self.mae.config.patch_size,
48
+ "target_image_mean": self.transform.image_mean,
49
+ "target_image_std": self.transform.image_std,
50
+ }
51
+
52
+ @torch.autocast("cuda", dtype=torch.bfloat16)
53
+ def forward_output(self, inputs, gaze_outputs, frame_idx_to_reconstruct=None):
54
+ """
55
+ Get all the outputs from the inputs
56
+ """
57
+ video = inputs['video']
58
+ gazing_pos = gaze_outputs['gazing_pos']
59
+ num_gazing_each_frame = gaze_outputs['num_gazing_each_frame']
60
+ if_padded_gazing = gaze_outputs['if_padded_gazing']
61
+ frame_sampling_rate = gaze_outputs['frame_sampling_rate']
62
+ num_vision_tokens_each_frame = gaze_outputs['num_vision_tokens_each_frame']
63
+
64
+ assert frame_sampling_rate == 1, "If frame_sampling_rate > 1, we can downsample the video here but ideally we don't want to do that"
65
+ assert num_vision_tokens_each_frame == sum([(scale // self.mae.config.patch_size) ** 2 for scale in self.scales]), "The number of vision tokens in each frame is not consistent between gaze model and MAE model"
66
+
67
+ # Frame sampling strategy for reconstruction
68
+ B, T = video.shape[:2]
69
+ if frame_idx_to_reconstruct is None:
70
+ frame_idx_to_reconstruct = torch.randperm(T)[:int(T * self.recon_sample_rate)].to(video.device)
71
+
72
+ # Reconstruct the video
73
+ gazing_info = {
74
+ 'gazing_pos': gazing_pos,
75
+ 'num_gazing_each_frame': num_gazing_each_frame,
76
+ 'if_padded_gazing': if_padded_gazing,
77
+ }
78
+ recon_output = self.mae(video, gazing_info=gazing_info, frame_idx_to_reconstruct=frame_idx_to_reconstruct, interpolate_pos_encoding=True)
79
+ recon_loss_mean = recon_output.loss_mean
80
+ recon_loss_each_reconstruction_frame = recon_output.loss_each_reconstruction_frame
81
+ num_gazing_before_each_reconstruction_frame = torch.stack([num_gazing_each_frame[:frame_idx+1].sum(dim=-1) for frame_idx in frame_idx_to_reconstruct], dim=0)
82
+ num_non_padded_gazing_at_each_reconstruction_frame = [(~if_padded_gazing)[:, num_gazing_each_frame[:frame_idx].sum():num_gazing_each_frame[:frame_idx+1].sum()].sum(dim=-1) for frame_idx in frame_idx_to_reconstruct]
83
+ num_non_padded_gazing_at_each_reconstruction_frame = torch.stack(num_non_padded_gazing_at_each_reconstruction_frame, dim=-1) # B * num_reconstruction_frames
84
+
85
+ # Organize the recon loss at each gazing token
86
+ if_padded_gazing_each_frame = list(if_padded_gazing.split(num_gazing_each_frame.tolist(), dim=-1))
87
+ reconstruction_loss_each_gazing_token = [torch.zeros(*if_padded_gazing_each_frame[t].shape, dtype=gazing_pos.dtype, device=gazing_pos.device) for t in range(len(num_gazing_each_frame))]
88
+ reconstruction_loss_each_gazing_token_mask = [torch.zeros(*if_padded_gazing_each_frame[t].shape, dtype=gazing_pos.dtype, device=gazing_pos.device) for t in range(len(num_gazing_each_frame))]
89
+ for i, frame_idx in enumerate(frame_idx_to_reconstruct):
90
+ cur_mask = F.pad(if_padded_gazing_each_frame[frame_idx][:, 1:], (0, 1), value=True).to(torch.float)
91
+ reconstruction_loss_each_gazing_token[frame_idx] = recon_loss_each_reconstruction_frame[:, i:i+1] * cur_mask
92
+ reconstruction_loss_each_gazing_token_mask[frame_idx] = cur_mask
93
+ reconstruction_loss_each_gazing_token = torch.cat(reconstruction_loss_each_gazing_token, dim=-1) # B * N
94
+ reconstruction_loss_each_gazing_token_mask = torch.cat(reconstruction_loss_each_gazing_token_mask, dim=-1) # B * N
95
+
96
+ outputs = {
97
+ "reconstruction": recon_output.reconstruction,
98
+ "reconstruction_loss": recon_loss_mean,
99
+ "reconstruction_loss_each_reconstruction_frame": recon_loss_each_reconstruction_frame,
100
+ "reconstruction_loss_each_gazing_token": reconstruction_loss_each_gazing_token,
101
+ "reconstruction_loss_each_gazing_token_mask": reconstruction_loss_each_gazing_token_mask,
102
+ "num_gazing_before_each_reconstruction_frame": num_gazing_before_each_reconstruction_frame,
103
+ "num_non_padded_gazing_at_each_reconstruction_frame": num_non_padded_gazing_at_each_reconstruction_frame,
104
+ "frame_idx_to_reconstruct": frame_idx_to_reconstruct,
105
+ "image_mean": self.transform.image_mean,
106
+ "image_std": self.transform.image_std,
107
+ "rescale_factor": self.transform.rescale_factor,
108
+ "scales": self.scales,
109
+ }
110
+ return outputs
111
+
112
+ def loss(self, inputs, gaze_outputs, outputs):
113
+ """
114
+ Compute the loss of the outputs. Used for training the task itself.
115
+ """
116
+ reconstruction_loss = outputs['reconstruction_loss']
117
+ reconstruction_loss_each_gazing_token = outputs['reconstruction_loss_each_gazing_token']
118
+ reconstruction_loss_each_gazing_token_mask = outputs['reconstruction_loss_each_gazing_token_mask']
119
+ return reconstruction_loss, reconstruction_loss_each_gazing_token, reconstruction_loss_each_gazing_token_mask
120
+
121
+ def reward(self, inputs, gaze_outputs, outputs):
122
+ """
123
+ Compute the reward of the outputs. Used for training the gazing model.
124
+ """
125
+ reconstruction_loss_each_reconstruction_frame = outputs['reconstruction_loss_each_reconstruction_frame']
126
+ rewards = -reconstruction_loss_each_reconstruction_frame.detach()
127
+
128
+ # Gazing length before each reward
129
+ traj_len_each_reward = outputs['num_gazing_before_each_reconstruction_frame']
130
+
131
+ return rewards, traj_len_each_reward
132
+
133
+ def metric(self, inputs, gaze_outputs, outputs):
134
+ """
135
+ Compute the metric used for recording during validation.
136
+ """
137
+ # Reconstruction loss
138
+ reconstruction_loss, _, __ = self.loss(inputs, gaze_outputs, outputs)
139
+ reconstruction_loss = reconstruction_loss.mean()
140
+
141
+ # Average gazing ratio per frame
142
+ bs, num_frames = inputs['video'].shape[:2]
143
+ num_vision_tokens_each_frame = gaze_outputs['num_vision_tokens_each_frame']
144
+ num_gazing_total = (~gaze_outputs['if_padded_gazing']).sum()
145
+ avg_gazing_ratio = num_gazing_total / (bs *num_frames * num_vision_tokens_each_frame)
146
+
147
+ metrics = {
148
+ 'reconstruction_loss': reconstruction_loss,
149
+ 'avg_gazing_ratio_per_frame': avg_gazing_ratio,
150
+ }
151
+ return metrics
152
+
153
+ def visualize(self, inputs, gaze_outputs, task_outputs, rl_outputs=None):
154
+ """
155
+ Visualize the outputs.
156
+ """
157
+ for method in self.visualize_methods:
158
+ method(inputs, gaze_outputs, task_outputs, rl_outputs)
159
+
160
+ def forward(self, inputs, gaze_outputs):
161
+ """
162
+ Compute the outputs and the loss, reward, and metric of the outputs.
163
+ inputs:
164
+ image: B, C, H, W
165
+ gaze_outputs:
166
+ gazing_pos: B, N
167
+ """
168
+ outputs = self.forward_output(inputs, gaze_outputs)
169
+ loss, reconstruction_loss_each_gazing_token, reconstruction_loss_each_gazing_token_mask = self.loss(inputs, gaze_outputs, outputs)
170
+ reward, traj_len_each_reward = self.reward(inputs, gaze_outputs, outputs)
171
+ metric = self.metric(inputs, gaze_outputs, outputs)
172
+
173
+ to_return = {
174
+ 'outputs': outputs,
175
+ 'loss': loss,
176
+ 'reward': reward,
177
+ 'traj_len_each_reward': traj_len_each_reward,
178
+ 'task_losses': reconstruction_loss_each_gazing_token,
179
+ 'task_losses_mask': reconstruction_loss_each_gazing_token_mask,
180
+ 'metrics': metric,
181
+ }
182
+ return to_return
autogaze/tasks/video_mae_reconstruction/visualize_video_mae_reconstruction.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import wandb
5
+ import matplotlib.pyplot as plt
6
+ from autogaze.utils import UnNormalize
7
+
8
+ class VisualizeReconstruction:
9
+ def __init__(self, **kwargs):
10
+
11
+ self.visualize_step = 0
12
+ if wandb.run is not None:
13
+ # define our custom x axis metric
14
+ wandb.define_metric("visualize_gaze/visualize_step")
15
+ # define which metrics will be plotted against it
16
+ wandb.define_metric("visualize_gaze/*", step_metric="visualize_gaze/visualize_step")
17
+
18
+ @torch.no_grad()
19
+ def __call__(self, inputs, gaze_outputs, task_outputs, rl_outputs):
20
+ # Get all information for visualization
21
+ videos = inputs['video']
22
+ gazing_mask = gaze_outputs['gazing_mask'] # containing multi-scale masks; list of B * T * N_each_scale
23
+ frame_sampling_rate = gaze_outputs['frame_sampling_rate']
24
+ scales = task_outputs['outputs']['scales']
25
+ reconstruction = task_outputs['outputs']['reconstruction']
26
+ frame_idx_to_reconstruct = task_outputs['outputs']['frame_idx_to_reconstruct']
27
+ image_mean = task_outputs['outputs']['image_mean']
28
+ image_std = task_outputs['outputs']['image_std']
29
+ rescale_factor = task_outputs['outputs']['rescale_factor']
30
+ num_scales = len(scales)
31
+
32
+ # sample the frames to visualize
33
+ videos = videos[:, ::frame_sampling_rate]
34
+ assert videos.shape[1] == gazing_mask[0].shape[1]
35
+
36
+ # only visualize the first instance
37
+ video = videos[0]
38
+ gazing_mask = [m[0] for m in gazing_mask]
39
+ reconstruction = reconstruction[0]
40
+
41
+ unnormalize = UnNormalize(image_mean, image_std, rescale_factor)
42
+
43
+ video = unnormalize(video)
44
+ reconstruction = unnormalize(reconstruction)
45
+ video = video.cpu().float().numpy()
46
+ reconstruction = reconstruction.cpu().float().numpy()
47
+
48
+ # complete the reconstruction by filling the unselected frames
49
+ reconstruction_full = np.zeros_like(video)
50
+ reconstruction_full[frame_idx_to_reconstruct.cpu().numpy()] = reconstruction
51
+ reconstruction = reconstruction_full
52
+
53
+ # Create a figure with subplots: original video frames and one row for each scale's masked video frames
54
+ T = video.shape[0] # Number of frames
55
+ fig, axes = plt.subplots(num_scales + 2, T, figsize=(3 * T, 3 * (num_scales + 2)))
56
+
57
+ # Plot original video frames
58
+ for t in range(T):
59
+ frame = video[t].transpose(1, 2, 0) # C * H * W -> H * W * C
60
+ axes[0, t].imshow(frame)
61
+ axes[0, t].set_title(f'Original Frame {t+1}')
62
+ axes[0, t].axis('off')
63
+
64
+ # Visualize masked video for each scale
65
+ for scale_idx in range(num_scales):
66
+ scale_mask = gazing_mask[scale_idx] # T * N
67
+
68
+ for t in range(T):
69
+ frame_mask = scale_mask[t] # N
70
+
71
+ # Reshape if it's flattened
72
+ if frame_mask.dim() == 1:
73
+ h = w = int(frame_mask.shape[0] ** 0.5)
74
+ frame_mask = frame_mask.reshape(h, w)
75
+
76
+ # Resize mask to match current scale
77
+ frame_mask = F.interpolate(
78
+ frame_mask.unsqueeze(0).unsqueeze(0),
79
+ size=(scales[scale_idx], scales[scale_idx]),
80
+ mode='nearest'
81
+ ).squeeze()
82
+
83
+ frame_mask = frame_mask.cpu().float().numpy()
84
+
85
+ # Resize frame to match mask dimensions
86
+ frame = video[t] # C * H * W
87
+ scale_frame = F.interpolate(
88
+ torch.from_numpy(frame).unsqueeze(0),
89
+ size=(scales[scale_idx], scales[scale_idx]),
90
+ mode='bicubic',
91
+ align_corners=False
92
+ ).squeeze().clamp(0, 1).numpy()
93
+
94
+ masked_frame = scale_frame * (0.8 * frame_mask[None, :, :] + 0.2)
95
+
96
+ # Plot this frame's masked image
97
+ axes[scale_idx + 1, t].imshow(masked_frame.transpose(1, 2, 0))
98
+
99
+ # Add red borders around gazed patches
100
+ original_mask = gazing_mask[scale_idx][t]
101
+ if original_mask.dim() == 1:
102
+ patch_grid_size = int(original_mask.shape[0] ** 0.5)
103
+ original_mask = original_mask.reshape(patch_grid_size, patch_grid_size)
104
+
105
+ patch_size = scales[scale_idx] // patch_grid_size
106
+ for i in range(patch_grid_size):
107
+ for j in range(patch_grid_size):
108
+ if original_mask[i, j] > 0.5: # If this patch is gazed at
109
+ rect = plt.Rectangle((j * patch_size - 0.5, i * patch_size - 0.5),
110
+ patch_size, patch_size,
111
+ linewidth=1, edgecolor='red', facecolor='none')
112
+ axes[scale_idx + 1, t].add_patch(rect)
113
+
114
+ axes[scale_idx + 1, t].set_title(f'Scale {scales[scale_idx]} Frame {t+1}')
115
+ axes[scale_idx + 1, t].axis('off')
116
+
117
+ # plot the reconstruction
118
+ for t in range(T):
119
+ frame = reconstruction[t].transpose(1, 2, 0) # C * H * W -> H * W * C
120
+ axes[num_scales + 1, t].imshow(frame)
121
+ axes[num_scales + 1, t].set_title(f'Reconstructed Frame {t+1}')
122
+ axes[num_scales + 1, t].axis('off')
123
+
124
+ # Adjust layout and log to wandb
125
+ plt.tight_layout()
126
+ wandb.log({
127
+ "visualize_gaze/visualize_step": self.visualize_step,
128
+ "visualize_gaze/visualize_gaze": wandb.Image(plt)
129
+ })
130
+
131
+ # Close the figure to free memory
132
+ plt.close(fig)
133
+
134
+ self.visualize_step += 1
autogaze/utils.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import builtins
2
+
3
+ from omegaconf import OmegaConf
4
+ from loguru import logger
5
+ import sys
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+ import wandb
10
+ import random
11
+ from torch.nn.parallel import DistributedDataParallel as DDP
12
+
13
+
14
+ class UnNormalize(object):
15
+ def __init__(self, mean, std, rescale_factor=None):
16
+ self.mean = mean
17
+ self.std = std
18
+ self.rescale_factor = rescale_factor
19
+
20
+ def __call__(self, image):
21
+ image2 = torch.clone(image)
22
+ dims = len(image2.shape)
23
+ if dims == 3:
24
+ image2 = image2.unsqueeze(0)
25
+ image2 = image2.permute(1, 0, 2, 3)
26
+ for t, m, s in zip(image2, self.mean, self.std):
27
+ t.mul_(s).add_(m)
28
+ image2 = image2.permute(1, 0, 2, 3)
29
+ if dims == 3:
30
+ image2 = image2.squeeze(0)
31
+
32
+ if self.rescale_factor is not None:
33
+ standard_rescale = 1.0 / 255.0
34
+ if abs(self.rescale_factor - standard_rescale) > 1e-6:
35
+ # if the processor uses 1/127.5, needs /2.0 + 0.5 correction
36
+ image2 = image2 / 2.0 + 0.5
37
+
38
+ return torch.clamp(image2, 0, 1)
39
+
40
+
41
+ class AverageScalarMeter(object):
42
+ def __init__(self, window_size):
43
+ self.window_size = window_size
44
+ self.current_size = 0
45
+ self.mean = 0
46
+
47
+ def update(self, values):
48
+ size = values.size()[0]
49
+ if size == 0:
50
+ return
51
+ new_mean = torch.mean(values.float(), dim=0).cpu().numpy().item()
52
+ size = np.clip(size, 0, self.window_size)
53
+ old_size = min(self.window_size - size, self.current_size)
54
+ size_sum = old_size + size
55
+ self.current_size = size_sum
56
+ self.mean = (self.mean * old_size + new_mean * size) / size_sum
57
+
58
+ def clear(self):
59
+ self.current_size = 0
60
+ self.mean = 0
61
+
62
+ def __len__(self):
63
+ return self.current_size
64
+
65
+ def get_mean(self):
66
+ return self.mean
67
+
68
+
69
+ def plot_grad_norms(named_parameters, name_prefix=''):
70
+ for name, param in named_parameters:
71
+ if param.grad is not None:
72
+ norm = torch.linalg.vector_norm(param.grad, 2.0).item()
73
+ wandb.log({f'{name_prefix}{name}': norm})
74
+
75
+
76
+ def suppress_print():
77
+ """Suppresses printing from the current process."""
78
+ def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
79
+ pass
80
+ builtins.print = ignore
81
+
82
+
83
+ def suppress_wandb():
84
+ """Suppresses wandb logging from the current_process."""
85
+ # Store original functions
86
+ original_functions = {}
87
+ for attr_name in dir(wandb):
88
+ attr = getattr(wandb, attr_name)
89
+ if callable(attr) and not attr_name.startswith('__'):
90
+ original_functions[attr_name] = attr
91
+
92
+ # Replace with no-op function
93
+ def make_noop(name):
94
+ def noop(*args, **kwargs):
95
+ pass
96
+ return noop
97
+
98
+ setattr(wandb, attr_name, make_noop(attr_name))
99
+
100
+
101
+ def suppress_logging():
102
+ """Suppresses loguru logging from the current process."""
103
+ logger.remove() # Remove all handlers
104
+ logger.add(lambda _: None) # Add a no-op handler
105
+
106
+
107
+ def dump_cfg(cfg, logdir):
108
+ out_f = os.path.join(logdir, "config.yaml")
109
+ with open(out_f, "w") as f:
110
+ f.write(OmegaConf.to_yaml(cfg))
111
+ print("Wrote config to: {}".format(out_f))
112
+
113
+
114
+ def get_scheduled_temperature(step, total_steps, temp_schedule_args):
115
+ if temp_schedule_args['mode'] == 'exp':
116
+ t_start = temp_schedule_args['exp']['temp_start']
117
+ t_end = temp_schedule_args['exp']['temp_end']
118
+ return t_start * (t_end / t_start) ** (step / total_steps)
119
+ else:
120
+ raise ValueError(f"Unknown temp_schedule_args: {temp_schedule_args}")
121
+
122
+
123
+ def seed_everything(seed: int):
124
+ random.seed(seed)
125
+ os.environ['PYTHONHASHSEED'] = str(seed)
126
+ np.random.seed(seed)
127
+ torch.manual_seed(seed)
128
+ if torch.cuda.is_available():
129
+ torch.cuda.manual_seed(seed)
130
+ torch.backends.cudnn.deterministic = True
131
+ torch.backends.cudnn.benchmark = False
132
+ if hasattr(torch, 'mps') and torch.backends.mps.is_available():
133
+ torch.mps.manual_seed(seed)
134
+
135
+
136
+ def seed_worker(worker_id):
137
+ worker_seed = torch.initial_seed() % 2**32
138
+ np.random.seed(worker_seed)
139
+ random.seed(worker_seed)
140
+ torch.manual_seed(worker_seed + worker_id) # Add worker_id to make it different
141
+
142
+
143
+ def format_kwargs(cfg, optional_args):
144
+ return {
145
+ arg_name: getattr(getattr(cfg, section), attr)
146
+ for arg_name, section, attr in optional_args
147
+ if hasattr(cfg, section) and hasattr(getattr(cfg, section), attr)
148
+ }
149
+
150
+
151
+ def move_inputs_to_cuda(inputs):
152
+ for k, v in inputs.items():
153
+ if isinstance(v, torch.Tensor):
154
+ inputs[k] = v.cuda()
155
+ elif isinstance(v, dict):
156
+ inputs[k] = move_inputs_to_cuda(v)
157
+ return inputs
158
+
159
+
160
+ def unwrap_model(model):
161
+ """Unwrap DDP model if needed."""
162
+ if isinstance(model, DDP):
163
+ return model.module
164
+ return model
165
+
166
+
167
+ def get_gazing_pos_from_gazing_mask(gazing_mask: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Get the gazing positions from the gazing mask.
170
+ inputs:
171
+ gazing_mask: (B, N). 1 means gazed, 0 means not gazed.
172
+ outputs:
173
+ gazing_pos: (B, K). K is the maximum number of gazed tokens per instance. If the instance has less than K gazed tokens, the remaining positions are padded with -1.
174
+ if_padded_gazing: (B, K). 1 means padded, 0 means not padded.
175
+ """
176
+ # x: (B, N) with 0/1 values (float/bool/int all fine)
177
+ gazing_mask = gazing_mask.to(torch.long)
178
+ B, N = gazing_mask.shape
179
+
180
+ # Indices per row
181
+ idx = torch.arange(N, device=gazing_mask.device).expand(B, N)
182
+
183
+ # Sort key: put ones first, keep original order among ones/zeros
184
+ # - ones get key = idx (0..N-1)
185
+ # - zeros get key = N + idx (pushed after all ones)
186
+ key = (1 - gazing_mask) * N + idx
187
+ order = key.argsort(dim=1, stable=True) # (B, N)
188
+ sorted_idx = idx.gather(1, order) # ones first, then zeros
189
+
190
+ # Max number of ones (K) and per-row counts
191
+ counts = gazing_mask.sum(dim=1) # (B,)
192
+ K = int(counts.max().item())
193
+
194
+ if K == 0:
195
+ return sorted_idx[:, :0] # (B, 0) empty result
196
+
197
+ topk = sorted_idx[:, :K] # (B, K)
198
+ pos = torch.arange(K, device=gazing_mask.device).expand(B, K)
199
+ mask = pos < counts.unsqueeze(1) # True where a real "1" exists
200
+
201
+ # Pad with -1 where the row has fewer than K ones
202
+ gazing_pos = topk.masked_fill(~mask, -1)
203
+ if_padded_gazing = (gazing_pos == -1)
204
+
205
+ return gazing_pos, if_padded_gazing
demo_utils.py CHANGED
@@ -1,5 +1,10 @@
1
- import sys
2
- import os
 
 
 
 
 
3
  import torch
4
  import torch.nn.functional as F
5
  import numpy as np
@@ -9,19 +14,11 @@ from transformers import VivitImageProcessor
9
  from PIL import Image, ImageDraw, ImageFont
10
  from omegaconf import OmegaConf
11
  from einops import rearrange
12
-
13
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'gengaze'))
14
  from autogaze.models.autogaze import AutoGaze
15
  from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch
16
  from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction
17
  from autogaze.utils import UnNormalize
18
- from tqdm import trange
19
-
20
- try:
21
- import spaces
22
- ZEROGPU_AVAILABLE = True
23
- except ImportError:
24
- ZEROGPU_AVAILABLE = False
25
 
26
 
27
  def image_to_video(image_path, output_path, fps):
@@ -218,24 +215,29 @@ def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.
218
  progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%")
219
  yield None
220
 
221
- # Extract mini-batch from CPU and move to GPU: (batch_size, nt, 16, C, H, W)
222
  spatial_batch = video_chunks[start_idx:end_idx].to(device)
223
- # Flatten to (batch_size * nt, 16, C, H, W) for model
224
  spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w')
225
  print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks')
226
 
227
  # Run AutoGaze on this mini-batch
228
  batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement)
229
 
 
 
 
230
  # Free GPU memory after forward pass
231
  del spatial_batch
232
 
233
  # Count gazing tokens for this batch
234
  if_padded = batch_gaze_output.get('if_padded_gazing')
235
  if if_padded is not None:
236
- total_gazing_tokens += (~if_padded).sum().item()
 
 
237
  else:
238
- total_gazing_tokens += (batch_gaze_output['gazing_pos'] < (196 * 16)).sum().item()
 
 
239
 
240
  # Store the output
241
  all_gaze_outputs.append(batch_gaze_output)
@@ -283,7 +285,7 @@ def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.
283
  # Clean up mini-batch outputs
284
  del all_gaze_outputs
285
 
286
- total_possible_tokens = 196 * 16 * num_chunks
287
 
288
  # Extract gazing masks for later visualization (already in batched form)
289
  gazing_masks_batched = gaze_output['gazing_mask'] # List of 4 scales, each (num_chunks, 16, num_patches)
 
1
+ # IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.)
2
+ try:
3
+ import spaces
4
+ ZEROGPU_AVAILABLE = True
5
+ except ImportError:
6
+ ZEROGPU_AVAILABLE = False
7
+
8
  import torch
9
  import torch.nn.functional as F
10
  import numpy as np
 
14
  from PIL import Image, ImageDraw, ImageFont
15
  from omegaconf import OmegaConf
16
  from einops import rearrange
17
+ from tqdm import trange
 
18
  from autogaze.models.autogaze import AutoGaze
19
  from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch
20
  from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction
21
  from autogaze.utils import UnNormalize
 
 
 
 
 
 
 
22
 
23
 
24
  def image_to_video(image_path, output_path, fps):
 
215
  progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%")
216
  yield None
217
 
 
218
  spatial_batch = video_chunks[start_idx:end_idx].to(device)
 
219
  spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w')
220
  print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks')
221
 
222
  # Run AutoGaze on this mini-batch
223
  batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement)
224
 
225
+ num_gazing_each_frame = batch_gaze_output['num_gazing_each_frame'][:T]
226
+ num_gazing_total = num_gazing_each_frame.sum().item()
227
+
228
  # Free GPU memory after forward pass
229
  del spatial_batch
230
 
231
  # Count gazing tokens for this batch
232
  if_padded = batch_gaze_output.get('if_padded_gazing')
233
  if if_padded is not None:
234
+ print(f'shape of if_padded: {if_padded.shape}')
235
+ if_padded = if_padded[:, :min(num_gazing_total, if_padded.shape[1])]
236
+ new_gazing_tokens = (~if_padded).sum().item()
237
  else:
238
+ new_gazing_tokens = (batch_gaze_output['gazing_pos'] < (196 * T)).sum().item()
239
+ total_gazing_tokens += new_gazing_tokens
240
+ print(f'Batch {batch_idx+1}: Gazing tokens = {new_gazing_tokens}, Total gazing tokens so far = {total_gazing_tokens}')
241
 
242
  # Store the output
243
  all_gaze_outputs.append(batch_gaze_output)
 
285
  # Clean up mini-batch outputs
286
  del all_gaze_outputs
287
 
288
+ total_possible_tokens = 196 * min(T, 16) * num_chunks
289
 
290
  # Extract gazing masks for later visualization (already in batched form)
291
  gazing_masks_batched = gaze_output['gazing_mask'] # List of 4 scales, each (num_chunks, 16, num_patches)
packages.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git
2
+ git-lfs
3
+ ffmpeg
4
+ pkg-config
5
+ libavcodec-dev
6
+ libavformat-dev
7
+ libavutil-dev
8
+ libswscale-dev
9
+ libswresample-dev
10
+ libavdevice-dev
11
+ libavfilter-dev
12
+ libsm6
13
+ libxext6
14
+ cmake
15
+ rsync
16
+ libgl1
requirements.txt CHANGED
@@ -11,5 +11,5 @@ tqdm==4.67.1
11
  transformers==4.53.0
12
  omegaconf==2.3.0
13
  einops==0.8.1
14
- av==14.4.0
15
- imageio==2.37.0
 
11
  transformers==4.53.0
12
  omegaconf==2.3.0
13
  einops==0.8.1
14
+ av
15
+ imageio[ffmpeg]==2.37.0