Spaces:
Running on Zero
Running on Zero
update
Browse files- Dockerfile +35 -0
- README.md +4 -6
- app.py +18 -8
- autogaze/__init__.py +1 -0
- autogaze/__pycache__/__init__.cpython-310.pyc +0 -0
- autogaze/__pycache__/utils.cpython-310.pyc +0 -0
- autogaze/datasets/__init__.py +1 -0
- autogaze/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- autogaze/datasets/__pycache__/video_utils.cpython-310.pyc +0 -0
- autogaze/datasets/video_utils.py +133 -0
- autogaze/models/__init__.py +1 -0
- autogaze/models/__pycache__/__init__.cpython-310.pyc +0 -0
- autogaze/models/autogaze/__init__.py +17 -0
- autogaze/models/autogaze/__pycache__/__init__.cpython-310.pyc +0 -0
- autogaze/models/autogaze/__pycache__/autogaze.cpython-310.pyc +0 -0
- autogaze/models/autogaze/__pycache__/configuration_autogaze.cpython-310.pyc +0 -0
- autogaze/models/autogaze/__pycache__/modeling_autogaze.cpython-310.pyc +0 -0
- autogaze/models/autogaze/__pycache__/modeling_llama_multi_token_pred.cpython-310.pyc +0 -0
- autogaze/models/autogaze/autogaze.py +432 -0
- autogaze/models/autogaze/configuration_autogaze.py +326 -0
- autogaze/models/autogaze/modeling_autogaze.py +431 -0
- autogaze/models/autogaze/modeling_llama_multi_token_pred.py +471 -0
- autogaze/tasks/__init__.py +1 -0
- autogaze/tasks/__pycache__/__init__.cpython-310.pyc +0 -0
- autogaze/tasks/video_mae_reconstruction/__init__.py +1 -0
- autogaze/tasks/video_mae_reconstruction/__pycache__/__init__.cpython-310.pyc +0 -0
- autogaze/tasks/video_mae_reconstruction/__pycache__/configuration_video_mae.cpython-310.pyc +0 -0
- autogaze/tasks/video_mae_reconstruction/__pycache__/modeling_video_mae.cpython-310.pyc +0 -0
- autogaze/tasks/video_mae_reconstruction/__pycache__/task_video_mae_reconstruction.cpython-310.pyc +0 -0
- autogaze/tasks/video_mae_reconstruction/__pycache__/visualize_video_mae_reconstruction.cpython-310.pyc +0 -0
- autogaze/tasks/video_mae_reconstruction/configuration_video_mae.py +159 -0
- autogaze/tasks/video_mae_reconstruction/modeling_video_mae.py +1412 -0
- autogaze/tasks/video_mae_reconstruction/task_video_mae_reconstruction.py +182 -0
- autogaze/tasks/video_mae_reconstruction/visualize_video_mae_reconstruction.py +134 -0
- autogaze/utils.py +205 -0
- demo_utils.py +18 -16
- packages.txt +16 -0
- requirements.txt +2 -2
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
git \
|
| 7 |
+
git-lfs \
|
| 8 |
+
ffmpeg \
|
| 9 |
+
pkg-config \
|
| 10 |
+
libavcodec-dev \
|
| 11 |
+
libavformat-dev \
|
| 12 |
+
libavutil-dev \
|
| 13 |
+
libswscale-dev \
|
| 14 |
+
libswresample-dev \
|
| 15 |
+
libavdevice-dev \
|
| 16 |
+
libavfilter-dev \
|
| 17 |
+
libsm6 \
|
| 18 |
+
libxext6 \
|
| 19 |
+
cmake \
|
| 20 |
+
rsync \
|
| 21 |
+
libgl1 \
|
| 22 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 23 |
+
&& git lfs install
|
| 24 |
+
|
| 25 |
+
COPY requirements.txt .
|
| 26 |
+
|
| 27 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 28 |
+
|
| 29 |
+
COPY . .
|
| 30 |
+
|
| 31 |
+
EXPOSE 7860
|
| 32 |
+
|
| 33 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 34 |
+
|
| 35 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,14 +1,12 @@
|
|
| 1 |
---
|
| 2 |
title: AutoGaze
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
-
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
-
short_description: AutoGaze can remove redundant patches in any video.
|
| 12 |
---
|
| 13 |
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: AutoGaze
|
| 3 |
+
emoji: 👀
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.5.1
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import tempfile
|
| 3 |
import os
|
|
@@ -8,13 +16,6 @@ import av
|
|
| 8 |
from PIL import Image
|
| 9 |
import numpy as np
|
| 10 |
|
| 11 |
-
try:
|
| 12 |
-
import spaces
|
| 13 |
-
ZEROGPU_AVAILABLE = True
|
| 14 |
-
except ImportError:
|
| 15 |
-
ZEROGPU_AVAILABLE = False
|
| 16 |
-
print("Warning: spaces module not available. Running without ZeroGPU support.")
|
| 17 |
-
|
| 18 |
model_cache = {}
|
| 19 |
|
| 20 |
def get_model(device):
|
|
@@ -22,7 +23,16 @@ def get_model(device):
|
|
| 22 |
model_cache[device] = load_model(device=device)
|
| 23 |
return model_cache[device]
|
| 24 |
|
| 25 |
-
device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def cleanup_gpu():
|
| 28 |
"""Clean up GPU memory."""
|
|
|
|
| 1 |
+
# IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.)
|
| 2 |
+
try:
|
| 3 |
+
import spaces
|
| 4 |
+
ZEROGPU_AVAILABLE = True
|
| 5 |
+
except ImportError:
|
| 6 |
+
ZEROGPU_AVAILABLE = False
|
| 7 |
+
print("Warning: spaces module not available. Running without ZeroGPU support.")
|
| 8 |
+
|
| 9 |
import gradio as gr
|
| 10 |
import tempfile
|
| 11 |
import os
|
|
|
|
| 16 |
from PIL import Image
|
| 17 |
import numpy as np
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
model_cache = {}
|
| 20 |
|
| 21 |
def get_model(device):
|
|
|
|
| 23 |
model_cache[device] = load_model(device=device)
|
| 24 |
return model_cache[device]
|
| 25 |
|
| 26 |
+
# Determine device: use CUDA if available locally or if ZeroGPU will provide it
|
| 27 |
+
if ZEROGPU_AVAILABLE:
|
| 28 |
+
device = "cuda" # ZeroGPU will provide GPU
|
| 29 |
+
print("Using ZeroGPU (CUDA device will be allocated on demand)")
|
| 30 |
+
elif torch.cuda.is_available():
|
| 31 |
+
device = "cuda"
|
| 32 |
+
print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
|
| 33 |
+
else:
|
| 34 |
+
device = "cpu"
|
| 35 |
+
print("No GPU available, using CPU")
|
| 36 |
|
| 37 |
def cleanup_gpu():
|
| 38 |
"""Clean up GPU memory."""
|
autogaze/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""AutoGaze package for video patch reduction."""
|
autogaze/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
autogaze/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
autogaze/datasets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""AutoGaze datasets and utilities."""
|
autogaze/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
autogaze/datasets/__pycache__/video_utils.cpython-310.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
autogaze/datasets/video_utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Common utilities for video loading and processing."""
|
| 2 |
+
|
| 3 |
+
import av
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_relative_video_path(path):
|
| 9 |
+
"""
|
| 10 |
+
Get the last three levels of the path as the relative path to the video.
|
| 11 |
+
Args:
|
| 12 |
+
path (str): Path to get the last three levels of.
|
| 13 |
+
Returns:
|
| 14 |
+
last_three (str): Last three levels of the path.
|
| 15 |
+
"""
|
| 16 |
+
parts = path.replace("\\", "/").split("/")
|
| 17 |
+
return "/".join(parts[-3:]) if len(parts) >= 3 else path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def read_video_pyav(container, indices):
|
| 21 |
+
"""
|
| 22 |
+
Decode the video with PyAV decoder.
|
| 23 |
+
Args:
|
| 24 |
+
container (`av.container.input.InputContainer`): PyAV container.
|
| 25 |
+
indices (`List[int]`): List of frame indices to decode.
|
| 26 |
+
Returns:
|
| 27 |
+
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
| 28 |
+
"""
|
| 29 |
+
frames = []
|
| 30 |
+
container.seek(0)
|
| 31 |
+
start_index = indices[0]
|
| 32 |
+
end_index = indices[-1]
|
| 33 |
+
for i, frame in enumerate(container.decode(video=0)):
|
| 34 |
+
if i > end_index:
|
| 35 |
+
break
|
| 36 |
+
if i >= start_index and i in indices:
|
| 37 |
+
frames.append(frame)
|
| 38 |
+
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def sample_frame_indices(clip_len, frame_sample_rate, seg_len, random_sample_frame=False):
|
| 42 |
+
"""
|
| 43 |
+
Sample a given number of frame indices from the video.
|
| 44 |
+
Args:
|
| 45 |
+
clip_len (`int`): Total number of frames to sample.
|
| 46 |
+
frame_sample_rate (`int`): Sample every n-th frame.
|
| 47 |
+
seg_len (`int`): Maximum allowed index of sample's last frame.
|
| 48 |
+
Returns:
|
| 49 |
+
indices (`List[int]`): List of sampled frame indices
|
| 50 |
+
"""
|
| 51 |
+
converted_len = int(clip_len * frame_sample_rate)
|
| 52 |
+
if seg_len <= converted_len:
|
| 53 |
+
# Not enough frames, just return the first clip_len frames (or as many as possible)
|
| 54 |
+
indices = np.arange(min(clip_len, seg_len))
|
| 55 |
+
indices = np.pad(indices, (0, max(0, clip_len - len(indices))), mode="edge")
|
| 56 |
+
return indices.astype(np.int64)
|
| 57 |
+
if random_sample_frame:
|
| 58 |
+
end_idx = np.random.randint(converted_len, seg_len)
|
| 59 |
+
start_idx = end_idx - converted_len
|
| 60 |
+
else:
|
| 61 |
+
start_idx = 0
|
| 62 |
+
end_idx = converted_len
|
| 63 |
+
indices = np.linspace(start_idx, end_idx, num=clip_len)
|
| 64 |
+
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
| 65 |
+
return indices
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def process_video_frames(video, clip_len):
|
| 69 |
+
"""
|
| 70 |
+
Process video frames to ensure correct shape and length.
|
| 71 |
+
Args:
|
| 72 |
+
video (np.ndarray): Video frames of shape (num_frames, H, W, 3)
|
| 73 |
+
clip_len (int): Target number of frames
|
| 74 |
+
Returns:
|
| 75 |
+
video (np.ndarray): Processed video of shape (clip_len, H, W, 3)
|
| 76 |
+
"""
|
| 77 |
+
# Ensure video has shape (clip_len, H, W, 3)
|
| 78 |
+
if video.shape[0] != clip_len:
|
| 79 |
+
# Pad or repeat last frame if needed
|
| 80 |
+
if video.shape[0] < clip_len:
|
| 81 |
+
pad_frames = clip_len - video.shape[0]
|
| 82 |
+
last_frame = video[-1:]
|
| 83 |
+
video = np.concatenate(
|
| 84 |
+
[video, np.repeat(last_frame, pad_frames, axis=0)], axis=0
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
video = video[:clip_len]
|
| 88 |
+
|
| 89 |
+
assert video.shape[0] == clip_len, (
|
| 90 |
+
f"Video has {video.shape[0]} frames, expected {clip_len}"
|
| 91 |
+
)
|
| 92 |
+
assert video.ndim == 4 and video.shape[-1] == 3, (
|
| 93 |
+
f"Video shape is {video.shape}, expected (clip_len, H, W, 3)"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return video
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def transform_video_for_pytorch(video, transform=None):
|
| 100 |
+
"""
|
| 101 |
+
Transform video frames and convert to PyTorch format.
|
| 102 |
+
Args:
|
| 103 |
+
video (np.ndarray): Video frames of shape (clip_len, H, W, 3)
|
| 104 |
+
transform: Optional transform to apply
|
| 105 |
+
Returns:
|
| 106 |
+
img (np.ndarray): Transformed video of shape (clip_len, C, H, W)
|
| 107 |
+
"""
|
| 108 |
+
if transform is not None:
|
| 109 |
+
imgs = transform(list(video)).pixel_values
|
| 110 |
+
if isinstance(imgs[0], list): # frames are wrapped in a python list
|
| 111 |
+
img = imgs[0]
|
| 112 |
+
else:
|
| 113 |
+
img = imgs # frames are not wrapped in a python list
|
| 114 |
+
img = np.stack(img)
|
| 115 |
+
else:
|
| 116 |
+
img = video # fallback: return raw video
|
| 117 |
+
|
| 118 |
+
# Ensure output is (clip_len, C, H, W) for pytorch
|
| 119 |
+
if img.shape[1] == 3 and img.shape[-1] != 3:
|
| 120 |
+
# Already (clip_len, C, H, W)
|
| 121 |
+
pass
|
| 122 |
+
elif img.shape[-1] == 3:
|
| 123 |
+
# (clip_len, H, W, 3) -> (clip_len, 3, H, W)
|
| 124 |
+
img = np.transpose(img, (0, 3, 1, 2))
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"Unexpected image shape after transform: {img.shape}")
|
| 127 |
+
|
| 128 |
+
clip_len = img.shape[0]
|
| 129 |
+
assert img.shape[0] == clip_len and img.shape[1] == 3, (
|
| 130 |
+
f"Output img shape: {img.shape}"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return torch.tensor(img)
|
autogaze/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""AutoGaze models."""
|
autogaze/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (178 Bytes). View file
|
|
|
autogaze/models/autogaze/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .autogaze import AutoGaze
|
| 2 |
+
from .configuration_autogaze import (
|
| 3 |
+
AutoGazeConfig,
|
| 4 |
+
GazeModelConfig,
|
| 5 |
+
VisionModelConfig,
|
| 6 |
+
ConnectorConfig,
|
| 7 |
+
GazeDecoderConfig,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"AutoGaze",
|
| 12 |
+
"AutoGazeConfig",
|
| 13 |
+
"GazeModelConfig",
|
| 14 |
+
"VisionModelConfig",
|
| 15 |
+
"ConnectorConfig",
|
| 16 |
+
"GazeDecoderConfig",
|
| 17 |
+
]
|
autogaze/models/autogaze/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (420 Bytes). View file
|
|
|
autogaze/models/autogaze/__pycache__/autogaze.cpython-310.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
autogaze/models/autogaze/__pycache__/configuration_autogaze.cpython-310.pyc
ADDED
|
Binary file (9.83 kB). View file
|
|
|
autogaze/models/autogaze/__pycache__/modeling_autogaze.cpython-310.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
autogaze/models/autogaze/__pycache__/modeling_llama_multi_token_pred.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
autogaze/models/autogaze/autogaze.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
from autogaze.utils import get_gazing_pos_from_gazing_mask
|
| 13 |
+
from .modeling_autogaze import AutoGazeModel
|
| 14 |
+
from .configuration_autogaze import AutoGazeConfig
|
| 15 |
+
|
| 16 |
+
class AutoGaze(PreTrainedModel):
|
| 17 |
+
config_class = AutoGazeConfig
|
| 18 |
+
|
| 19 |
+
def __init__(self, config: AutoGazeConfig):
|
| 20 |
+
super().__init__(config)
|
| 21 |
+
|
| 22 |
+
self.config = config
|
| 23 |
+
self.gazing_ratio_config = config.gazing_ratio_config
|
| 24 |
+
self.gazing_ratio_each_frame_config = config.gazing_ratio_each_frame_config
|
| 25 |
+
self.scales = sorted([int(scale) for scale in str(config.scales).split('+')])
|
| 26 |
+
self.num_vision_tokens_each_frame = config.num_vision_tokens_each_frame
|
| 27 |
+
self.num_vision_tokens_each_scale_each_frame = [int(scale**2 / sum([scale**2 for scale in self.scales]) * self.num_vision_tokens_each_frame) for scale in self.scales]
|
| 28 |
+
self.frame_sampling_rate = config.gaze_model_config.vision_model_config.temporal_patch_size
|
| 29 |
+
self.image_mean = config.image_mean
|
| 30 |
+
self.image_std = config.image_std
|
| 31 |
+
self.attn_mode = config.attn_mode
|
| 32 |
+
|
| 33 |
+
# Create the gazing model
|
| 34 |
+
self.gazing_model = AutoGazeModel(config.gaze_model_config)
|
| 35 |
+
|
| 36 |
+
# Task loss requirement
|
| 37 |
+
self.has_task_loss_requirement_during_training = config.has_task_loss_requirement_during_training
|
| 38 |
+
self.has_task_loss_requirement_during_inference = config.has_task_loss_requirement_during_inference
|
| 39 |
+
self.task_loss_requirement_config = config.task_loss_requirement_config
|
| 40 |
+
|
| 41 |
+
def get_gazing_ratio(self, sync_across_ranks=True):
|
| 42 |
+
"""
|
| 43 |
+
Sample the gazing ratio for the whole video according to the config.
|
| 44 |
+
"""
|
| 45 |
+
sample_strategy = self.gazing_ratio_config['sample_strategy_during_training'] if self.training else self.gazing_ratio_config['sample_strategy_during_inference']
|
| 46 |
+
if sample_strategy == 'fixed':
|
| 47 |
+
ratio = self.gazing_ratio_config['fixed']['gazing_ratio']
|
| 48 |
+
elif sample_strategy == 'uniform':
|
| 49 |
+
ratio = random.uniform(self.gazing_ratio_config['uniform']['gazing_ratio_min'], self.gazing_ratio_config['uniform']['gazing_ratio_max'])
|
| 50 |
+
elif sample_strategy == 'exponential':
|
| 51 |
+
ratio = random.expovariate(self.gazing_ratio_config['exponential']['lambda'])
|
| 52 |
+
while ratio < self.gazing_ratio_config['exponential']['gazing_ratio_min'] or ratio > self.gazing_ratio_config['exponential']['gazing_ratio_max']:
|
| 53 |
+
ratio = random.expovariate(self.gazing_ratio_config['exponential']['lambda'])
|
| 54 |
+
|
| 55 |
+
if sync_across_ranks:
|
| 56 |
+
ratio = torch.tensor(ratio).cuda()
|
| 57 |
+
if torch.distributed.is_initialized():
|
| 58 |
+
torch.distributed.broadcast(ratio, src=0) # Make every rank use the same gazing ratio. Otherwise, each rank will have different gazing ratio, and the train/inference time is bounded by the slowest rank (with highest gazing ratio).
|
| 59 |
+
ratio = ratio.item()
|
| 60 |
+
|
| 61 |
+
return ratio
|
| 62 |
+
|
| 63 |
+
def get_gazing_ratio_each_frame(self, inputs, video, gazing_ratio_mean, num_frames, temperature, use_cache):
|
| 64 |
+
"""
|
| 65 |
+
Sample the gazing ratio for each frame according to the config.
|
| 66 |
+
"""
|
| 67 |
+
sample_strategy = self.gazing_ratio_each_frame_config['sample_strategy_during_training'] if self.training else self.gazing_ratio_each_frame_config['sample_strategy_during_inference']
|
| 68 |
+
if sample_strategy == 'uniform':
|
| 69 |
+
gazing_ratio_each_frame = torch.ones(num_frames) * gazing_ratio_mean
|
| 70 |
+
elif sample_strategy == 'dirichlet':
|
| 71 |
+
gazing_ratio_agg = gazing_ratio_mean * num_frames
|
| 72 |
+
alpha = self.gazing_ratio_each_frame_config['dirichlet']['alpha']
|
| 73 |
+
if isinstance(alpha, str):
|
| 74 |
+
alpha = torch.tensor([float(a) for a in alpha.split(',')])
|
| 75 |
+
assert len(alpha) == num_frames, "The number of alpha values must be equal to the number of frames"
|
| 76 |
+
gazing_ratio_each_frame = torch.distributions.dirichlet.Dirichlet(torch.ones(num_frames) * alpha).sample() * gazing_ratio_agg
|
| 77 |
+
gazing_ratio_each_frame = gazing_ratio_each_frame.clamp(min=0, max=1)
|
| 78 |
+
elif sample_strategy == 'self':
|
| 79 |
+
assert use_cache == False, "using cache is not supported for self-predicted gazing ratio"
|
| 80 |
+
|
| 81 |
+
# Only preserve one sample for each group
|
| 82 |
+
if "group_size" in inputs:
|
| 83 |
+
video = rearrange(video, '(g b) t c h w -> g b t c h w', g=inputs["group_size"])[0]
|
| 84 |
+
|
| 85 |
+
assert video.shape[0] == 1, "Currently only batch_size=1 is supported because otherwise we need to support different gazing ratio constraints in the same batch in model.generate()"
|
| 86 |
+
|
| 87 |
+
# Max gazing ratio for each frame
|
| 88 |
+
max_gazing_ratio_each_frame = torch.ones(num_frames) * gazing_ratio_mean
|
| 89 |
+
max_num_gaze_tokens_each_frame = (max_gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
|
| 90 |
+
|
| 91 |
+
# Sample task loss requirement
|
| 92 |
+
task_loss_requirement = self.get_task_loss_requirement(video, force_sampling=True)
|
| 93 |
+
|
| 94 |
+
# Sample the gazing
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
if self.training:
|
| 97 |
+
gazing_info = self.gazing_model.generate(
|
| 98 |
+
video,
|
| 99 |
+
max_gaze_tokens_each_frame=max_num_gaze_tokens_each_frame,
|
| 100 |
+
task_loss_requirement=task_loss_requirement,
|
| 101 |
+
do_sample=True,
|
| 102 |
+
temperature=temperature,
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
gazing_info = self.gazing_model.generate(
|
| 106 |
+
video,
|
| 107 |
+
max_gaze_tokens_each_frame=max_num_gaze_tokens_each_frame,
|
| 108 |
+
task_loss_requirement=task_loss_requirement,
|
| 109 |
+
do_sample=False,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if_padded_gazing = gazing_info["if_padded_gazing"]
|
| 113 |
+
num_gazing_each_frame = gazing_info["num_gazing_each_frame"]
|
| 114 |
+
if_padded_gazing = if_padded_gazing.split(num_gazing_each_frame.tolist(), dim=1)
|
| 115 |
+
num_non_padded_gazing_each_frame = torch.stack([(~if_padded_gazing[i]).sum(dim=-1) for i in range(len(if_padded_gazing))], dim=1) # (B, num_frames)
|
| 116 |
+
|
| 117 |
+
gazing_ratio_each_frame = num_non_padded_gazing_each_frame[0] / self.num_vision_tokens_each_frame
|
| 118 |
+
else:
|
| 119 |
+
raise NotImplementedError(f"Sample strategy {sample_strategy} not implemented.")
|
| 120 |
+
|
| 121 |
+
return gazing_ratio_each_frame
|
| 122 |
+
|
| 123 |
+
def get_task_loss_requirement(self, video, sync_across_ranks=True, force_sampling=False):
|
| 124 |
+
"""
|
| 125 |
+
Sample the task loss requirement for each frame according to the config.
|
| 126 |
+
|
| 127 |
+
inputs:
|
| 128 |
+
video: tensor of shape (B, T, C, H, W)
|
| 129 |
+
returns:
|
| 130 |
+
task_loss_requirement: tensor of shape (B, T // frame_sampling_rate), representing the task loss requirement for each frame of each video. None if no task loss requirement is used.
|
| 131 |
+
"""
|
| 132 |
+
has_task_loss_requirement = self.has_task_loss_requirement_during_training if self.training else self.has_task_loss_requirement_during_inference
|
| 133 |
+
if not has_task_loss_requirement and not force_sampling:
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
B, T = video.shape[:2]
|
| 137 |
+
sample_strategy = self.task_loss_requirement_config['sample_strategy_during_training'] if self.training else self.task_loss_requirement_config['sample_strategy_during_inference']
|
| 138 |
+
if sample_strategy == 'fixed':
|
| 139 |
+
task_loss_requirement = self.task_loss_requirement_config['fixed']['task_loss_requirement']
|
| 140 |
+
task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
|
| 141 |
+
elif sample_strategy == 'uniform':
|
| 142 |
+
task_loss_requirement_min = self.task_loss_requirement_config['uniform']['task_loss_requirement_min']
|
| 143 |
+
task_loss_requirement_max = self.task_loss_requirement_config['uniform']['task_loss_requirement_max']
|
| 144 |
+
task_loss_requirement = random.uniform(task_loss_requirement_min, task_loss_requirement_max)
|
| 145 |
+
task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
|
| 146 |
+
else:
|
| 147 |
+
raise NotImplementedError(f"Task loss requirement sample strategy {self.task_loss_requirement_config['sample_strategy']} not implemented")
|
| 148 |
+
|
| 149 |
+
if sync_across_ranks:
|
| 150 |
+
if torch.distributed.is_initialized():
|
| 151 |
+
torch.distributed.broadcast(task_loss_requirement, src=0) # Make every rank use the same gazing ratio. Otherwise, each rank will have different gazing ratio, and the train/inference time is bounded by the slowest rank (with highest gazing ratio).
|
| 152 |
+
|
| 153 |
+
return task_loss_requirement
|
| 154 |
+
|
| 155 |
+
def get_mask_from_gazing_pos(self, video, gazing_pos, if_padded_gazing):
|
| 156 |
+
"""
|
| 157 |
+
Create the video gazing mask from the gazing positions.
|
| 158 |
+
|
| 159 |
+
inputs:
|
| 160 |
+
video: B, T, C, H, W
|
| 161 |
+
gazing_pos: B, N
|
| 162 |
+
if_padded_gazing: B, N
|
| 163 |
+
returns:
|
| 164 |
+
mask: list of B * T * N_each_scale
|
| 165 |
+
"""
|
| 166 |
+
B, T = video.shape[:2]
|
| 167 |
+
mask = torch.zeros(B, self.num_vision_tokens_each_frame * (T // self.frame_sampling_rate) + 1, device=video.device) # +1 for the padded gazing positions
|
| 168 |
+
tmp_gazing_pos = gazing_pos.clone()
|
| 169 |
+
tmp_gazing_pos[if_padded_gazing] = mask.shape[1] - 1 # Set the padded gazing positions to the last position
|
| 170 |
+
mask[torch.arange(B)[:, None], tmp_gazing_pos] = 1
|
| 171 |
+
mask = mask[:, :-1] # Remove the last position (padded gazing positions)
|
| 172 |
+
mask = mask.reshape(B, T // self.frame_sampling_rate, self.num_vision_tokens_each_frame)
|
| 173 |
+
mask = [mask[:, :, sum(self.num_vision_tokens_each_scale_each_frame[:i]):sum(self.num_vision_tokens_each_scale_each_frame[:i+1])] for i in range(len(self.scales))] # list of B * T * N_each_scale
|
| 174 |
+
|
| 175 |
+
return mask
|
| 176 |
+
|
| 177 |
+
def input_res_adapt(self, pixel_values, target_scales, target_patch_size):
|
| 178 |
+
"""
|
| 179 |
+
Preprocess the input to adapt to the target scales and patch size.
|
| 180 |
+
|
| 181 |
+
inputs:
|
| 182 |
+
pixel_values: B, T, C, H, W
|
| 183 |
+
returns:
|
| 184 |
+
pixel_values: B, T, C, H, W
|
| 185 |
+
res_adapt_info: dict, the information of resolution adaptation, for future recovery.
|
| 186 |
+
"""
|
| 187 |
+
B, T, C, H, W = pixel_values.shape
|
| 188 |
+
assert H == W == target_scales[-1], "Now we need the input video to be the same size as the largest scale of the vision model" # FIXME: in the future we should use relative resize ratio as the scales, e.g., 0.125+0.25+0.5+1. In this way we can also support naflex ViT.
|
| 189 |
+
assert len(self.scales) == len(target_scales), "The scales of the gaze model and the vision model must be the same"
|
| 190 |
+
tile_feature_map_size_each_scale = [int(self.num_vision_tokens_each_scale_each_frame[i] ** 0.5) for i in range(len(self.scales))]
|
| 191 |
+
original_feature_map_height_each_scale = [target_scales[i] // target_patch_size for i in range(len(target_scales))]
|
| 192 |
+
original_feature_map_width_each_scale = [target_scales[i] // target_patch_size for i in range(len(target_scales))]
|
| 193 |
+
num_tiles_height = math.ceil(original_feature_map_height_each_scale[-1] / tile_feature_map_size_each_scale[-1])
|
| 194 |
+
num_tiles_width = math.ceil(original_feature_map_width_each_scale[-1] / tile_feature_map_size_each_scale[-1])
|
| 195 |
+
pad_H = num_tiles_height * tile_feature_map_size_each_scale[-1] * target_patch_size - H
|
| 196 |
+
pad_W = num_tiles_width * tile_feature_map_size_each_scale[-1] * target_patch_size - W
|
| 197 |
+
pixel_values = F.pad(pixel_values, (0, pad_W, 0, pad_H))
|
| 198 |
+
pixel_values = rearrange(pixel_values, 'b t c (nh sh) (nw sw) -> (b nh nw) t c sh sw', nh=num_tiles_height, nw=num_tiles_width)
|
| 199 |
+
res_adapt_info = {
|
| 200 |
+
'tile_feature_map_size_each_scale': tile_feature_map_size_each_scale,
|
| 201 |
+
'original_feature_map_height_each_scale': original_feature_map_height_each_scale,
|
| 202 |
+
'original_feature_map_width_each_scale': original_feature_map_width_each_scale,
|
| 203 |
+
'num_tiles_height': num_tiles_height,
|
| 204 |
+
'num_tiles_width': num_tiles_width,
|
| 205 |
+
'pad_H': pad_H,
|
| 206 |
+
'pad_W': pad_W,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
return pixel_values, res_adapt_info
|
| 210 |
+
|
| 211 |
+
def recover_output_from_res_adapt(self, gaze_outputs, res_adapt_info):
|
| 212 |
+
"""
|
| 213 |
+
Postprocess the output to recover from resolution adaptation.
|
| 214 |
+
|
| 215 |
+
inputs:
|
| 216 |
+
gaze_outputs: dict, the outputs of the gazing model.
|
| 217 |
+
res_adapt_info: dict, the information of resolution adaptation.
|
| 218 |
+
returns:
|
| 219 |
+
gaze_outputs: dict, the outputs of the gazing model.
|
| 220 |
+
"""
|
| 221 |
+
num_tiles_height = res_adapt_info['num_tiles_height']
|
| 222 |
+
num_tiles_width = res_adapt_info['num_tiles_width']
|
| 223 |
+
tile_feature_map_size_each_scale = res_adapt_info['tile_feature_map_size_each_scale']
|
| 224 |
+
original_feature_map_height_each_scale = res_adapt_info['original_feature_map_height_each_scale']
|
| 225 |
+
original_feature_map_width_each_scale = res_adapt_info['original_feature_map_width_each_scale']
|
| 226 |
+
|
| 227 |
+
# Recover the gazing mask. Remove the gazing for the padded regions.
|
| 228 |
+
new_gazing_mask = []
|
| 229 |
+
for scale_idx in range(len(gaze_outputs['scales'])):
|
| 230 |
+
cur_gazing_mask = gaze_outputs['gazing_mask'][scale_idx]
|
| 231 |
+
cur_gazing_mask = rearrange(cur_gazing_mask, '(b nh nw) t (sh sw) -> b t (nh sh) (nw sw)', nh=num_tiles_height, nw=num_tiles_width, sh=tile_feature_map_size_each_scale[scale_idx], sw=tile_feature_map_size_each_scale[scale_idx])
|
| 232 |
+
cur_gazing_mask = cur_gazing_mask[:, :, :original_feature_map_height_each_scale[scale_idx], :original_feature_map_width_each_scale[scale_idx]]
|
| 233 |
+
cur_gazing_mask = cur_gazing_mask.flatten(-2, -1) # (b t (nh sh) (nw sw)) -> (b t (nh sh * nw sw))
|
| 234 |
+
new_gazing_mask.append(cur_gazing_mask)
|
| 235 |
+
|
| 236 |
+
# Recover the num_gazing_each_frame and num_vision_tokens_each_frame
|
| 237 |
+
new_num_vision_tokens_each_frame = sum([mask.shape[-1] for mask in new_gazing_mask])
|
| 238 |
+
|
| 239 |
+
# Recover the gazing pos, if_padded_gazing, and num_gazing_each_frame, by inderring from the gazing mask. Note this will lose the original order of the gazing!
|
| 240 |
+
new_gazing_mask_all_scales = torch.cat(new_gazing_mask, dim=-1) # B, T, N
|
| 241 |
+
B, T = new_gazing_mask_all_scales.shape[:2]
|
| 242 |
+
new_gazing_pos, new_if_padded_gazing = get_gazing_pos_from_gazing_mask(new_gazing_mask_all_scales.flatten(0, 1))
|
| 243 |
+
new_gazing_pos, new_if_padded_gazing = rearrange(new_gazing_pos, '(b t) n -> b t n', b=B, t=T), rearrange(new_if_padded_gazing, '(b t) n -> b t n', b=B, t=T)
|
| 244 |
+
max_num_gazing_each_frame = (~new_if_padded_gazing).sum(dim=-1).max(dim=0)[0]
|
| 245 |
+
assert all([torch.all(new_if_padded_gazing[:, t, num:] == True) for t, num in enumerate(max_num_gazing_each_frame)]), "The removed gazing should all be padded."
|
| 246 |
+
new_gazing_pos = [new_gazing_pos[:, t, :num] for t, num in enumerate(max_num_gazing_each_frame)]
|
| 247 |
+
new_if_padded_gazing = [new_if_padded_gazing[:, t, :num] for t, num in enumerate(max_num_gazing_each_frame)]
|
| 248 |
+
new_gazing_pos = [gazing_pos + new_num_vision_tokens_each_frame * t for t, gazing_pos in enumerate(new_gazing_pos)]
|
| 249 |
+
new_gazing_pos, new_if_padded_gazing = torch.cat(new_gazing_pos, dim=1), torch.cat(new_if_padded_gazing, dim=1)
|
| 250 |
+
new_num_gazing_each_frame = max_num_gazing_each_frame
|
| 251 |
+
|
| 252 |
+
# Update the outputs
|
| 253 |
+
gaze_outputs['gazing_pos'] = new_gazing_pos
|
| 254 |
+
gaze_outputs['gazing_mask'] = new_gazing_mask
|
| 255 |
+
gaze_outputs['frame_sampling_rate'] = gaze_outputs['frame_sampling_rate']
|
| 256 |
+
gaze_outputs['num_vision_tokens_each_frame'] = new_num_vision_tokens_each_frame
|
| 257 |
+
gaze_outputs['num_gazing_each_frame'] = new_num_gazing_each_frame
|
| 258 |
+
gaze_outputs['if_padded_gazing'] = new_if_padded_gazing
|
| 259 |
+
|
| 260 |
+
# Currently we haven't reordered actions probs and task loss prediction based on the new gazing pos, so delete it for now for safety.
|
| 261 |
+
del(gaze_outputs['log_action_probs'])
|
| 262 |
+
del(gaze_outputs['task_loss_prediction'])
|
| 263 |
+
|
| 264 |
+
return gaze_outputs
|
| 265 |
+
|
| 266 |
+
#FIXME: separate forward and generate functions
|
| 267 |
+
def forward(
|
| 268 |
+
self,
|
| 269 |
+
inputs,
|
| 270 |
+
target_scales=None,
|
| 271 |
+
target_patch_size=None,
|
| 272 |
+
target_image_mean=None,
|
| 273 |
+
target_image_std=None,
|
| 274 |
+
gazing_info=None,
|
| 275 |
+
temperature=1,
|
| 276 |
+
gazing_ratio=None,
|
| 277 |
+
task_loss_requirement=None,
|
| 278 |
+
generate_only=False,
|
| 279 |
+
use_cache=False,
|
| 280 |
+
past_key_values=None,
|
| 281 |
+
past_inputs_embeds=None,
|
| 282 |
+
past_attention_mask=None,
|
| 283 |
+
past_conv_values=None,
|
| 284 |
+
):
|
| 285 |
+
"""
|
| 286 |
+
inputs:
|
| 287 |
+
video: B, T, C, H, W
|
| 288 |
+
target_scales: list of scales for downstream vision model. If None, then use the scales in the gaze model.
|
| 289 |
+
target_patch_size: patch size for downstream vision model. If None, then use the patch size in the gaze model.
|
| 290 |
+
target_image_mean: image mean for downstream vision model. If None, then use the image mean in the gaze model.
|
| 291 |
+
target_image_std: image std for downstream vision model. If None, then use the image std in the gaze model.
|
| 292 |
+
gazing_info: dict, the ground truth gazing information for NTP pre-training. If None, then run the gazing model to predict gazing positions.
|
| 293 |
+
temperature: temperature for generating gazing.
|
| 294 |
+
gazing_ratio: gazing ratio for the gazing model. If None, then sample the gazing ratio according to the config.
|
| 295 |
+
task_loss_requirement: task loss requirement for the gazing model. If None, then sample the task loss requirement according to the config.
|
| 296 |
+
generate_only: whether to only generate the gazing positions, or to also calculate the probability of taking such gaze.
|
| 297 |
+
use_cache: whether to use the cache for the gazing model.
|
| 298 |
+
past_key_values: the past key values for the gazing model.
|
| 299 |
+
past_inputs_embeds: the past inputs embeds for the gazing model.
|
| 300 |
+
past_attention_mask: the past attention mask for the gazing model.
|
| 301 |
+
past_conv_values: the past conv values for the gazing model.
|
| 302 |
+
returns:
|
| 303 |
+
to_return: dict, the outputs of the gazing model.
|
| 304 |
+
"""
|
| 305 |
+
if not generate_only:
|
| 306 |
+
assert past_key_values is None and past_inputs_embeds is None and past_attention_mask is None and past_conv_values is None, \
|
| 307 |
+
"If not in generate-only mode, we don't support past_key_values, past_inputs_embeds, past_attention_mask, and past_conv_values yet."
|
| 308 |
+
|
| 309 |
+
video = inputs['video']
|
| 310 |
+
|
| 311 |
+
# Preprocess the input to fix the image mean and std
|
| 312 |
+
if target_image_mean is not None and target_image_std is not None:
|
| 313 |
+
video = rearrange(video, 'b t c h w -> b t h w c')
|
| 314 |
+
video = video * torch.tensor(target_image_std, device=video.device, dtype=video.dtype) + torch.tensor(target_image_mean, device=video.device, dtype=video.dtype)
|
| 315 |
+
video = video * 2 - 1 # Vivit preprocesssor has a rescaling factor of 1/127.5 instead of 1/255, and it has an offset of -1.
|
| 316 |
+
video = (video - torch.tensor(self.image_mean, device=video.device, dtype=video.dtype)) / torch.tensor(self.image_std, device=video.device, dtype=video.dtype)
|
| 317 |
+
video = rearrange(video, 'b t h w c -> b t c h w')
|
| 318 |
+
|
| 319 |
+
# Preprocess the input for resolution adaptation
|
| 320 |
+
if target_scales is not None and target_patch_size is not None:
|
| 321 |
+
if not (target_scales == self.scales and [(scale // target_patch_size) ** 2 for scale in target_scales] == self.num_vision_tokens_each_scale_each_frame):
|
| 322 |
+
video, res_adapt_info = self.input_res_adapt(video, target_scales, target_patch_size)
|
| 323 |
+
|
| 324 |
+
B, T = video.shape[:2]
|
| 325 |
+
|
| 326 |
+
# If gazing_pos is already provided, then directly calculate the probability of taking such gaze. Usually in the cases of calculating pi(a|s) in PPO/GRPO/etc.
|
| 327 |
+
# Otherwise, run the gazing model first to predict gazing positions.
|
| 328 |
+
if gazing_info is None or len(gazing_info) == 0:
|
| 329 |
+
with torch.autocast("cuda", dtype=torch.bfloat16) if self.attn_mode == "flash_attention_2" else nullcontext():
|
| 330 |
+
|
| 331 |
+
if gazing_ratio is not None and task_loss_requirement is not None:
|
| 332 |
+
# If the user specifies the gazing ratio and task loss requirement, then use gazing ratio as the max gazing ratio and use task loss requirement to control when to stop
|
| 333 |
+
if isinstance(gazing_ratio, list):
|
| 334 |
+
assert len(gazing_ratio) == T // self.frame_sampling_rate, "The number of gazing ratios must be equal to the number of frames"
|
| 335 |
+
gazing_ratio = torch.tensor(gazing_ratio)
|
| 336 |
+
gazing_ratio_each_frame = torch.ones(T // self.frame_sampling_rate) * gazing_ratio
|
| 337 |
+
num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
|
| 338 |
+
task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
|
| 339 |
+
elif gazing_ratio is not None:
|
| 340 |
+
# If the user specifies the gazing ratio, then turn off the task loss requirement
|
| 341 |
+
if isinstance(gazing_ratio, list):
|
| 342 |
+
assert len(gazing_ratio) == T // self.frame_sampling_rate, "The number of gazing ratios must be equal to the number of frames"
|
| 343 |
+
gazing_ratio = torch.tensor(gazing_ratio)
|
| 344 |
+
gazing_ratio_each_frame = torch.ones(T // self.frame_sampling_rate) * gazing_ratio
|
| 345 |
+
num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
|
| 346 |
+
task_loss_requirement = None
|
| 347 |
+
elif task_loss_requirement is not None:
|
| 348 |
+
# If the user specifies the task loss requirement, then turn off the gazing ratio limit
|
| 349 |
+
gazing_ratio = 1
|
| 350 |
+
gazing_ratio_each_frame = torch.ones(T // self.frame_sampling_rate) * gazing_ratio
|
| 351 |
+
num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
|
| 352 |
+
task_loss_requirement = torch.ones(B, T // self.frame_sampling_rate, device=video.device) * task_loss_requirement
|
| 353 |
+
else:
|
| 354 |
+
gazing_ratio = self.get_gazing_ratio()
|
| 355 |
+
gazing_ratio_each_frame = self.get_gazing_ratio_each_frame(inputs, video, gazing_ratio, T // self.frame_sampling_rate, temperature, use_cache)
|
| 356 |
+
num_gaze_tokens_each_frame = (gazing_ratio_each_frame * self.num_vision_tokens_each_frame).to(torch.long).clamp(min=1)
|
| 357 |
+
task_loss_requirement = self.get_task_loss_requirement(video)
|
| 358 |
+
|
| 359 |
+
if self.training:
|
| 360 |
+
gazing_info = self.gazing_model.generate(
|
| 361 |
+
video,
|
| 362 |
+
max_gaze_tokens_each_frame=num_gaze_tokens_each_frame,
|
| 363 |
+
task_loss_requirement=task_loss_requirement,
|
| 364 |
+
do_sample=True,
|
| 365 |
+
temperature=temperature,
|
| 366 |
+
use_cache=use_cache,
|
| 367 |
+
past_key_values=past_key_values,
|
| 368 |
+
past_inputs_embeds=past_inputs_embeds,
|
| 369 |
+
past_attention_mask=past_attention_mask,
|
| 370 |
+
past_conv_values=past_conv_values,
|
| 371 |
+
)
|
| 372 |
+
else:
|
| 373 |
+
gazing_info = self.gazing_model.generate(
|
| 374 |
+
video,
|
| 375 |
+
max_gaze_tokens_each_frame=num_gaze_tokens_each_frame,
|
| 376 |
+
task_loss_requirement=task_loss_requirement,
|
| 377 |
+
do_sample=False,
|
| 378 |
+
use_cache=use_cache,
|
| 379 |
+
past_key_values=past_key_values,
|
| 380 |
+
past_inputs_embeds=past_inputs_embeds,
|
| 381 |
+
past_attention_mask=past_attention_mask,
|
| 382 |
+
past_conv_values=past_conv_values,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Unpack gazing_info
|
| 386 |
+
gazing_pos = gazing_info["gazing_pos"]
|
| 387 |
+
num_gazing_each_frame = gazing_info["num_gazing_each_frame"]
|
| 388 |
+
if_padded_gazing = gazing_info["if_padded_gazing"]
|
| 389 |
+
task_loss_requirement = gazing_info.get("task_loss_requirement", None)
|
| 390 |
+
new_past_key_values = gazing_info.get("past_key_values", None)
|
| 391 |
+
new_past_inputs_embeds = gazing_info.get("past_inputs_embeds", None)
|
| 392 |
+
new_past_attention_mask = gazing_info.get("past_attention_mask", None)
|
| 393 |
+
new_past_conv_values = gazing_info.get("past_conv_values", None)
|
| 394 |
+
|
| 395 |
+
# Get the log probablity of taking such gaze (log_action_probs)
|
| 396 |
+
if not generate_only:
|
| 397 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 398 |
+
forward_outputs = self.gazing_model(video, gazing_info) # B * N
|
| 399 |
+
action_probs = forward_outputs.gaze_probs
|
| 400 |
+
task_loss_prediction = forward_outputs.task_loss_prediction
|
| 401 |
+
log_action_probs = torch.log(action_probs + 1e-8) # B * N
|
| 402 |
+
else:
|
| 403 |
+
log_action_probs = None
|
| 404 |
+
task_loss_prediction = None
|
| 405 |
+
|
| 406 |
+
# Generate (multi-scale) gazing masks for ease of visualization
|
| 407 |
+
mask = self.get_mask_from_gazing_pos(video, gazing_pos, if_padded_gazing)
|
| 408 |
+
|
| 409 |
+
to_return = {
|
| 410 |
+
'gazing_pos': gazing_pos,
|
| 411 |
+
'log_action_probs': log_action_probs,
|
| 412 |
+
'gazing_mask': mask,
|
| 413 |
+
"scales": self.scales,
|
| 414 |
+
"frame_sampling_rate": self.frame_sampling_rate,
|
| 415 |
+
"num_vision_tokens_each_frame": self.num_vision_tokens_each_frame,
|
| 416 |
+
"num_gazing_each_frame": num_gazing_each_frame,
|
| 417 |
+
"if_padded_gazing": if_padded_gazing,
|
| 418 |
+
"task_loss_prediction": task_loss_prediction,
|
| 419 |
+
"has_task_loss_requirement": task_loss_requirement is not None,
|
| 420 |
+
"task_loss_requirement": task_loss_requirement,
|
| 421 |
+
"past_key_values": new_past_key_values if use_cache else None,
|
| 422 |
+
"past_inputs_embeds": new_past_inputs_embeds if use_cache else None,
|
| 423 |
+
"past_attention_mask": new_past_attention_mask if use_cache else None,
|
| 424 |
+
"past_conv_values": new_past_conv_values if use_cache else None,
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
# Postprocess the output to recover from resolution adaptation
|
| 428 |
+
if target_scales is not None and target_patch_size is not None:
|
| 429 |
+
if not (target_scales == self.scales and [(scale // target_patch_size) ** 2 for scale in target_scales] == self.num_vision_tokens_each_scale_each_frame):
|
| 430 |
+
to_return.update(self.recover_output_from_res_adapt(to_return, res_adapt_info))
|
| 431 |
+
|
| 432 |
+
return to_return
|
autogaze/models/autogaze/configuration_autogaze.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""AutoGaze model configuration"""
|
| 3 |
+
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
from transformers.utils import logging
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 8 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GazeDecoderConfig(PretrainedConfig):
|
| 15 |
+
r"""
|
| 16 |
+
Based on LLamaConfig from transformers.
|
| 17 |
+
```"""
|
| 18 |
+
model_type = "llama"
|
| 19 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 20 |
+
# Default tensor parallel plan for base model `LlamaModel`
|
| 21 |
+
base_model_tp_plan = {
|
| 22 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 23 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 24 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 25 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 26 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 27 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 28 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 29 |
+
}
|
| 30 |
+
base_model_pp_plan = {
|
| 31 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 32 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 33 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
vocab_size=32000,
|
| 39 |
+
hidden_size=4096,
|
| 40 |
+
intermediate_size=11008,
|
| 41 |
+
num_hidden_layers=32,
|
| 42 |
+
num_attention_heads=32,
|
| 43 |
+
num_key_value_heads=None,
|
| 44 |
+
hidden_act="silu",
|
| 45 |
+
max_position_embeddings=2048,
|
| 46 |
+
initializer_range=0.02,
|
| 47 |
+
rms_norm_eps=1e-6,
|
| 48 |
+
use_cache=True,
|
| 49 |
+
pad_token_id=None,
|
| 50 |
+
bos_token_id=1,
|
| 51 |
+
eos_token_id=2,
|
| 52 |
+
pretraining_tp=1,
|
| 53 |
+
tie_word_embeddings=False,
|
| 54 |
+
rope_theta=10000.0,
|
| 55 |
+
rope_scaling=None,
|
| 56 |
+
attention_bias=False,
|
| 57 |
+
attention_dropout=0.0,
|
| 58 |
+
mlp_bias=False,
|
| 59 |
+
head_dim=None,
|
| 60 |
+
attn_mode="sdpa",
|
| 61 |
+
num_multi_token_pred=1,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
self.vocab_size = vocab_size
|
| 65 |
+
self.max_position_embeddings = max_position_embeddings
|
| 66 |
+
self.hidden_size = hidden_size
|
| 67 |
+
self.intermediate_size = intermediate_size
|
| 68 |
+
self.num_hidden_layers = num_hidden_layers
|
| 69 |
+
self.num_attention_heads = num_attention_heads
|
| 70 |
+
|
| 71 |
+
# for backward compatibility
|
| 72 |
+
if num_key_value_heads is None:
|
| 73 |
+
num_key_value_heads = num_attention_heads
|
| 74 |
+
|
| 75 |
+
self.num_key_value_heads = num_key_value_heads
|
| 76 |
+
self.hidden_act = hidden_act
|
| 77 |
+
self.initializer_range = initializer_range
|
| 78 |
+
self.rms_norm_eps = rms_norm_eps
|
| 79 |
+
self.pretraining_tp = pretraining_tp
|
| 80 |
+
self.use_cache = use_cache
|
| 81 |
+
self.rope_theta = rope_theta
|
| 82 |
+
self.rope_scaling = rope_scaling
|
| 83 |
+
self.attention_bias = attention_bias
|
| 84 |
+
self.attention_dropout = attention_dropout
|
| 85 |
+
self.mlp_bias = mlp_bias
|
| 86 |
+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
| 87 |
+
self.num_multi_token_pred = num_multi_token_pred
|
| 88 |
+
self._attn_implementation = attn_mode
|
| 89 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 90 |
+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
| 91 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 92 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 93 |
+
rope_config_validation(self)
|
| 94 |
+
|
| 95 |
+
super().__init__(
|
| 96 |
+
pad_token_id=pad_token_id,
|
| 97 |
+
bos_token_id=bos_token_id,
|
| 98 |
+
eos_token_id=eos_token_id,
|
| 99 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 100 |
+
**kwargs,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class VisionModelConfig(PretrainedConfig):
|
| 105 |
+
r"""
|
| 106 |
+
Configuration for the vision model component of AutoGaze.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
hidden_dim (`int`, *optional*, defaults to `192`):
|
| 110 |
+
Hidden dimension of the vision model.
|
| 111 |
+
out_dim (`int`, *optional*, defaults to `192`):
|
| 112 |
+
Output dimension of the vision model.
|
| 113 |
+
depth (`int`, *optional*, defaults to `1`):
|
| 114 |
+
Depth of the vision model.
|
| 115 |
+
kernel_size (`int`, *optional*, defaults to `16`):
|
| 116 |
+
Kernel size for spatial convolution.
|
| 117 |
+
temporal_patch_size (`int`, *optional*, defaults to `1`):
|
| 118 |
+
Temporal patch size for video processing.
|
| 119 |
+
trunk_temporal_kernel_size (`int`, *optional*, defaults to `3`):
|
| 120 |
+
Temporal kernel size for trunk blocks.
|
| 121 |
+
trunk_spatial_kernel_size (`int`, *optional*, defaults to `3`):
|
| 122 |
+
Spatial kernel size for trunk blocks.
|
| 123 |
+
"""
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
hidden_dim=192,
|
| 127 |
+
out_dim=192,
|
| 128 |
+
depth=1,
|
| 129 |
+
kernel_size=16,
|
| 130 |
+
temporal_patch_size=1,
|
| 131 |
+
trunk_temporal_kernel_size=3,
|
| 132 |
+
trunk_spatial_kernel_size=3,
|
| 133 |
+
**kwargs,
|
| 134 |
+
):
|
| 135 |
+
self.hidden_dim = hidden_dim
|
| 136 |
+
self.out_dim = out_dim
|
| 137 |
+
self.depth = depth
|
| 138 |
+
self.kernel_size = kernel_size
|
| 139 |
+
self.temporal_patch_size = temporal_patch_size
|
| 140 |
+
self.trunk_temporal_kernel_size = trunk_temporal_kernel_size
|
| 141 |
+
self.trunk_spatial_kernel_size = trunk_spatial_kernel_size
|
| 142 |
+
|
| 143 |
+
super().__init__(**kwargs)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class ConnectorConfig(PretrainedConfig):
|
| 147 |
+
r"""
|
| 148 |
+
Configuration for the connector component between vision encoder and gaze model.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
hidden_dim (`int`, *optional*, defaults to `192`):
|
| 152 |
+
Hidden dimension of the connector.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
hidden_dim=192,
|
| 158 |
+
num_tokens=196,
|
| 159 |
+
**kwargs,
|
| 160 |
+
):
|
| 161 |
+
self.hidden_dim = hidden_dim
|
| 162 |
+
self.num_tokens = num_tokens
|
| 163 |
+
|
| 164 |
+
super().__init__(**kwargs)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class GazeModelConfig(PretrainedConfig):
|
| 168 |
+
r"""
|
| 169 |
+
Configuration for the gaze model, containing vision model, connector, and decoder configs.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
num_multi_token_pred (`int`, *optional*, defaults to `1`):
|
| 173 |
+
Number of tokens to predict in parallel.
|
| 174 |
+
input_img_size (`int`, *optional*, defaults to `224`):
|
| 175 |
+
Input image size.
|
| 176 |
+
vision_model_config (`VisionModelConfig` or `dict`, *optional*):
|
| 177 |
+
Configuration for the vision model.
|
| 178 |
+
connector_config (`ConnectorConfig` or `dict`, *optional*):
|
| 179 |
+
Configuration for the connector.
|
| 180 |
+
gaze_decoder_config (`GazeDecoderConfig` or `dict`, *optional*):
|
| 181 |
+
Configuration for the gaze decoder (LLaMA-based).
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
input_img_size=224,
|
| 187 |
+
vision_model_config={},
|
| 188 |
+
connector_config={},
|
| 189 |
+
gaze_decoder_config={},
|
| 190 |
+
num_vision_tokens_each_frame=196,
|
| 191 |
+
attn_mode="sdpa",
|
| 192 |
+
**kwargs,
|
| 193 |
+
):
|
| 194 |
+
self.input_img_size = input_img_size
|
| 195 |
+
self.vision_model_config = VisionModelConfig(**vision_model_config)
|
| 196 |
+
|
| 197 |
+
connector_config.update({
|
| 198 |
+
"num_tokens": (input_img_size // self.vision_model_config.kernel_size)**2,
|
| 199 |
+
})
|
| 200 |
+
self.connector_config = ConnectorConfig(**connector_config)
|
| 201 |
+
|
| 202 |
+
gaze_decoder_config.update({
|
| 203 |
+
"vocab_size": num_vision_tokens_each_frame + 1,
|
| 204 |
+
"eos_token_id": num_vision_tokens_each_frame,
|
| 205 |
+
"attn_mode": attn_mode,
|
| 206 |
+
})
|
| 207 |
+
self.gaze_decoder_config = GazeDecoderConfig(**gaze_decoder_config)
|
| 208 |
+
|
| 209 |
+
self.num_vision_tokens_each_frame = num_vision_tokens_each_frame
|
| 210 |
+
|
| 211 |
+
super().__init__(**kwargs)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class AutoGazeConfig(PretrainedConfig):
|
| 215 |
+
r"""
|
| 216 |
+
This is the configuration class to store the configuration of an [`AutoGaze`] model. It is used to instantiate an
|
| 217 |
+
AutoGaze model according to the specified arguments, defining the model architecture.
|
| 218 |
+
|
| 219 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 220 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
gazing_ratio_config (`dict`, *optional*):
|
| 224 |
+
Configuration for sampling gazing ratio during training and inference.
|
| 225 |
+
scales (`str` or `int`, *optional*, defaults to `"224"`):
|
| 226 |
+
Scales for the vision model. Can be a single scale or multiple scales separated by '+'.
|
| 227 |
+
num_vision_tokens_each_frame (`int`, *optional*, defaults to `196`):
|
| 228 |
+
Number of vision tokens per frame.
|
| 229 |
+
gaze_model_config (`GazeModelConfig` or `dict`, *optional*):
|
| 230 |
+
Configuration for the gaze model, including vision_model_config, connector_config, and gaze_decoder_config.
|
| 231 |
+
gazing_ratio_each_frame_config (`dict`, *optional*):
|
| 232 |
+
Configuration for sampling gazing ratio for each frame.
|
| 233 |
+
has_task_loss_requirement_during_training (`bool`, *optional*, defaults to `False`):
|
| 234 |
+
Whether to use task loss requirement during training.
|
| 235 |
+
has_task_loss_requirement_during_inference (`bool`, *optional*, defaults to `False`):
|
| 236 |
+
Whether to use task loss requirement during inference.
|
| 237 |
+
task_loss_requirement_config (`dict`, *optional*):
|
| 238 |
+
Configuration for task loss requirement sampling.
|
| 239 |
+
image_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
| 240 |
+
Image mean for normalization.
|
| 241 |
+
image_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
| 242 |
+
Image std for normalization.
|
| 243 |
+
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
| 244 |
+
Whether to use flash attention.
|
| 245 |
+
max_batch_size (`int`, *optional*):
|
| 246 |
+
Maximum batch size.
|
| 247 |
+
|
| 248 |
+
```python
|
| 249 |
+
>>> from autogaze.models.autogaze import AutoGaze, AutoGazeConfig
|
| 250 |
+
|
| 251 |
+
>>> # Initializing an AutoGaze configuration
|
| 252 |
+
>>> configuration = AutoGazeConfig()
|
| 253 |
+
|
| 254 |
+
>>> # Initializing a model from the configuration
|
| 255 |
+
>>> model = AutoGaze(configuration)
|
| 256 |
+
|
| 257 |
+
>>> # Accessing the model configuration
|
| 258 |
+
>>> configuration = model.config
|
| 259 |
+
```"""
|
| 260 |
+
|
| 261 |
+
model_type = "autogaze"
|
| 262 |
+
|
| 263 |
+
def __init__(
|
| 264 |
+
self,
|
| 265 |
+
gazing_ratio_config=None,
|
| 266 |
+
scales="224",
|
| 267 |
+
num_vision_tokens_each_frame=196,
|
| 268 |
+
gaze_model_config={},
|
| 269 |
+
gazing_ratio_each_frame_config=None,
|
| 270 |
+
has_task_loss_requirement_during_training=False,
|
| 271 |
+
has_task_loss_requirement_during_inference=False,
|
| 272 |
+
task_loss_requirement_config=None,
|
| 273 |
+
image_mean=[0.485, 0.456, 0.406],
|
| 274 |
+
image_std=[0.229, 0.224, 0.225],
|
| 275 |
+
use_flash_attn=True,
|
| 276 |
+
max_batch_size=None,
|
| 277 |
+
**kwargs,
|
| 278 |
+
):
|
| 279 |
+
self.gazing_ratio_config = gazing_ratio_config or {
|
| 280 |
+
"sample_strategy_during_training": "fixed",
|
| 281 |
+
"sample_strategy_during_inference": "fixed",
|
| 282 |
+
"fixed": {"gazing_ratio": 0.5},
|
| 283 |
+
"uniform": {"gazing_ratio_min": 0, "gazing_ratio_max": 1},
|
| 284 |
+
"exponential": {"gazing_ratio_min": 0, "gazing_ratio_max": 1, "lambda": 10},
|
| 285 |
+
}
|
| 286 |
+
self.scales = scales
|
| 287 |
+
self.num_vision_tokens_each_frame = num_vision_tokens_each_frame
|
| 288 |
+
self.attn_mode = "flash_attention_2" if use_flash_attn else "sdpa"
|
| 289 |
+
|
| 290 |
+
gaze_model_config.update({
|
| 291 |
+
"num_vision_tokens_each_frame": num_vision_tokens_each_frame,
|
| 292 |
+
"attn_mode": self.attn_mode,
|
| 293 |
+
})
|
| 294 |
+
self.gaze_model_config = GazeModelConfig(**gaze_model_config)
|
| 295 |
+
|
| 296 |
+
self.gazing_ratio_each_frame_config = gazing_ratio_each_frame_config or {
|
| 297 |
+
"sample_strategy_during_training": "uniform",
|
| 298 |
+
"sample_strategy_during_inference": "uniform",
|
| 299 |
+
"uniform": {},
|
| 300 |
+
"dirichlet": {"alpha": 0.5},
|
| 301 |
+
"self": {},
|
| 302 |
+
}
|
| 303 |
+
self.has_task_loss_requirement_during_training = has_task_loss_requirement_during_training
|
| 304 |
+
self.has_task_loss_requirement_during_inference = has_task_loss_requirement_during_inference
|
| 305 |
+
self.task_loss_requirement_config = task_loss_requirement_config or {
|
| 306 |
+
"sample_strategy_during_training": "fixed",
|
| 307 |
+
"sample_strategy_during_inference": "fixed",
|
| 308 |
+
"fixed": {"task_loss_requirement": 0.7},
|
| 309 |
+
"uniform": {"task_loss_requirement_min": 0.6, "task_loss_requirement_max": 0.9},
|
| 310 |
+
}
|
| 311 |
+
self.image_mean = image_mean
|
| 312 |
+
self.image_std = image_std
|
| 313 |
+
self.use_flash_attn = use_flash_attn
|
| 314 |
+
self.max_batch_size = max_batch_size
|
| 315 |
+
|
| 316 |
+
super().__init__(**kwargs)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
__all__ = [
|
| 320 |
+
"AutoGazeConfig",
|
| 321 |
+
"GazeModelConfig",
|
| 322 |
+
"VisionModelConfig",
|
| 323 |
+
"ConnectorConfig",
|
| 324 |
+
"GazeDecoderConfig",
|
| 325 |
+
]
|
| 326 |
+
|
autogaze/models/autogaze/modeling_autogaze.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from timm.models.convnext import ConvNeXtBlock
|
| 11 |
+
from timm.layers import LayerNorm2d
|
| 12 |
+
|
| 13 |
+
from transformers.modeling_outputs import ModelOutput
|
| 14 |
+
from transformers import LogitsProcessor, LogitsProcessorList
|
| 15 |
+
|
| 16 |
+
from .configuration_autogaze import GazeModelConfig, VisionModelConfig, ConnectorConfig
|
| 17 |
+
from .modeling_llama_multi_token_pred import LlamaForCausalLM_MultiTokenPred
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class AutoGazeOutput(ModelOutput):
|
| 22 |
+
gaze_logits: Optional[torch.FloatTensor] = None
|
| 23 |
+
gaze_probs: Optional[torch.FloatTensor] = None
|
| 24 |
+
loss: Optional[torch.FloatTensor] = None
|
| 25 |
+
logits: torch.FloatTensor = None
|
| 26 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 27 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 28 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 29 |
+
task_loss_prediction: Optional[torch.FloatTensor] = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class NoRepeatTokensLogitsProcessor(LogitsProcessor):
|
| 33 |
+
def __init__(self):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 37 |
+
# input_ids: (batch_size, sequence_length)
|
| 38 |
+
# scores: (batch_size, vocab_size) or (batch_size, num_multi_token_pred, vocab_size)
|
| 39 |
+
if scores.ndim == 3:
|
| 40 |
+
scores[torch.arange(scores.shape[0])[..., None], :, input_ids] = -float("inf")
|
| 41 |
+
else:
|
| 42 |
+
scores[torch.arange(scores.shape[0])[..., None], input_ids] = -float("inf")
|
| 43 |
+
return scores
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class NoEosTokenLogitsProcessor(LogitsProcessor):
|
| 47 |
+
def __init__(self):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 51 |
+
# input_ids: (batch_size, sequence_length)
|
| 52 |
+
# scores: (batch_size, vocab_size) or (batch_size, num_multi_token_pred, vocab_size)
|
| 53 |
+
scores[..., -1] = -float("inf")
|
| 54 |
+
return scores
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class AutoGazeModel(nn.Module):
|
| 58 |
+
def __init__(self, gaze_model_config: GazeModelConfig):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
self.num_vision_tokens_each_frame = gaze_model_config.num_vision_tokens_each_frame
|
| 62 |
+
self.input_img_size = gaze_model_config.input_img_size
|
| 63 |
+
self.frame_sampling_rate = gaze_model_config.vision_model_config.temporal_patch_size
|
| 64 |
+
self.num_multi_token_pred = gaze_model_config.gaze_decoder_config.num_multi_token_pred
|
| 65 |
+
self.gaze_decoder_config = gaze_model_config.gaze_decoder_config # Store for reference
|
| 66 |
+
|
| 67 |
+
# Create the vision model, connector, and gaze decoder
|
| 68 |
+
self.vision_model = ShallowVideoConvNet(gaze_model_config.vision_model_config)
|
| 69 |
+
self.connector = Connector(gaze_model_config.connector_config)
|
| 70 |
+
self.gaze_decoder = LlamaForCausalLM_MultiTokenPred(gaze_model_config.gaze_decoder_config)
|
| 71 |
+
|
| 72 |
+
# Add logits processors to prevent the model from repeating the same token and generating eos token during gazing.
|
| 73 |
+
self.logits_processor = LogitsProcessorList()
|
| 74 |
+
self.logits_processor.append(NoRepeatTokensLogitsProcessor()) # don't allow repeated gazing
|
| 75 |
+
self.logits_processor.append(NoEosTokenLogitsProcessor()) # don't allow generating eos token duing gazing
|
| 76 |
+
|
| 77 |
+
def embed(self, video=None, gaze_pos_ids=None, use_cache=False, past_conv_values=None):
|
| 78 |
+
"""
|
| 79 |
+
inputs:
|
| 80 |
+
video: (B x T x C x H x W).
|
| 81 |
+
gaze_pos_ids: list of (B, N), N is the number of gazing positions in each frame. The length of the list is T // frame_sampling_rate.
|
| 82 |
+
returns:
|
| 83 |
+
embeds: a list of interleaved vision and gaze embeddings. list of (B, N, C)
|
| 84 |
+
gaze_token_mask: a list of masks that indicate if the current embedding is a gaze embedding. (1 is gaze embedding, 0 is vision embedding). list of (N, )
|
| 85 |
+
gaze_pred_source_relative: a list of (relative) source index of where the gaze prediction is coming from. For example, if the gaze prediction is coming from two tokens before it, the source index is -2. list of (N, ).
|
| 86 |
+
For vision embeddings, there's no source prediction, so the source index is -1.
|
| 87 |
+
attention_mask: a list of (B, N) that indicates if the current embedding should be masked out (for EOS token). 1 is not masked, 0 is masked.
|
| 88 |
+
"""
|
| 89 |
+
B, T = video.shape[:2]
|
| 90 |
+
assert (video is None or gaze_pos_ids is None) or video.shape[1] // self.frame_sampling_rate == len(gaze_pos_ids), \
|
| 91 |
+
"The number of frames in the video (after subsampling) and in gaze position IDs must be the same, but got {} and {}".format(video.shape[1] // self.frame_sampling_rate, len(gaze_pos_ids))
|
| 92 |
+
|
| 93 |
+
if video is not None:
|
| 94 |
+
vision_features, new_past_conv_values = self.vision_model(video, use_cache=use_cache, past_conv_values=past_conv_values)
|
| 95 |
+
vision_features = vision_features.transpose(1, 2)
|
| 96 |
+
vision_features = rearrange(vision_features, 'b t c h w -> b t (h w) c')
|
| 97 |
+
vision_features = self.connector(vision_features)
|
| 98 |
+
vision_attention_mask = [torch.ones(B, vision_features.shape[2], device=vision_features.device).long() for _ in range(vision_features.shape[1])]
|
| 99 |
+
|
| 100 |
+
if gaze_pos_ids is not None:
|
| 101 |
+
num_gazing_each_frame = [gaze_pos_ids[t].shape[1] for t in range(len(gaze_pos_ids))]
|
| 102 |
+
gaze_pos_ids = torch.cat(gaze_pos_ids, dim=1)
|
| 103 |
+
gaze_attention_mask = (gaze_pos_ids != self.gaze_decoder_config.eos_token_id).to(torch.long)
|
| 104 |
+
gaze_embeds = self.gaze_decoder.model.embed_tokens(gaze_pos_ids)
|
| 105 |
+
gaze_embeds = list(gaze_embeds.split(num_gazing_each_frame, dim=1))
|
| 106 |
+
gaze_attention_mask = list(gaze_attention_mask.split(num_gazing_each_frame, dim=1))
|
| 107 |
+
|
| 108 |
+
embeds = []
|
| 109 |
+
gaze_token_mask = []
|
| 110 |
+
gaze_pred_source_relative = []
|
| 111 |
+
attention_mask = []
|
| 112 |
+
for t in range(T // self.frame_sampling_rate):
|
| 113 |
+
if video is not None:
|
| 114 |
+
embeds.append(vision_features[:, t, :, :])
|
| 115 |
+
gaze_token_mask.append(torch.zeros(vision_features.shape[2], device=vision_features.device).long())
|
| 116 |
+
gaze_pred_source_relative.append(torch.zeros(vision_features.shape[2], device=vision_features.device).long() - 1)
|
| 117 |
+
attention_mask.append(vision_attention_mask[t])
|
| 118 |
+
if gaze_pos_ids is not None:
|
| 119 |
+
embeds.append(gaze_embeds[t])
|
| 120 |
+
gaze_token_mask.append(torch.ones(gaze_embeds[t].shape[1], device=gaze_embeds[t].device).long())
|
| 121 |
+
gaze_pred_source_relative.append(-(torch.arange(gaze_embeds[t].shape[1], device=gaze_embeds[t].device) % self.num_multi_token_pred + 1))
|
| 122 |
+
attention_mask.append(gaze_attention_mask[t])
|
| 123 |
+
return embeds, gaze_token_mask, gaze_pred_source_relative, attention_mask, new_past_conv_values if video is not None else None
|
| 124 |
+
|
| 125 |
+
@torch.no_grad()
|
| 126 |
+
def generate(
|
| 127 |
+
self,
|
| 128 |
+
video,
|
| 129 |
+
max_gaze_tokens_each_frame=100,
|
| 130 |
+
task_loss_requirement=None,
|
| 131 |
+
use_cache=False,
|
| 132 |
+
past_key_values=None,
|
| 133 |
+
past_inputs_embeds=None,
|
| 134 |
+
past_attention_mask=None,
|
| 135 |
+
past_conv_values=None,
|
| 136 |
+
**generation_kwargs,
|
| 137 |
+
):
|
| 138 |
+
"""
|
| 139 |
+
Inputs:
|
| 140 |
+
video: (B, T, C, H, W)
|
| 141 |
+
max_gaze_tokens_each_frame: int or (T, ). Indicating the max gazing length for each frame. If is int, then all frames have the same max gazing length.
|
| 142 |
+
task_loss_requirement (optional): (B, T). Indicating the task loss requirement for each frame.
|
| 143 |
+
past_key_values (optional): The past key values for the gaze model. Can be used for streaming generation.
|
| 144 |
+
past_inputs_embeds (optional): The past inputs embeds for the gaze model. Can be used for streaming generation.
|
| 145 |
+
past_attention_mask (optional): The past attention mask for the gaze model. Can be used for streaming generation.
|
| 146 |
+
"""
|
| 147 |
+
if past_key_values is not None or past_inputs_embeds is not None or past_attention_mask is not None or past_conv_values is not None:
|
| 148 |
+
assert past_key_values is not None and past_inputs_embeds is not None and past_attention_mask is not None and past_conv_values is not None, \
|
| 149 |
+
"If past_key_values, past_inputs_embeds, past_attention_mask, or past_conv_values is provided, then all four must be provided!"
|
| 150 |
+
|
| 151 |
+
# Subsample frames and resize
|
| 152 |
+
B, T = video.shape[:2]
|
| 153 |
+
video = rearrange(video, 'b t c h w -> (b t) c h w')
|
| 154 |
+
video = F.interpolate(video, size=(self.input_img_size, self.input_img_size), mode="bicubic", align_corners=False)
|
| 155 |
+
video = rearrange(video, '(b t) c h w -> b t c h w', b=B)
|
| 156 |
+
|
| 157 |
+
# Embed all the frames
|
| 158 |
+
video_embeds, _, __, ___, past_conv_values = self.embed(video=video, use_cache=use_cache, past_conv_values=past_conv_values)
|
| 159 |
+
|
| 160 |
+
# Generate gaze position IDs for each frame
|
| 161 |
+
gaze_pos_ids_list = []
|
| 162 |
+
inputs_embeds = [] if past_inputs_embeds is None else past_inputs_embeds
|
| 163 |
+
attention_mask = [] if past_attention_mask is None else past_attention_mask
|
| 164 |
+
past_key_values = None if past_key_values is None else past_key_values
|
| 165 |
+
num_gazing_each_frame = []
|
| 166 |
+
if_padded_gazing = []
|
| 167 |
+
for t in range(len(video_embeds)):
|
| 168 |
+
|
| 169 |
+
# Update inputs_embeds and attention mask for the new frame
|
| 170 |
+
inputs_embeds.append(video_embeds[t])
|
| 171 |
+
attention_mask.append(torch.ones(video_embeds[t].shape[0], video_embeds[t].shape[1], device=video_embeds[t].device).long())
|
| 172 |
+
|
| 173 |
+
# Put task loss requirement into generation config
|
| 174 |
+
generation_config = self.gaze_decoder.generation_config
|
| 175 |
+
generation_config.task_loss_requirement = task_loss_requirement[:, t] if task_loss_requirement is not None else None
|
| 176 |
+
|
| 177 |
+
# Get the max gazing length for the current frame
|
| 178 |
+
assert isinstance(max_gaze_tokens_each_frame, int) or len(max_gaze_tokens_each_frame) == len(video_embeds), \
|
| 179 |
+
"max_gaze_tokens_each_frame must be an int or a tensor of the same length as the video embeddings, but got {} and {}".format(max_gaze_tokens_each_frame, len(video_embeds))
|
| 180 |
+
max_gaze_tokens = max_gaze_tokens_each_frame if isinstance(max_gaze_tokens_each_frame, int) else max_gaze_tokens_each_frame[t]
|
| 181 |
+
|
| 182 |
+
# Generate gaze position IDs for the current frame
|
| 183 |
+
is_gradient_checkpointing = self.gaze_decoder.is_gradient_checkpointing
|
| 184 |
+
if is_gradient_checkpointing:
|
| 185 |
+
self.gaze_decoder.gradient_checkpointing_disable()
|
| 186 |
+
gaze_outputs = self.gaze_decoder.generate(
|
| 187 |
+
inputs_embeds=torch.cat(inputs_embeds, dim=1), # We need to pass the whole sequence of inputs_embeds (both current and past) to the model even when we use use_cache=True!!!
|
| 188 |
+
attention_mask=torch.cat(attention_mask, dim=1),
|
| 189 |
+
position_ids=torch.cat(attention_mask, dim=1).cumsum(dim=-1) - 1,
|
| 190 |
+
max_new_tokens=max_gaze_tokens,
|
| 191 |
+
logits_processor=self.logits_processor,
|
| 192 |
+
pad_token_id=self.gaze_decoder_config.eos_token_id,
|
| 193 |
+
eos_token_id=self.gaze_decoder_config.eos_token_id,
|
| 194 |
+
past_key_values=past_key_values,
|
| 195 |
+
use_cache=True,
|
| 196 |
+
return_dict_in_generate=True,
|
| 197 |
+
generation_config=generation_config,
|
| 198 |
+
**generation_kwargs,
|
| 199 |
+
)
|
| 200 |
+
if is_gradient_checkpointing:
|
| 201 |
+
self.gaze_decoder.gradient_checkpointing_enable()
|
| 202 |
+
|
| 203 |
+
# Get the predicted gaze ids
|
| 204 |
+
gaze_pos_ids = gaze_outputs.sequences # B * N
|
| 205 |
+
gaze_pos_ids_list.append(gaze_pos_ids + self.num_vision_tokens_each_frame * t)
|
| 206 |
+
|
| 207 |
+
# Update inputs_embeds for the next frame
|
| 208 |
+
inputs_embeds.append(self.gaze_decoder.model.embed_tokens(gaze_pos_ids))
|
| 209 |
+
|
| 210 |
+
# Update past_key_values for the next frame
|
| 211 |
+
past_key_values = gaze_outputs.past_key_values
|
| 212 |
+
|
| 213 |
+
# Update auxiliary information
|
| 214 |
+
num_gazing_each_frame.append(gaze_pos_ids.shape[1])
|
| 215 |
+
if_padded_gazing.append(gaze_pos_ids == self.gaze_decoder_config.eos_token_id)
|
| 216 |
+
|
| 217 |
+
# Update attention mask
|
| 218 |
+
attention_mask.append((gaze_pos_ids != self.gaze_decoder_config.eos_token_id).to(torch.long))
|
| 219 |
+
|
| 220 |
+
# Concatenate gaze position IDs from all frames
|
| 221 |
+
gaze_pos_ids = torch.cat(gaze_pos_ids_list, dim=1)
|
| 222 |
+
|
| 223 |
+
# Get auxiliary information
|
| 224 |
+
num_gazing_each_frame = torch.tensor(num_gazing_each_frame, device=gaze_pos_ids.device).to(torch.long)
|
| 225 |
+
if_padded_gazing = torch.cat(if_padded_gazing, dim=1)
|
| 226 |
+
|
| 227 |
+
to_return = {
|
| 228 |
+
"gazing_pos": gaze_pos_ids, # In gaze_pos_ids, the padded gazing positions are not necessarily eos_token_id, so one needs to use if_padded_gazing to determine if the gazing position is padded!!!
|
| 229 |
+
"num_gazing_each_frame": num_gazing_each_frame,
|
| 230 |
+
"if_padded_gazing": if_padded_gazing,
|
| 231 |
+
"task_loss_requirement": task_loss_requirement,
|
| 232 |
+
"past_inputs_embeds": inputs_embeds if use_cache else None,
|
| 233 |
+
"past_attention_mask": attention_mask if use_cache else None,
|
| 234 |
+
"past_key_values": past_key_values if use_cache else None,
|
| 235 |
+
"past_conv_values": past_conv_values if use_cache else None,
|
| 236 |
+
}
|
| 237 |
+
return to_return
|
| 238 |
+
|
| 239 |
+
def forward(self, video, gazing_info, **kwargs):
|
| 240 |
+
# Unpack gazing_info
|
| 241 |
+
gaze_pos_ids = gazing_info["gazing_pos"]
|
| 242 |
+
num_gazing_each_frame = gazing_info["num_gazing_each_frame"]
|
| 243 |
+
if_padded_gazing = gazing_info["if_padded_gazing"]
|
| 244 |
+
|
| 245 |
+
# Subsample frames and resize
|
| 246 |
+
B, T = video.shape[:2]
|
| 247 |
+
video = rearrange(video, 'b t c h w -> (b t) c h w')
|
| 248 |
+
video = F.interpolate(video, size=(self.input_img_size, self.input_img_size), mode="bicubic", align_corners=False)
|
| 249 |
+
video = rearrange(video, '(b t) c h w -> b t c h w', b=B)
|
| 250 |
+
|
| 251 |
+
# Split the gaze frame-wise
|
| 252 |
+
gaze_pos_ids_split = list(gaze_pos_ids.split(num_gazing_each_frame.tolist(), dim=1))
|
| 253 |
+
gaze_pos_ids_split = [gaze_pos_ids_split[t] - self.num_vision_tokens_each_frame * t for t in range(len(gaze_pos_ids_split))]
|
| 254 |
+
if_padded_gazing_split = list(if_padded_gazing.split(num_gazing_each_frame.tolist(), dim=1))
|
| 255 |
+
|
| 256 |
+
# Fill the padded gazing positions with eos_token_id
|
| 257 |
+
gaze_pos_ids_split = [gaze_pos * (~padded) + self.gaze_decoder_config.eos_token_id * padded for gaze_pos, padded in zip(gaze_pos_ids_split, if_padded_gazing_split)]
|
| 258 |
+
|
| 259 |
+
# Embed the video and gaze position IDs
|
| 260 |
+
inputs_embeds, gaze_token_mask, gaze_pred_source_relative, attention_mask, _ = self.embed(video=video, gaze_pos_ids=gaze_pos_ids_split)
|
| 261 |
+
inputs_embeds = torch.cat(inputs_embeds, dim=1) # B * N * C
|
| 262 |
+
gaze_token_mask = torch.cat(gaze_token_mask, dim=0) # N
|
| 263 |
+
gaze_pred_source_relative = torch.cat(gaze_pred_source_relative, dim=0) # N
|
| 264 |
+
attention_mask = torch.cat(attention_mask, dim=1) # B * N
|
| 265 |
+
|
| 266 |
+
# Run model forward
|
| 267 |
+
outputs = self.gaze_decoder(
|
| 268 |
+
inputs_embeds=inputs_embeds,
|
| 269 |
+
attention_mask=attention_mask,
|
| 270 |
+
position_ids=attention_mask.cumsum(dim=-1) - 1,
|
| 271 |
+
**kwargs,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Get gaze logits and probs
|
| 275 |
+
logits_multi_token_pred = outputs.logits
|
| 276 |
+
task_loss_prediction_multi_token_pred = outputs.task_loss_prediction # B * N * num_multi_token_pred
|
| 277 |
+
logits_multi_token_pred = rearrange(logits_multi_token_pred, 'b n (k c) -> b n k c', k=self.num_multi_token_pred)
|
| 278 |
+
gaze_probs_all_multi_token_pred = F.softmax(logits_multi_token_pred, dim=-1)
|
| 279 |
+
|
| 280 |
+
shifted_probs = []
|
| 281 |
+
shifted_task_loss_prediction = []
|
| 282 |
+
for i in range(self.num_multi_token_pred):
|
| 283 |
+
shifted_probs.append(F.pad(gaze_probs_all_multi_token_pred[:, :-(i + 1), i, :], (0, 0, i + 1, 0), value=0))
|
| 284 |
+
shifted_task_loss_prediction.append(F.pad(task_loss_prediction_multi_token_pred[:, :task_loss_prediction_multi_token_pred.shape[1] - i, i], (i, 0), value=0))
|
| 285 |
+
shifted_probs = torch.stack(shifted_probs, dim=2) # B, N, K, C
|
| 286 |
+
shifted_task_loss_prediction = torch.stack(shifted_task_loss_prediction, dim=2) # B, N, K
|
| 287 |
+
|
| 288 |
+
gaze_probs_all = shifted_probs[:, torch.arange(logits_multi_token_pred.shape[1]), -gaze_pred_source_relative - 1]
|
| 289 |
+
task_loss_prediction = shifted_task_loss_prediction[:, torch.arange(logits_multi_token_pred.shape[1]), (-gaze_pred_source_relative) % self.num_multi_token_pred] # B, N
|
| 290 |
+
|
| 291 |
+
gaze_input_token_pos = torch.nonzero(gaze_token_mask, as_tuple=True)[0]
|
| 292 |
+
gaze_probs_all = gaze_probs_all[:, gaze_input_token_pos, :]
|
| 293 |
+
task_loss_prediction = task_loss_prediction[:, gaze_input_token_pos]
|
| 294 |
+
B, N = gaze_probs_all.shape[:2]
|
| 295 |
+
gaze_probs = gaze_probs_all.reshape(B * N, -1)[torch.arange(B * N), torch.cat(gaze_pos_ids_split, dim=1).flatten()].reshape(B, N) # [B, T]
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
outputs = AutoGazeOutput(
|
| 299 |
+
gaze_probs=gaze_probs,
|
| 300 |
+
loss=outputs.loss,
|
| 301 |
+
logits=outputs.logits,
|
| 302 |
+
past_key_values=outputs.past_key_values,
|
| 303 |
+
hidden_states=outputs.hidden_states,
|
| 304 |
+
attentions=outputs.attentions,
|
| 305 |
+
task_loss_prediction=task_loss_prediction,
|
| 306 |
+
)
|
| 307 |
+
return outputs
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
################ Shallow Vision Encoder #################
|
| 311 |
+
|
| 312 |
+
class Conv3dBlockForStreaming(nn.Module):
|
| 313 |
+
def __init__(self, hidden_dim, temporal_patch_size, spatial_kernel_size):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.hidden_dim = hidden_dim
|
| 316 |
+
self.temporal_patch_size = temporal_patch_size
|
| 317 |
+
self.spatial_kernel_size = spatial_kernel_size
|
| 318 |
+
|
| 319 |
+
self.conv3d = nn.Conv3d(
|
| 320 |
+
hidden_dim, hidden_dim,
|
| 321 |
+
kernel_size=(temporal_patch_size, spatial_kernel_size, spatial_kernel_size),
|
| 322 |
+
padding=(0, (spatial_kernel_size - 1) // 2, (spatial_kernel_size - 1) // 2), # We manually pad the temporal dimension in forward, to support streaming
|
| 323 |
+
bias=True,
|
| 324 |
+
)
|
| 325 |
+
self.relu = nn.ReLU(inplace=True)
|
| 326 |
+
|
| 327 |
+
def forward(self, x, use_cache=False, past_conv_values=None):
|
| 328 |
+
if not (use_cache and past_conv_values is not None):
|
| 329 |
+
x = F.pad(x, (0, 0, 0, 0, self.temporal_patch_size - 1, 0), value=0)
|
| 330 |
+
else:
|
| 331 |
+
x = torch.cat([past_conv_values, x], dim=2)
|
| 332 |
+
new_past_conv_values = x[:, :, -(self.temporal_patch_size - 1):]
|
| 333 |
+
|
| 334 |
+
x = self.conv3d(x)
|
| 335 |
+
|
| 336 |
+
x = self.relu(x)
|
| 337 |
+
|
| 338 |
+
return x, new_past_conv_values
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class ShallowVideoConvNet(nn.Module):
|
| 342 |
+
"""
|
| 343 |
+
A shallow video convolutional network for video gaze modeling, inspired by ViViT's patch embedding approach.
|
| 344 |
+
Expects input of shape (B, T, C, H, W) or (B*T, C, H, W).
|
| 345 |
+
"""
|
| 346 |
+
def __init__(self, config: VisionModelConfig):
|
| 347 |
+
super().__init__()
|
| 348 |
+
hidden_dim = config.hidden_dim
|
| 349 |
+
out_dim = config.out_dim
|
| 350 |
+
depth = config.depth
|
| 351 |
+
kernel_size = config.kernel_size
|
| 352 |
+
self.temporal_patch_size = getattr(config, "temporal_patch_size", 1)
|
| 353 |
+
|
| 354 |
+
# For video, first merge temporal and batch if needed, then apply 3D conv for temporal patching
|
| 355 |
+
self.temporal_conv = nn.Conv3d(
|
| 356 |
+
in_channels=3, # RGB
|
| 357 |
+
out_channels=hidden_dim,
|
| 358 |
+
kernel_size=(self.temporal_patch_size, kernel_size, kernel_size),
|
| 359 |
+
stride=(self.temporal_patch_size, kernel_size, kernel_size),
|
| 360 |
+
bias=True,
|
| 361 |
+
)
|
| 362 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
| 363 |
+
|
| 364 |
+
self.trunk_temporal_kernel_size = config.trunk_temporal_kernel_size
|
| 365 |
+
self.trunk_spatial_kernel_size = config.trunk_spatial_kernel_size
|
| 366 |
+
blocks = []
|
| 367 |
+
for i in range(depth):
|
| 368 |
+
blocks.append(
|
| 369 |
+
Conv3dBlockForStreaming(
|
| 370 |
+
hidden_dim=hidden_dim,
|
| 371 |
+
temporal_patch_size=self.trunk_temporal_kernel_size,
|
| 372 |
+
spatial_kernel_size=self.trunk_spatial_kernel_size,
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
self.blocks = nn.ModuleList(blocks)
|
| 376 |
+
|
| 377 |
+
self.out_proj = nn.Conv3d(
|
| 378 |
+
hidden_dim, out_dim, kernel_size=1, stride=1, bias=True
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def forward(self, x, use_cache=False, past_conv_values=None):
|
| 382 |
+
# x: (B, T, C, H, W) or (B*T, C, H, W)
|
| 383 |
+
if x.dim() == 5:
|
| 384 |
+
# (B, T, C, H, W) -> (B, C, T, H, W)
|
| 385 |
+
x = x.permute(0, 2, 1, 3, 4)
|
| 386 |
+
elif x.dim() == 4:
|
| 387 |
+
# (B*T, C, H, W) -> (B*T, C, 1, H, W)
|
| 388 |
+
x = x.unsqueeze(2)
|
| 389 |
+
else:
|
| 390 |
+
raise ValueError("Input must be 4D or 5D tensor")
|
| 391 |
+
x = self.temporal_conv(x) # (B, hidden_dim, T', H', W')
|
| 392 |
+
# Collapse temporal dimension into batch for normalization and blocks
|
| 393 |
+
B, C, T, H, W = x.shape
|
| 394 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W) # (B*T, C, H, W)
|
| 395 |
+
# Flatten spatial dims for norm: (B*T, C, H*W)
|
| 396 |
+
x = x.view(B * T, C, -1).permute(0, 2, 1) # (B*T, H*W, C)
|
| 397 |
+
x = self.norm(x)
|
| 398 |
+
x = x.permute(0, 2, 1).contiguous().view(B * T, C, H, W) # (B*T, C, H, W)
|
| 399 |
+
# Reshape back to (B, C, T, H, W)
|
| 400 |
+
x = x.view(B, T, C, H, W).permute(0, 2, 1, 3, 4)
|
| 401 |
+
# Main trunk
|
| 402 |
+
new_past_conv_values = []
|
| 403 |
+
for i, block in enumerate(self.blocks):
|
| 404 |
+
x, new_past_conv_values_i = block(
|
| 405 |
+
x,
|
| 406 |
+
use_cache=use_cache,
|
| 407 |
+
past_conv_values=past_conv_values[i] if use_cache and past_conv_values is not None else None
|
| 408 |
+
)
|
| 409 |
+
new_past_conv_values.append(new_past_conv_values_i)
|
| 410 |
+
x = self.out_proj(x)
|
| 411 |
+
# Output shape: (B, out_dim, T', H', W')
|
| 412 |
+
return x, new_past_conv_values
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
################ Connector Between Vision Encoder and Gaze Model #################
|
| 416 |
+
|
| 417 |
+
class Connector(nn.Module):
|
| 418 |
+
def __init__(self, config: ConnectorConfig):
|
| 419 |
+
super().__init__()
|
| 420 |
+
|
| 421 |
+
self.hidden_dim = config.hidden_dim
|
| 422 |
+
self.num_tokens = config.num_tokens
|
| 423 |
+
|
| 424 |
+
self.pos_embed = nn.Parameter(torch.randn(self.num_tokens, self.hidden_dim))
|
| 425 |
+
|
| 426 |
+
def forward(self, x):
|
| 427 |
+
"""
|
| 428 |
+
x: (B, T, N, C)
|
| 429 |
+
"""
|
| 430 |
+
x = x + self.pos_embed[None, None]
|
| 431 |
+
return x
|
autogaze/models/autogaze/modeling_llama_multi_token_pred.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
import os
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
| 23 |
+
from copy import deepcopy
|
| 24 |
+
from einops import rearrange
|
| 25 |
+
from importlib.metadata import version
|
| 26 |
+
from packaging.version import Version
|
| 27 |
+
import torch
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
from torch import nn
|
| 30 |
+
|
| 31 |
+
from transformers.cache_utils import Cache
|
| 32 |
+
from transformers.generation import GenerationMixin
|
| 33 |
+
from transformers.modeling_outputs import (
|
| 34 |
+
BaseModelOutputWithPast,
|
| 35 |
+
CausalLMOutputWithPast,
|
| 36 |
+
)
|
| 37 |
+
from transformers.utils import (
|
| 38 |
+
can_return_tuple,
|
| 39 |
+
)
|
| 40 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 41 |
+
from transformers.models.llama.modeling_llama import (
|
| 42 |
+
LlamaModel,
|
| 43 |
+
LlamaPreTrainedModel,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
from transformers.cache_utils import (
|
| 47 |
+
Cache,
|
| 48 |
+
)
|
| 49 |
+
from transformers.utils import (
|
| 50 |
+
ModelOutput,
|
| 51 |
+
)
|
| 52 |
+
from transformers.generation.configuration_utils import (
|
| 53 |
+
GenerationConfig,
|
| 54 |
+
)
|
| 55 |
+
from transformers.generation.logits_process import (
|
| 56 |
+
LogitsProcessorList,
|
| 57 |
+
)
|
| 58 |
+
from transformers.generation.stopping_criteria import (
|
| 59 |
+
StoppingCriteriaList,
|
| 60 |
+
)
|
| 61 |
+
from transformers.generation.utils import (
|
| 62 |
+
GenerateNonBeamOutput,
|
| 63 |
+
GenerateDecoderOnlyOutput,
|
| 64 |
+
GenerateEncoderDecoderOutput,
|
| 65 |
+
ALL_CACHE_NAMES,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if TYPE_CHECKING:
|
| 69 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 70 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 71 |
+
from transformers.generation.streamers import BaseStreamer
|
| 72 |
+
|
| 73 |
+
LOW_TRANSFORMERS_VERSION = Version(version("transformers")) < Version("4.52.0")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class CausalLMOutputWithPast(ModelOutput):
|
| 78 |
+
"""
|
| 79 |
+
Base class for causal language model (or autoregressive) outputs.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 83 |
+
Language modeling loss (for next-token prediction).
|
| 84 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 85 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 86 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 87 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 88 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 89 |
+
|
| 90 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 91 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 92 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 93 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 94 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 95 |
+
|
| 96 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 97 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 98 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 99 |
+
sequence_length)`.
|
| 100 |
+
|
| 101 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 102 |
+
heads.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
loss: Optional[torch.FloatTensor] = None
|
| 106 |
+
logits: Optional[torch.FloatTensor] = None
|
| 107 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 108 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 109 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 110 |
+
task_loss_prediction: Optional[torch.FloatTensor] = None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class LlamaForCausalLM_MultiTokenPred(LlamaPreTrainedModel, GenerationMixin):
|
| 114 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 115 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 116 |
+
|
| 117 |
+
def __init__(self, config):
|
| 118 |
+
super().__init__(config)
|
| 119 |
+
self.model = LlamaModel(config)
|
| 120 |
+
self.vocab_size = config.vocab_size
|
| 121 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * config.num_multi_token_pred, bias=False)
|
| 122 |
+
self.task_loss_prediction_head = nn.Linear(config.hidden_size, config.num_multi_token_pred, bias=False)
|
| 123 |
+
|
| 124 |
+
# Initialize weights and apply final processing
|
| 125 |
+
self.post_init()
|
| 126 |
+
|
| 127 |
+
def get_input_embeddings(self):
|
| 128 |
+
return self.model.embed_tokens
|
| 129 |
+
|
| 130 |
+
def set_input_embeddings(self, value):
|
| 131 |
+
self.model.embed_tokens = value
|
| 132 |
+
|
| 133 |
+
def get_output_embeddings(self):
|
| 134 |
+
return self.lm_head
|
| 135 |
+
|
| 136 |
+
def set_output_embeddings(self, new_embeddings):
|
| 137 |
+
self.lm_head = new_embeddings
|
| 138 |
+
|
| 139 |
+
def set_decoder(self, decoder):
|
| 140 |
+
self.model = decoder
|
| 141 |
+
|
| 142 |
+
def get_decoder(self):
|
| 143 |
+
return self.model
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@can_return_tuple
|
| 147 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
| 148 |
+
def forward(
|
| 149 |
+
self,
|
| 150 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 151 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 152 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 153 |
+
past_key_values: Optional[Cache] = None,
|
| 154 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 155 |
+
labels: Optional[torch.LongTensor] = None,
|
| 156 |
+
use_cache: Optional[bool] = None,
|
| 157 |
+
output_attentions: Optional[bool] = None,
|
| 158 |
+
output_hidden_states: Optional[bool] = None,
|
| 159 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 160 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 161 |
+
**kwargs,
|
| 162 |
+
) -> CausalLMOutputWithPast:
|
| 163 |
+
"""
|
| 164 |
+
This function is copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.forward.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 168 |
+
output_hidden_states = (
|
| 169 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 173 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 174 |
+
input_ids=input_ids,
|
| 175 |
+
attention_mask=attention_mask,
|
| 176 |
+
position_ids=position_ids,
|
| 177 |
+
past_key_values=past_key_values,
|
| 178 |
+
inputs_embeds=inputs_embeds,
|
| 179 |
+
use_cache=use_cache,
|
| 180 |
+
output_attentions=output_attentions,
|
| 181 |
+
output_hidden_states=output_hidden_states,
|
| 182 |
+
cache_position=cache_position,
|
| 183 |
+
**kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
hidden_states = outputs.last_hidden_state
|
| 187 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 188 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 189 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 190 |
+
task_loss_prediction = self.task_loss_prediction_head(hidden_states[:, slice_indices, :])
|
| 191 |
+
|
| 192 |
+
loss = None
|
| 193 |
+
if labels is not None:
|
| 194 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 195 |
+
|
| 196 |
+
return CausalLMOutputWithPast(
|
| 197 |
+
loss=loss,
|
| 198 |
+
logits=logits,
|
| 199 |
+
past_key_values=outputs.past_key_values,
|
| 200 |
+
hidden_states=outputs.hidden_states,
|
| 201 |
+
attentions=outputs.attentions,
|
| 202 |
+
task_loss_prediction=task_loss_prediction,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def _update_model_kwargs_for_generation(
|
| 206 |
+
self,
|
| 207 |
+
outputs: ModelOutput,
|
| 208 |
+
model_kwargs: Dict[str, Any],
|
| 209 |
+
is_encoder_decoder: bool = False,
|
| 210 |
+
num_new_tokens: int = 1,
|
| 211 |
+
) -> Dict[str, Any]:
|
| 212 |
+
# update past_key_values keeping its naming used in model code
|
| 213 |
+
for possible_cache_name in ALL_CACHE_NAMES:
|
| 214 |
+
if possible_cache_name in outputs:
|
| 215 |
+
# TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
|
| 216 |
+
if possible_cache_name in ("past_buckets_states", "mems"):
|
| 217 |
+
cache_name = "past_key_values"
|
| 218 |
+
else:
|
| 219 |
+
cache_name = possible_cache_name
|
| 220 |
+
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
# update token_type_ids with last value
|
| 224 |
+
if "token_type_ids" in model_kwargs:
|
| 225 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
| 226 |
+
assert token_type_ids.dim() == 2
|
| 227 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1).repeat(1, num_new_tokens)], dim=-1)
|
| 228 |
+
|
| 229 |
+
if not is_encoder_decoder:
|
| 230 |
+
# update attention mask
|
| 231 |
+
if "attention_mask" in model_kwargs:
|
| 232 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 233 |
+
model_kwargs["attention_mask"] = torch.cat(
|
| 234 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
# update decoder attention mask
|
| 238 |
+
if "decoder_attention_mask" in model_kwargs:
|
| 239 |
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
| 240 |
+
model_kwargs["decoder_attention_mask"] = torch.cat(
|
| 241 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], num_new_tokens))],
|
| 242 |
+
dim=-1,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Since we generate multiple tokens at once, the number of new tokens > 1 and all those tokens need to
|
| 246 |
+
# be cached later.
|
| 247 |
+
if model_kwargs.get("use_cache", True):
|
| 248 |
+
model_kwargs["cache_position"] = torch.arange(
|
| 249 |
+
model_kwargs["cache_position"][-1] + 1, model_kwargs["cache_position"][-1] + num_new_tokens + 1, dtype=model_kwargs["cache_position"].dtype
|
| 250 |
+
).to(model_kwargs["cache_position"].device)
|
| 251 |
+
else:
|
| 252 |
+
past_positions = model_kwargs.pop("cache_position")
|
| 253 |
+
new_positions = torch.arange(
|
| 254 |
+
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
|
| 255 |
+
).to(past_positions.device)
|
| 256 |
+
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
|
| 257 |
+
return model_kwargs
|
| 258 |
+
|
| 259 |
+
def _sample(
|
| 260 |
+
self,
|
| 261 |
+
input_ids: torch.LongTensor,
|
| 262 |
+
logits_processor: LogitsProcessorList,
|
| 263 |
+
stopping_criteria: StoppingCriteriaList,
|
| 264 |
+
generation_config: GenerationConfig,
|
| 265 |
+
synced_gpus: bool,
|
| 266 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 267 |
+
**model_kwargs,
|
| 268 |
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 269 |
+
"""
|
| 270 |
+
This function is copied from transformers.generation.utils.GenerationMixin._sample.
|
| 271 |
+
"""
|
| 272 |
+
# init values
|
| 273 |
+
pad_token_id = generation_config._pad_token_tensor
|
| 274 |
+
output_attentions = generation_config.output_attentions
|
| 275 |
+
output_hidden_states = generation_config.output_hidden_states
|
| 276 |
+
output_scores = generation_config.output_scores
|
| 277 |
+
output_logits = generation_config.output_logits
|
| 278 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 279 |
+
do_sample = generation_config.do_sample
|
| 280 |
+
task_loss_requirement = generation_config.task_loss_requirement
|
| 281 |
+
|
| 282 |
+
# init attention / hidden states / scores tuples
|
| 283 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
| 284 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
| 285 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 286 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 287 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
| 288 |
+
|
| 289 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
| 290 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
| 291 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
| 292 |
+
encoder_hidden_states = (
|
| 293 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# keep track of which sequences are already finished
|
| 297 |
+
batch_size, cur_len = input_ids.shape
|
| 298 |
+
this_peer_finished = False
|
| 299 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 300 |
+
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) if LOW_TRANSFORMERS_VERSION else self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
| 301 |
+
|
| 302 |
+
model_forward = self.__call__
|
| 303 |
+
if isinstance(model_kwargs.get("past_key_values"), Cache):
|
| 304 |
+
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
|
| 305 |
+
if getattr(self, "hf_quantizer", None) is not None:
|
| 306 |
+
is_compileable &= self.hf_quantizer.is_compileable
|
| 307 |
+
is_compileable = is_compileable and not generation_config.disable_compile
|
| 308 |
+
if is_compileable and (
|
| 309 |
+
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
|
| 310 |
+
):
|
| 311 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
| 312 |
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
| 313 |
+
|
| 314 |
+
if generation_config.prefill_chunk_size is not None:
|
| 315 |
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
| 316 |
+
is_prefill = False
|
| 317 |
+
else:
|
| 318 |
+
is_prefill = True
|
| 319 |
+
|
| 320 |
+
is_first_token = True
|
| 321 |
+
final_past_key_values = None
|
| 322 |
+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 323 |
+
# prepare model inputs
|
| 324 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 325 |
+
|
| 326 |
+
# prepare variable output controls (note: some models won't accept all output controls)
|
| 327 |
+
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
| 328 |
+
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
| 329 |
+
|
| 330 |
+
if is_prefill:
|
| 331 |
+
outputs = self(**model_inputs, return_dict=True)
|
| 332 |
+
is_prefill = False
|
| 333 |
+
else:
|
| 334 |
+
outputs = model_forward(**model_inputs, return_dict=True)
|
| 335 |
+
|
| 336 |
+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 337 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 338 |
+
outputs,
|
| 339 |
+
model_kwargs,
|
| 340 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
| 341 |
+
num_new_tokens=self.config.num_multi_token_pred,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if synced_gpus and this_peer_finished:
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
# Deepcopy the to-return past_key_values, such that it won't change during the dummy model forwards when this peer is finished but peers from other devices are still generating.
|
| 348 |
+
final_past_key_values = deepcopy(model_kwargs.get("past_key_values"))
|
| 349 |
+
|
| 350 |
+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
| 351 |
+
# (the clone itself is always small)
|
| 352 |
+
next_token_logits_all = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
| 353 |
+
task_loss_prediction_all = outputs.task_loss_prediction[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
| 354 |
+
|
| 355 |
+
# Process all new tokens at once.
|
| 356 |
+
next_token_logits_all = rearrange(next_token_logits_all, 'b (n v) -> b n v', n=self.config.num_multi_token_pred, v=self.config.vocab_size)
|
| 357 |
+
|
| 358 |
+
# pre-process distribution
|
| 359 |
+
next_token_scores_all = logits_processor(input_ids, next_token_logits_all)
|
| 360 |
+
|
| 361 |
+
# token selection
|
| 362 |
+
next_tokens_all = []
|
| 363 |
+
early_stopped = False
|
| 364 |
+
for i in range(self.config.num_multi_token_pred):
|
| 365 |
+
next_token_scores_i = next_token_scores_all[:, i, :]
|
| 366 |
+
|
| 367 |
+
# exit early if meeting max number of token or not token left to choose
|
| 368 |
+
if len(next_tokens_all) + input_ids.shape[1] >= generation_config.max_new_tokens:
|
| 369 |
+
early_stopped = True
|
| 370 |
+
break
|
| 371 |
+
if torch.all(next_token_scores_i == -float("inf")):
|
| 372 |
+
early_stopped = True
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
if do_sample:
|
| 376 |
+
probs = nn.functional.softmax(next_token_scores_i, dim=-1)
|
| 377 |
+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
| 378 |
+
next_tokens_i = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 379 |
+
else:
|
| 380 |
+
next_tokens_i = torch.argmax(next_token_scores_i, dim=-1)
|
| 381 |
+
|
| 382 |
+
next_tokens_all.append(next_tokens_i)
|
| 383 |
+
|
| 384 |
+
# avoid repeating gazing
|
| 385 |
+
next_token_scores_all[torch.arange(next_tokens_i.shape[0]), i + 1:, next_tokens_i] = -float("inf")
|
| 386 |
+
|
| 387 |
+
next_tokens_all = torch.stack(next_tokens_all, dim=1)
|
| 388 |
+
|
| 389 |
+
# already finished sentences should have their next token be a padding token
|
| 390 |
+
next_tokens_all = next_tokens_all * unfinished_sequences[..., None] + pad_token_id * (1 - unfinished_sequences[..., None])
|
| 391 |
+
|
| 392 |
+
# Mark finished if task loss requirement is met
|
| 393 |
+
meet_task_loss_requirement = torch.zeros_like(next_tokens_all, dtype=torch.bool)
|
| 394 |
+
if task_loss_requirement is not None:
|
| 395 |
+
meet_task_loss_requirement = task_loss_prediction_all[:, :next_tokens_all.shape[1]] <= task_loss_requirement[..., None]
|
| 396 |
+
if is_first_token:
|
| 397 |
+
meet_task_loss_requirement[:, 0] = False
|
| 398 |
+
next_tokens_all = next_tokens_all * (~meet_task_loss_requirement) + pad_token_id * meet_task_loss_requirement
|
| 399 |
+
|
| 400 |
+
# Truncate the next tokens to the max new tokens
|
| 401 |
+
meet_max_new_tokens = False
|
| 402 |
+
if next_tokens_all.shape[1] + input_ids.shape[1] >= generation_config.max_new_tokens:
|
| 403 |
+
next_tokens_all = next_tokens_all[:, :generation_config.max_new_tokens - input_ids.shape[1]]
|
| 404 |
+
meet_max_new_tokens = True
|
| 405 |
+
|
| 406 |
+
# update generated ids, model inputs, and length for next step
|
| 407 |
+
input_ids = torch.cat([input_ids, next_tokens_all], dim=-1)
|
| 408 |
+
if streamer is not None:
|
| 409 |
+
for i in range(next_tokens_all.shape[1]):
|
| 410 |
+
streamer.put(next_tokens_all[:, i].cpu())
|
| 411 |
+
|
| 412 |
+
# Update the finishing flags
|
| 413 |
+
unfinished_sequences = unfinished_sequences & ~torch.any(meet_task_loss_requirement, dim=-1) & ~meet_max_new_tokens & ~early_stopped
|
| 414 |
+
this_peer_finished = unfinished_sequences.max() == 0
|
| 415 |
+
cur_len += 1
|
| 416 |
+
|
| 417 |
+
is_first_token = False
|
| 418 |
+
|
| 419 |
+
# Store scores, attentions and hidden_states when required
|
| 420 |
+
if return_dict_in_generate:
|
| 421 |
+
if output_scores:
|
| 422 |
+
scores += (next_token_scores_all,)
|
| 423 |
+
if output_logits:
|
| 424 |
+
raw_logits += (next_token_logits_all,)
|
| 425 |
+
if output_attentions:
|
| 426 |
+
decoder_attentions += (
|
| 427 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
| 428 |
+
)
|
| 429 |
+
if self.config.is_encoder_decoder:
|
| 430 |
+
cross_attentions += (outputs.cross_attentions,)
|
| 431 |
+
|
| 432 |
+
if output_hidden_states:
|
| 433 |
+
decoder_hidden_states += (
|
| 434 |
+
(outputs.decoder_hidden_states,)
|
| 435 |
+
if self.config.is_encoder_decoder
|
| 436 |
+
else (outputs.hidden_states,)
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
| 442 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
| 443 |
+
del outputs
|
| 444 |
+
|
| 445 |
+
if streamer is not None:
|
| 446 |
+
streamer.end()
|
| 447 |
+
|
| 448 |
+
if return_dict_in_generate:
|
| 449 |
+
if self.config.is_encoder_decoder:
|
| 450 |
+
return GenerateEncoderDecoderOutput(
|
| 451 |
+
sequences=input_ids,
|
| 452 |
+
scores=scores,
|
| 453 |
+
logits=raw_logits,
|
| 454 |
+
encoder_attentions=encoder_attentions,
|
| 455 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 456 |
+
decoder_attentions=decoder_attentions,
|
| 457 |
+
cross_attentions=cross_attentions,
|
| 458 |
+
decoder_hidden_states=decoder_hidden_states,
|
| 459 |
+
past_key_values=final_past_key_values,
|
| 460 |
+
)
|
| 461 |
+
else:
|
| 462 |
+
return GenerateDecoderOnlyOutput(
|
| 463 |
+
sequences=input_ids,
|
| 464 |
+
scores=scores,
|
| 465 |
+
logits=raw_logits,
|
| 466 |
+
attentions=decoder_attentions,
|
| 467 |
+
hidden_states=decoder_hidden_states,
|
| 468 |
+
past_key_values=final_past_key_values,
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
return input_ids
|
autogaze/tasks/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""AutoGaze tasks."""
|
autogaze/tasks/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
autogaze/tasks/video_mae_reconstruction/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .task_video_mae_reconstruction import VideoMAEReconstruction
|
autogaze/tasks/video_mae_reconstruction/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (250 Bytes). View file
|
|
|
autogaze/tasks/video_mae_reconstruction/__pycache__/configuration_video_mae.cpython-310.pyc
ADDED
|
Binary file (5.86 kB). View file
|
|
|
autogaze/tasks/video_mae_reconstruction/__pycache__/modeling_video_mae.cpython-310.pyc
ADDED
|
Binary file (43.9 kB). View file
|
|
|
autogaze/tasks/video_mae_reconstruction/__pycache__/task_video_mae_reconstruction.cpython-310.pyc
ADDED
|
Binary file (6.91 kB). View file
|
|
|
autogaze/tasks/video_mae_reconstruction/__pycache__/visualize_video_mae_reconstruction.cpython-310.pyc
ADDED
|
Binary file (3.44 kB). View file
|
|
|
autogaze/tasks/video_mae_reconstruction/configuration_video_mae.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""ViT MAE model configuration"""
|
| 16 |
+
|
| 17 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ViTMAEConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`ViTMAEModel`]. It is used to instantiate an ViT
|
| 27 |
+
MAE model according to the specified arguments, defining the model architecture. Instantiating a configuration with
|
| 28 |
+
the defaults will yield a similar configuration to that of the ViT
|
| 29 |
+
[facebook/vit-mae-base](https://huggingface.co/facebook/vit-mae-base) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 37 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 38 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 39 |
+
Number of hidden layers in the Transformer encoder.
|
| 40 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 41 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 42 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 43 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 47 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 48 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 49 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout ratio for the attention probabilities.
|
| 51 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 52 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 53 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 54 |
+
The epsilon used by the layer normalization layers.
|
| 55 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 56 |
+
The size (resolution) of each image.
|
| 57 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 58 |
+
The size (resolution) of each patch.
|
| 59 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 60 |
+
The number of input channels.
|
| 61 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to add a bias to the queries, keys and values.
|
| 63 |
+
decoder_num_attention_heads (`int`, *optional*, defaults to 16):
|
| 64 |
+
Number of attention heads for each attention layer in the decoder.
|
| 65 |
+
decoder_hidden_size (`int`, *optional*, defaults to 512):
|
| 66 |
+
Dimensionality of the decoder.
|
| 67 |
+
decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
|
| 68 |
+
Number of hidden layers in the decoder.
|
| 69 |
+
decoder_intermediate_size (`int`, *optional*, defaults to 2048):
|
| 70 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
|
| 71 |
+
mask_ratio (`float`, *optional*, defaults to 0.75):
|
| 72 |
+
The ratio of the number of masked tokens in the input sequence.
|
| 73 |
+
norm_pix_loss (`bool`, *optional*, defaults to `False`):
|
| 74 |
+
Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved
|
| 75 |
+
representation quality in the experiments of the authors.
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
>>> from transformers import ViTMAEConfig, ViTMAEModel
|
| 81 |
+
|
| 82 |
+
>>> # Initializing a ViT MAE vit-mae-base style configuration
|
| 83 |
+
>>> configuration = ViTMAEConfig()
|
| 84 |
+
|
| 85 |
+
>>> # Initializing a model (with random weights) from the vit-mae-base style configuration
|
| 86 |
+
>>> model = ViTMAEModel(configuration)
|
| 87 |
+
|
| 88 |
+
>>> # Accessing the model configuration
|
| 89 |
+
>>> configuration = model.config
|
| 90 |
+
```"""
|
| 91 |
+
|
| 92 |
+
model_type = "vit_mae"
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
hidden_size=768,
|
| 97 |
+
num_hidden_layers=12,
|
| 98 |
+
num_attention_heads=12,
|
| 99 |
+
intermediate_size=3072,
|
| 100 |
+
hidden_act="gelu",
|
| 101 |
+
hidden_dropout_prob=0.0,
|
| 102 |
+
attention_probs_dropout_prob=0.0,
|
| 103 |
+
initializer_range=0.02,
|
| 104 |
+
layer_norm_eps=1e-12,
|
| 105 |
+
image_size=224,
|
| 106 |
+
patch_size=16,
|
| 107 |
+
num_channels=3,
|
| 108 |
+
qkv_bias=True,
|
| 109 |
+
decoder_num_attention_heads=16,
|
| 110 |
+
decoder_hidden_size=512,
|
| 111 |
+
decoder_num_hidden_layers=8,
|
| 112 |
+
decoder_intermediate_size=2048,
|
| 113 |
+
mask_ratio=0.75,
|
| 114 |
+
norm_pix_loss=False,
|
| 115 |
+
scales='224',
|
| 116 |
+
loss_type='l1',
|
| 117 |
+
loss_weights='1',
|
| 118 |
+
l1_loss_config=None,
|
| 119 |
+
dinov2_reg_loss_config=None,
|
| 120 |
+
siglip2_loss_config=None,
|
| 121 |
+
scale_embed=True,
|
| 122 |
+
max_num_frames=256,
|
| 123 |
+
time_embed=True,
|
| 124 |
+
causal=True,
|
| 125 |
+
**kwargs,
|
| 126 |
+
):
|
| 127 |
+
super().__init__(**kwargs)
|
| 128 |
+
|
| 129 |
+
self.hidden_size = hidden_size
|
| 130 |
+
self.num_hidden_layers = num_hidden_layers
|
| 131 |
+
self.num_attention_heads = num_attention_heads
|
| 132 |
+
self.intermediate_size = intermediate_size
|
| 133 |
+
self.hidden_act = hidden_act
|
| 134 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 135 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 136 |
+
self.initializer_range = initializer_range
|
| 137 |
+
self.layer_norm_eps = layer_norm_eps
|
| 138 |
+
self.image_size = image_size
|
| 139 |
+
self.patch_size = patch_size
|
| 140 |
+
self.num_channels = num_channels
|
| 141 |
+
self.qkv_bias = qkv_bias
|
| 142 |
+
self.decoder_num_attention_heads = decoder_num_attention_heads
|
| 143 |
+
self.decoder_hidden_size = decoder_hidden_size
|
| 144 |
+
self.decoder_num_hidden_layers = decoder_num_hidden_layers
|
| 145 |
+
self.decoder_intermediate_size = decoder_intermediate_size
|
| 146 |
+
self.mask_ratio = mask_ratio
|
| 147 |
+
self.norm_pix_loss = norm_pix_loss
|
| 148 |
+
self.scales = scales
|
| 149 |
+
self.loss_type = loss_type
|
| 150 |
+
self.loss_weights = loss_weights
|
| 151 |
+
self.l1_loss_config = l1_loss_config
|
| 152 |
+
self.dinov2_reg_loss_config = dinov2_reg_loss_config
|
| 153 |
+
self.siglip2_loss_config = siglip2_loss_config
|
| 154 |
+
self.scale_embed = scale_embed
|
| 155 |
+
self.max_num_frames = max_num_frames
|
| 156 |
+
self.time_embed = time_embed
|
| 157 |
+
self.causal = causal
|
| 158 |
+
|
| 159 |
+
__all__ = ["ViTMAEConfig"]
|
autogaze/tasks/video_mae_reconstruction/modeling_video_mae.py
ADDED
|
@@ -0,0 +1,1412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ViT MAE (masked autoencoder) model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Callable, Optional, Set, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from einops import rearrange
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 30 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 31 |
+
from transformers.pytorch_utils import prune_linear_layer
|
| 32 |
+
from transformers.utils import (
|
| 33 |
+
ModelOutput,
|
| 34 |
+
add_start_docstrings,
|
| 35 |
+
add_start_docstrings_to_model_forward,
|
| 36 |
+
logging,
|
| 37 |
+
replace_return_docstrings,
|
| 38 |
+
torch_int,
|
| 39 |
+
)
|
| 40 |
+
from .configuration_video_mae import ViTMAEConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads):
|
| 47 |
+
"""
|
| 48 |
+
Finds the heads and their indices taking `already_pruned_heads` into account.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
heads (`Set[int]`): A set of head indices we want to prune.
|
| 52 |
+
n_heads (`int`): The number of heads in the model.
|
| 53 |
+
head_size (`int`): The size of each head.
|
| 54 |
+
already_pruned_heads (`Set[int]`): A set of already pruned heads.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
|
| 58 |
+
"""
|
| 59 |
+
mask = torch.ones(n_heads, head_size)
|
| 60 |
+
heads = set(heads) - already_pruned_heads
|
| 61 |
+
for head in heads:
|
| 62 |
+
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
| 63 |
+
mask[head] = 0
|
| 64 |
+
mask = mask.view(-1).contiguous().eq(1)
|
| 65 |
+
index = torch.arange(len(mask))[mask].long()
|
| 66 |
+
return heads, index
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
class ViTMAEModelOutput(ModelOutput):
|
| 71 |
+
"""
|
| 72 |
+
Class for ViTMAEModel's outputs, with potential hidden states and attentions.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 76 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 77 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 78 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 79 |
+
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 80 |
+
Tensor containing the original index of the (shuffled) masked patches.
|
| 81 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 82 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 83 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 84 |
+
plus the initial embedding outputs.
|
| 85 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 86 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 87 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 88 |
+
the self-attention heads.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 92 |
+
mask: Optional[torch.LongTensor] = None
|
| 93 |
+
ids_restore: Optional[torch.LongTensor] = None
|
| 94 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 95 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class ViTMAEDecoderOutput(ModelOutput):
|
| 100 |
+
"""
|
| 101 |
+
Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
|
| 105 |
+
Pixel reconstruction logits.
|
| 106 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 107 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 108 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 109 |
+
plus the initial embedding outputs.
|
| 110 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 111 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 112 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 113 |
+
the self-attention heads.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
logits: Optional[torch.FloatTensor] = None
|
| 117 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 118 |
+
num_decoded_tokens_each_frame: Optional[torch.LongTensor] = None
|
| 119 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclass
|
| 123 |
+
class ViTMAEForPreTrainingOutput(ModelOutput):
|
| 124 |
+
"""
|
| 125 |
+
Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
loss_each_reconstruction_frame (`torch.FloatTensor` of shape `(batch_size, num_selected_frames)`):
|
| 129 |
+
Pixel reconstruction loss for each reconstruction frame.
|
| 130 |
+
loss_mean (`torch.FloatTensor` of shape `(1,)`):
|
| 131 |
+
Mean of the pixel reconstruction loss for each reconstruction frame.
|
| 132 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
|
| 133 |
+
Pixel reconstruction logits.
|
| 134 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 135 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 136 |
+
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 137 |
+
Tensor containing the original index of the (shuffled) masked patches.
|
| 138 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 139 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 140 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 141 |
+
plus the initial embedding outputs.
|
| 142 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 143 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 144 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 145 |
+
the self-attention heads.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
loss_each_reconstruction_frame: Optional[torch.FloatTensor] = None
|
| 149 |
+
loss_mean: Optional[torch.FloatTensor] = None
|
| 150 |
+
reconstruction: Optional[torch.FloatTensor] = None
|
| 151 |
+
logits: Optional[torch.FloatTensor] = None
|
| 152 |
+
mask: Optional[torch.LongTensor] = None
|
| 153 |
+
ids_restore: Optional[torch.LongTensor] = None
|
| 154 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 155 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
| 159 |
+
"""
|
| 160 |
+
Create 2D sin/cos positional embeddings.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
embed_dim (`int`):
|
| 164 |
+
Embedding dimension.
|
| 165 |
+
grid_size (`int`):
|
| 166 |
+
The grid height and width.
|
| 167 |
+
add_cls_token (`bool`, *optional*, defaults to `False`):
|
| 168 |
+
Whether or not to add a classification (CLS) token.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
|
| 172 |
+
position embeddings (with or without classification token)
|
| 173 |
+
"""
|
| 174 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 175 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 176 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 177 |
+
grid = np.stack(grid, axis=0)
|
| 178 |
+
|
| 179 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 180 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 181 |
+
if add_cls_token:
|
| 182 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 183 |
+
return pos_embed
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 187 |
+
if embed_dim % 2 != 0:
|
| 188 |
+
raise ValueError("embed_dim must be even")
|
| 189 |
+
|
| 190 |
+
# use half of dimensions to encode grid_h
|
| 191 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 192 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 193 |
+
|
| 194 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 195 |
+
return emb
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 199 |
+
"""
|
| 200 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 201 |
+
"""
|
| 202 |
+
if embed_dim % 2 != 0:
|
| 203 |
+
raise ValueError("embed_dim must be even")
|
| 204 |
+
|
| 205 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 206 |
+
omega /= embed_dim / 2.0
|
| 207 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 208 |
+
|
| 209 |
+
pos = pos.reshape(-1) # (M,)
|
| 210 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 211 |
+
|
| 212 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 213 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 214 |
+
|
| 215 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 216 |
+
return emb
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class ViTMAEEmbeddings(nn.Module):
|
| 220 |
+
"""
|
| 221 |
+
Construct the CLS token, position and patch embeddings.
|
| 222 |
+
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, config):
|
| 226 |
+
super().__init__()
|
| 227 |
+
|
| 228 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 229 |
+
self.patch_embeddings = ViTMAEPatchEmbeddings(config)
|
| 230 |
+
self.num_patches = self.patch_embeddings.num_patches
|
| 231 |
+
# fixed sin-cos embedding
|
| 232 |
+
self.position_embeddings = nn.Parameter(
|
| 233 |
+
torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
|
| 234 |
+
)
|
| 235 |
+
self.patch_size = config.patch_size
|
| 236 |
+
self.config = config
|
| 237 |
+
|
| 238 |
+
# multi-scale setting
|
| 239 |
+
self.scales = sorted([int(scale) for scale in config.scales.split('+')])
|
| 240 |
+
self.num_patch_each_scale = [(scale // config.patch_size)**2 for scale in self.scales]
|
| 241 |
+
if config.scale_embed:
|
| 242 |
+
self.scale_embed = nn.Parameter(torch.randn(len(self.scales), config.hidden_size) * 0)
|
| 243 |
+
|
| 244 |
+
# time embedding
|
| 245 |
+
if config.time_embed:
|
| 246 |
+
self.time_embed = nn.Parameter(torch.randn(config.max_num_frames, config.hidden_size) * 0)
|
| 247 |
+
|
| 248 |
+
def initialize_weights(self):
|
| 249 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 250 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 251 |
+
self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
|
| 252 |
+
)
|
| 253 |
+
self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 254 |
+
|
| 255 |
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
| 256 |
+
w = self.patch_embeddings.projection.weight.data
|
| 257 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 258 |
+
|
| 259 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 260 |
+
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
|
| 261 |
+
|
| 262 |
+
# initialize scale embed
|
| 263 |
+
if self.config.scale_embed:
|
| 264 |
+
torch.nn.init.normal_(self.scale_embed, std=self.config.initializer_range)
|
| 265 |
+
|
| 266 |
+
# initialize time embed
|
| 267 |
+
if self.config.time_embed:
|
| 268 |
+
torch.nn.init.normal_(self.time_embed, std=self.config.initializer_range)
|
| 269 |
+
|
| 270 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
| 271 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 272 |
+
"""
|
| 273 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 274 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 275 |
+
|
| 276 |
+
Adapted from:
|
| 277 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 278 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
num_patches = embeddings.shape[1] - 1
|
| 282 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 283 |
+
|
| 284 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 285 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 286 |
+
return self.position_embeddings
|
| 287 |
+
|
| 288 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 289 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 290 |
+
|
| 291 |
+
dim = embeddings.shape[-1]
|
| 292 |
+
|
| 293 |
+
new_height = height // self.patch_size
|
| 294 |
+
new_width = width // self.patch_size
|
| 295 |
+
|
| 296 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 297 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 298 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 299 |
+
|
| 300 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 301 |
+
patch_pos_embed,
|
| 302 |
+
size=(new_height, new_width),
|
| 303 |
+
mode="bicubic",
|
| 304 |
+
align_corners=False,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 308 |
+
|
| 309 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 310 |
+
|
| 311 |
+
def mask_with_gazing(self, sequence, gazing_info):
|
| 312 |
+
"""
|
| 313 |
+
Mask the sequence with the gazing information.
|
| 314 |
+
For the padded gazing, we select a dummy token to fill in the positions (the dummy token is currently the first token in each sequence).
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
sequence: The sequence to mask.
|
| 318 |
+
gazing_info:
|
| 319 |
+
gazing_pos: The gazing positions of each whole sequence. (B, N)
|
| 320 |
+
num_gazing_each_frame: The number of gazing positions for each frame, including the padded gazing. (T, )
|
| 321 |
+
if_padded_gazing: Whether the gazing is padded. (B, N)
|
| 322 |
+
"""
|
| 323 |
+
gazing_pos = gazing_info['gazing_pos'].clone()
|
| 324 |
+
num_gazing_each_frame = gazing_info['num_gazing_each_frame'].clone()
|
| 325 |
+
if_padded_gazing = gazing_info['if_padded_gazing'].clone()
|
| 326 |
+
|
| 327 |
+
B, seq_length, dim = sequence.shape
|
| 328 |
+
gaze_length = gazing_pos.shape[1]
|
| 329 |
+
assert gaze_length == num_gazing_each_frame.sum()
|
| 330 |
+
|
| 331 |
+
# Record the original sequence length into gazing_info
|
| 332 |
+
gazing_info['original_seq_length'] = seq_length
|
| 333 |
+
|
| 334 |
+
# Pad the sequence with an additional token for padded gazing to select
|
| 335 |
+
sequence = torch.cat([sequence, sequence[:, :1]], dim=1)
|
| 336 |
+
|
| 337 |
+
# Change all the padded gazing id to the last token id
|
| 338 |
+
gazing_pos = gazing_pos.flatten()
|
| 339 |
+
gazing_pos[if_padded_gazing.flatten()] = seq_length
|
| 340 |
+
gazing_pos = gazing_pos.view(B, -1)
|
| 341 |
+
|
| 342 |
+
# Get the unmasked part of the sequence for MAE encoding
|
| 343 |
+
sequence_unmasked = sequence[torch.arange(B)[..., None], gazing_pos]
|
| 344 |
+
|
| 345 |
+
return sequence_unmasked
|
| 346 |
+
|
| 347 |
+
def forward(self, pixel_values, gazing_info=None, noise=None, interpolate_pos_encoding: bool = False):
|
| 348 |
+
"""
|
| 349 |
+
pixel_values: (B, T, C, H, W)
|
| 350 |
+
"""
|
| 351 |
+
B, T = pixel_values.shape[:2]
|
| 352 |
+
pixel_values = rearrange(pixel_values, 'b t c h w -> (b t) c h w')
|
| 353 |
+
|
| 354 |
+
embeddings = []
|
| 355 |
+
for i, scale in enumerate(self.scales):
|
| 356 |
+
pixel_values_cur_scale = F.interpolate(pixel_values, size=(scale, scale), mode="bicubic", align_corners=False)
|
| 357 |
+
embeddings_cur_scale = self.patch_embeddings(pixel_values_cur_scale, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 358 |
+
if interpolate_pos_encoding:
|
| 359 |
+
position_embeddings_cur_scale = self.interpolate_pos_encoding(embeddings_cur_scale, scale, scale)
|
| 360 |
+
else:
|
| 361 |
+
position_embeddings_cur_scale = self.position_embeddings
|
| 362 |
+
|
| 363 |
+
# add position embeddings w/o cls token
|
| 364 |
+
embeddings_cur_scale = embeddings_cur_scale + position_embeddings_cur_scale[:, 1:, :]
|
| 365 |
+
|
| 366 |
+
# add scale embedding
|
| 367 |
+
if self.config.scale_embed:
|
| 368 |
+
scale_embeddings_cur_scale = self.scale_embed[i][None, None]
|
| 369 |
+
embeddings_cur_scale = embeddings_cur_scale + scale_embeddings_cur_scale
|
| 370 |
+
|
| 371 |
+
embeddings.append(embeddings_cur_scale)
|
| 372 |
+
embeddings = torch.cat(embeddings, dim=1) # (B * T) * N * C
|
| 373 |
+
|
| 374 |
+
# add time embedding
|
| 375 |
+
embeddings = rearrange(embeddings, '(b t) n c -> b t n c', b=B, t=T) # B * T * N * C
|
| 376 |
+
if self.config.time_embed:
|
| 377 |
+
time_embeddings = self.time_embed[None, :T, None, :] # 1 * T * 1 * C
|
| 378 |
+
embeddings = embeddings + time_embeddings
|
| 379 |
+
|
| 380 |
+
embeddings = rearrange(embeddings, 'b t n c -> b (t n) c') # B * (T * N) * C
|
| 381 |
+
|
| 382 |
+
# masking: length -> length * config.mask_ratio
|
| 383 |
+
embeddings = self.mask_with_gazing(embeddings, gazing_info)
|
| 384 |
+
|
| 385 |
+
# append cls token
|
| 386 |
+
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
|
| 387 |
+
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
|
| 388 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 389 |
+
|
| 390 |
+
return embeddings
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class ViTMAEPatchEmbeddings(nn.Module):
|
| 394 |
+
"""
|
| 395 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 396 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 397 |
+
Transformer.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
def __init__(self, config):
|
| 401 |
+
super().__init__()
|
| 402 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 403 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 404 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 405 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 406 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 407 |
+
self.image_size = image_size
|
| 408 |
+
self.patch_size = patch_size
|
| 409 |
+
self.num_channels = num_channels
|
| 410 |
+
self.num_patches = num_patches
|
| 411 |
+
|
| 412 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 413 |
+
|
| 414 |
+
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
|
| 415 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 416 |
+
if num_channels != self.num_channels:
|
| 417 |
+
raise ValueError(
|
| 418 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 424 |
+
)
|
| 425 |
+
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 426 |
+
return x
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
| 430 |
+
def eager_attention_forward(
|
| 431 |
+
module: nn.Module,
|
| 432 |
+
query: torch.Tensor,
|
| 433 |
+
key: torch.Tensor,
|
| 434 |
+
value: torch.Tensor,
|
| 435 |
+
attention_mask: Optional[torch.Tensor],
|
| 436 |
+
scaling: float,
|
| 437 |
+
dropout: float = 0.0,
|
| 438 |
+
**kwargs,
|
| 439 |
+
):
|
| 440 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 441 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 442 |
+
|
| 443 |
+
# Normalize the attention scores to probabilities.
|
| 444 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 445 |
+
|
| 446 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 447 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 448 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 449 |
+
|
| 450 |
+
# Mask heads if we want to
|
| 451 |
+
if attention_mask is not None:
|
| 452 |
+
attn_weights = attn_weights * attention_mask
|
| 453 |
+
|
| 454 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 455 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 456 |
+
|
| 457 |
+
return attn_output, attn_weights
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
|
| 461 |
+
class ViTMAESelfAttention(nn.Module):
|
| 462 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 463 |
+
super().__init__()
|
| 464 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 467 |
+
f"heads {config.num_attention_heads}."
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
self.config = config
|
| 471 |
+
self.num_attention_heads = config.num_attention_heads
|
| 472 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 473 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 474 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 475 |
+
self.scaling = self.attention_head_size**-0.5
|
| 476 |
+
self.is_causal = False
|
| 477 |
+
|
| 478 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 479 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 480 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 481 |
+
|
| 482 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 483 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 484 |
+
x = x.view(new_x_shape)
|
| 485 |
+
return x.permute(0, 2, 1, 3)
|
| 486 |
+
|
| 487 |
+
def forward(
|
| 488 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
| 489 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 490 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 491 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 492 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 493 |
+
|
| 494 |
+
attention_interface: Callable = eager_attention_forward
|
| 495 |
+
if self.config._attn_implementation != "eager":
|
| 496 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 497 |
+
logger.warning_once(
|
| 498 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 499 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 500 |
+
)
|
| 501 |
+
assert False, "SDPA doesn't support output_attentions=True. If falling back to eager, please change the attention mask implementation."
|
| 502 |
+
else:
|
| 503 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 504 |
+
|
| 505 |
+
context_layer, attention_probs = attention_interface(
|
| 506 |
+
self,
|
| 507 |
+
query_layer,
|
| 508 |
+
key_layer,
|
| 509 |
+
value_layer,
|
| 510 |
+
head_mask,
|
| 511 |
+
is_causal=self.is_causal,
|
| 512 |
+
scaling=self.scaling,
|
| 513 |
+
dropout=0.0 if not self.training else self.dropout_prob,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 517 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 518 |
+
|
| 519 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 520 |
+
|
| 521 |
+
return outputs
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
|
| 525 |
+
class ViTMAESelfOutput(nn.Module):
|
| 526 |
+
"""
|
| 527 |
+
The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
|
| 528 |
+
layernorm applied before each block.
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 532 |
+
super().__init__()
|
| 533 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 534 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 535 |
+
|
| 536 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 537 |
+
hidden_states = self.dense(hidden_states)
|
| 538 |
+
hidden_states = self.dropout(hidden_states)
|
| 539 |
+
|
| 540 |
+
return hidden_states
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
|
| 544 |
+
class ViTMAEAttention(nn.Module):
|
| 545 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 546 |
+
super().__init__()
|
| 547 |
+
self.attention = ViTMAESelfAttention(config)
|
| 548 |
+
self.output = ViTMAESelfOutput(config)
|
| 549 |
+
self.pruned_heads = set()
|
| 550 |
+
|
| 551 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 552 |
+
if len(heads) == 0:
|
| 553 |
+
return
|
| 554 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 555 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Prune linear layers
|
| 559 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 560 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 561 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 562 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 563 |
+
|
| 564 |
+
# Update hyper params and store pruned heads
|
| 565 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 566 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 567 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 568 |
+
|
| 569 |
+
def forward(
|
| 570 |
+
self,
|
| 571 |
+
hidden_states: torch.Tensor,
|
| 572 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 573 |
+
output_attentions: bool = False,
|
| 574 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 575 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
| 576 |
+
|
| 577 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 578 |
+
|
| 579 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 580 |
+
return outputs
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
|
| 584 |
+
class ViTMAEIntermediate(nn.Module):
|
| 585 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 586 |
+
super().__init__()
|
| 587 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 588 |
+
if isinstance(config.hidden_act, str):
|
| 589 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 590 |
+
else:
|
| 591 |
+
self.intermediate_act_fn = config.hidden_act
|
| 592 |
+
|
| 593 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 594 |
+
hidden_states = self.dense(hidden_states)
|
| 595 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 596 |
+
|
| 597 |
+
return hidden_states
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
# Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE
|
| 601 |
+
class ViTMAEOutput(nn.Module):
|
| 602 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 603 |
+
super().__init__()
|
| 604 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 605 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 606 |
+
|
| 607 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 608 |
+
hidden_states = self.dense(hidden_states)
|
| 609 |
+
hidden_states = self.dropout(hidden_states)
|
| 610 |
+
|
| 611 |
+
hidden_states = hidden_states + input_tensor
|
| 612 |
+
|
| 613 |
+
return hidden_states
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
|
| 617 |
+
class ViTMAELayer(nn.Module):
|
| 618 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 619 |
+
|
| 620 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 621 |
+
super().__init__()
|
| 622 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 623 |
+
self.seq_len_dim = 1
|
| 624 |
+
self.attention = ViTMAEAttention(config)
|
| 625 |
+
self.intermediate = ViTMAEIntermediate(config)
|
| 626 |
+
self.output = ViTMAEOutput(config)
|
| 627 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 628 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 629 |
+
|
| 630 |
+
def forward(
|
| 631 |
+
self,
|
| 632 |
+
hidden_states: torch.Tensor,
|
| 633 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 634 |
+
output_attentions: bool = False,
|
| 635 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 636 |
+
self_attention_outputs = self.attention(
|
| 637 |
+
self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention
|
| 638 |
+
head_mask,
|
| 639 |
+
output_attentions=output_attentions,
|
| 640 |
+
)
|
| 641 |
+
attention_output = self_attention_outputs[0]
|
| 642 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 643 |
+
|
| 644 |
+
# first residual connection
|
| 645 |
+
hidden_states = attention_output + hidden_states
|
| 646 |
+
|
| 647 |
+
# in ViTMAE, layernorm is also applied after self-attention
|
| 648 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 649 |
+
layer_output = self.intermediate(layer_output)
|
| 650 |
+
|
| 651 |
+
# second residual connection is done here
|
| 652 |
+
layer_output = self.output(layer_output, hidden_states)
|
| 653 |
+
|
| 654 |
+
outputs = (layer_output,) + outputs
|
| 655 |
+
|
| 656 |
+
return outputs
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
|
| 660 |
+
class ViTMAEEncoder(nn.Module):
|
| 661 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 662 |
+
super().__init__()
|
| 663 |
+
self.config = config
|
| 664 |
+
self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
|
| 665 |
+
self.gradient_checkpointing = False
|
| 666 |
+
|
| 667 |
+
def forward(
|
| 668 |
+
self,
|
| 669 |
+
hidden_states: torch.Tensor,
|
| 670 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 671 |
+
output_attentions: bool = False,
|
| 672 |
+
output_hidden_states: bool = False,
|
| 673 |
+
return_dict: bool = True,
|
| 674 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 675 |
+
all_hidden_states = () if output_hidden_states else None
|
| 676 |
+
all_self_attentions = () if output_attentions else None
|
| 677 |
+
|
| 678 |
+
for i, layer_module in enumerate(self.layer):
|
| 679 |
+
if output_hidden_states:
|
| 680 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 681 |
+
|
| 682 |
+
if self.gradient_checkpointing and self.training:
|
| 683 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 684 |
+
layer_module.__call__,
|
| 685 |
+
hidden_states,
|
| 686 |
+
head_mask,
|
| 687 |
+
output_attentions,
|
| 688 |
+
)
|
| 689 |
+
else:
|
| 690 |
+
layer_outputs = layer_module(hidden_states, head_mask, output_attentions)
|
| 691 |
+
|
| 692 |
+
hidden_states = layer_outputs[0]
|
| 693 |
+
|
| 694 |
+
if output_attentions:
|
| 695 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 696 |
+
|
| 697 |
+
if output_hidden_states:
|
| 698 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 699 |
+
|
| 700 |
+
if not return_dict:
|
| 701 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 702 |
+
return BaseModelOutput(
|
| 703 |
+
last_hidden_state=hidden_states,
|
| 704 |
+
hidden_states=all_hidden_states,
|
| 705 |
+
attentions=all_self_attentions,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
class ViTMAEPreTrainedModel(PreTrainedModel):
|
| 710 |
+
"""
|
| 711 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 712 |
+
models.
|
| 713 |
+
"""
|
| 714 |
+
|
| 715 |
+
config_class = ViTMAEConfig
|
| 716 |
+
base_model_prefix = "vit"
|
| 717 |
+
main_input_name = "pixel_values"
|
| 718 |
+
supports_gradient_checkpointing = True
|
| 719 |
+
_supports_sdpa = True
|
| 720 |
+
_supports_flash_attn_2 = True
|
| 721 |
+
_supports_flex_attn = True
|
| 722 |
+
|
| 723 |
+
def _init_weights(self, module):
|
| 724 |
+
"""Initialize the weights"""
|
| 725 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 726 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 727 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 728 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 729 |
+
if module.bias is not None:
|
| 730 |
+
module.bias.data.zero_()
|
| 731 |
+
elif isinstance(module, nn.LayerNorm):
|
| 732 |
+
module.bias.data.zero_()
|
| 733 |
+
module.weight.data.fill_(1.0)
|
| 734 |
+
elif isinstance(module, ViTMAEEmbeddings):
|
| 735 |
+
module.initialize_weights()
|
| 736 |
+
elif isinstance(module, ViTMAEDecoder):
|
| 737 |
+
module.mask_token.data.zero_()
|
| 738 |
+
module.decoder_pos_embed.data.zero_()
|
| 739 |
+
if self.config.scale_embed:
|
| 740 |
+
torch.nn.init.normal_(module.decoder_scale_embed, std=self.config.initializer_range)
|
| 741 |
+
if self.config.time_embed:
|
| 742 |
+
torch.nn.init.normal_(module.time_embed, std=self.config.initializer_range)
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class ViTMAEModel(ViTMAEPreTrainedModel):
|
| 746 |
+
def __init__(self, config):
|
| 747 |
+
super().__init__(config)
|
| 748 |
+
self.config = config
|
| 749 |
+
|
| 750 |
+
self.embeddings = ViTMAEEmbeddings(config)
|
| 751 |
+
self.encoder = ViTMAEEncoder(config)
|
| 752 |
+
|
| 753 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 754 |
+
|
| 755 |
+
# Initialize weights and apply final processing
|
| 756 |
+
self.post_init()
|
| 757 |
+
|
| 758 |
+
def get_input_embeddings(self):
|
| 759 |
+
return self.embeddings.patch_embeddings
|
| 760 |
+
|
| 761 |
+
def _prune_heads(self, heads_to_prune):
|
| 762 |
+
"""
|
| 763 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 764 |
+
class PreTrainedModel
|
| 765 |
+
"""
|
| 766 |
+
for layer, heads in heads_to_prune.items():
|
| 767 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 768 |
+
|
| 769 |
+
def forward(
|
| 770 |
+
self,
|
| 771 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 772 |
+
gazing_info: Optional[dict] = None,
|
| 773 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 774 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 775 |
+
output_attentions: Optional[bool] = None,
|
| 776 |
+
output_hidden_states: Optional[bool] = None,
|
| 777 |
+
return_dict: Optional[bool] = None,
|
| 778 |
+
interpolate_pos_encoding: bool = False,
|
| 779 |
+
) -> Union[Tuple, ViTMAEModelOutput]:
|
| 780 |
+
|
| 781 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 782 |
+
output_hidden_states = (
|
| 783 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 784 |
+
)
|
| 785 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 786 |
+
|
| 787 |
+
if pixel_values is None:
|
| 788 |
+
raise ValueError("You have to specify pixel_values")
|
| 789 |
+
|
| 790 |
+
# Prepare head mask if needed
|
| 791 |
+
# 1.0 in head_mask indicate we keep the head
|
| 792 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 793 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 794 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 795 |
+
head_mask = head_mask.to(self.dtype)
|
| 796 |
+
|
| 797 |
+
embedding_output = self.embeddings(
|
| 798 |
+
pixel_values, gazing_info=gazing_info, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
encoder_outputs = self.encoder(
|
| 802 |
+
embedding_output,
|
| 803 |
+
head_mask=head_mask,
|
| 804 |
+
output_attentions=output_attentions,
|
| 805 |
+
output_hidden_states=output_hidden_states,
|
| 806 |
+
return_dict=return_dict,
|
| 807 |
+
)
|
| 808 |
+
sequence_output = encoder_outputs[0]
|
| 809 |
+
sequence_output = self.layernorm(sequence_output)
|
| 810 |
+
|
| 811 |
+
if not return_dict:
|
| 812 |
+
return sequence_output + encoder_outputs[1:]
|
| 813 |
+
|
| 814 |
+
return ViTMAEModelOutput(
|
| 815 |
+
last_hidden_state=sequence_output,
|
| 816 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 817 |
+
attentions=encoder_outputs.attentions,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
class ViTMAEDecoder(ViTMAEPreTrainedModel):
|
| 822 |
+
def __init__(self, config, num_patches):
|
| 823 |
+
super().__init__(config)
|
| 824 |
+
self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
|
| 825 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
|
| 826 |
+
self.decoder_pos_embed = nn.Parameter(
|
| 827 |
+
torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False
|
| 828 |
+
) # fixed sin-cos embedding
|
| 829 |
+
|
| 830 |
+
decoder_config = deepcopy(config)
|
| 831 |
+
decoder_config.hidden_size = config.decoder_hidden_size
|
| 832 |
+
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
|
| 833 |
+
decoder_config.num_attention_heads = config.decoder_num_attention_heads
|
| 834 |
+
decoder_config.intermediate_size = config.decoder_intermediate_size
|
| 835 |
+
self.decoder_layers = nn.ModuleList(
|
| 836 |
+
[ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
|
| 840 |
+
self.decoder_pred = nn.Linear(
|
| 841 |
+
config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
|
| 842 |
+
) # encoder to decoder
|
| 843 |
+
self.gradient_checkpointing = False
|
| 844 |
+
self.config = config
|
| 845 |
+
|
| 846 |
+
# multi-scale setting
|
| 847 |
+
self.scales = sorted([int(scale) for scale in config.scales.split('+')])
|
| 848 |
+
self.num_patch_each_frame_each_scale = [(scale // config.patch_size)**2 for scale in self.scales]
|
| 849 |
+
if self.config.scale_embed:
|
| 850 |
+
self.decoder_scale_embed = nn.Parameter(torch.randn(len(self.scales), config.decoder_hidden_size) * 0)
|
| 851 |
+
|
| 852 |
+
# time embed
|
| 853 |
+
if self.config.time_embed:
|
| 854 |
+
self.time_embed = nn.Parameter(torch.randn(config.max_num_frames, config.decoder_hidden_size) * 0)
|
| 855 |
+
|
| 856 |
+
self.num_token_each_frame = sum(self.num_patch_each_frame_each_scale)
|
| 857 |
+
|
| 858 |
+
self.initialize_weights(num_patches)
|
| 859 |
+
|
| 860 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 861 |
+
"""
|
| 862 |
+
This method is a modified version of the interpolation function for ViT-mae model at the decoder, that
|
| 863 |
+
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
| 864 |
+
resolution images.
|
| 865 |
+
|
| 866 |
+
Adapted from:
|
| 867 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 868 |
+
"""
|
| 869 |
+
|
| 870 |
+
# -1 removes the class dimension since we later append it without interpolation
|
| 871 |
+
embeddings_positions = embeddings.shape[1] - 1
|
| 872 |
+
|
| 873 |
+
# Separation of class token and patch tokens
|
| 874 |
+
class_pos_embed = self.decoder_pos_embed[:, :1]
|
| 875 |
+
patch_pos_embed = self.decoder_pos_embed[:, 1:]
|
| 876 |
+
|
| 877 |
+
# To retain the final 3d tensor with the required dimensions
|
| 878 |
+
dim = self.decoder_pos_embed.shape[-1]
|
| 879 |
+
|
| 880 |
+
# Increasing a dimension to enable bicubic interpolation
|
| 881 |
+
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
|
| 882 |
+
|
| 883 |
+
# permute to bring the dimension to be interpolated, to the last
|
| 884 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 885 |
+
|
| 886 |
+
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
|
| 887 |
+
# we keep the second last dimension constant
|
| 888 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 889 |
+
patch_pos_embed,
|
| 890 |
+
size=(patch_pos_embed.shape[-2], embeddings_positions),
|
| 891 |
+
mode="bicubic",
|
| 892 |
+
align_corners=False,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
# Converting back to the original shape
|
| 896 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 897 |
+
# Adding the class token back
|
| 898 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 899 |
+
|
| 900 |
+
def initialize_weights(self, num_patches):
|
| 901 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 902 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(
|
| 903 |
+
self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
|
| 904 |
+
)
|
| 905 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 906 |
+
|
| 907 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 908 |
+
torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)
|
| 909 |
+
|
| 910 |
+
def forward(
|
| 911 |
+
self,
|
| 912 |
+
hidden_states,
|
| 913 |
+
gazing_info=None,
|
| 914 |
+
frame_idx_to_reconstruct=None,
|
| 915 |
+
head_mask=None,
|
| 916 |
+
output_attentions=False,
|
| 917 |
+
output_hidden_states=False,
|
| 918 |
+
return_dict=True,
|
| 919 |
+
interpolate_pos_encoding: bool = False,
|
| 920 |
+
):
|
| 921 |
+
gazing_pos = gazing_info['gazing_pos']
|
| 922 |
+
num_gazing_each_frame = gazing_info['num_gazing_each_frame']
|
| 923 |
+
if_padded_gazing = gazing_info['if_padded_gazing']
|
| 924 |
+
original_seq_length = gazing_info['original_seq_length']
|
| 925 |
+
|
| 926 |
+
B = hidden_states.shape[0]
|
| 927 |
+
gaze_length = gazing_pos.shape[1]
|
| 928 |
+
assert gaze_length == num_gazing_each_frame.sum()
|
| 929 |
+
T = len(num_gazing_each_frame)
|
| 930 |
+
original_seq_length_each_frame = original_seq_length // T
|
| 931 |
+
|
| 932 |
+
# embed tokens
|
| 933 |
+
x = self.decoder_embed(hidden_states)
|
| 934 |
+
|
| 935 |
+
# Take out cls token
|
| 936 |
+
x_ = x[:, 1:, :]
|
| 937 |
+
cls_token = x[:, :1, :]
|
| 938 |
+
|
| 939 |
+
# Change all the padded gazing id to the last token id
|
| 940 |
+
gazing_pos = gazing_pos.flatten()
|
| 941 |
+
gazing_pos[if_padded_gazing.flatten()] = original_seq_length
|
| 942 |
+
gazing_pos = gazing_pos.view(B, -1)
|
| 943 |
+
|
| 944 |
+
# add mask tokens back to the sequence (temporarily append an additional token for padded gazing to select)
|
| 945 |
+
full_seq = self.mask_token.repeat(x.shape[0], original_seq_length + 1, 1).to(x.dtype)
|
| 946 |
+
full_seq[torch.arange(B)[..., None], gazing_pos] = x_
|
| 947 |
+
full_seq = full_seq[:, :-1, :]
|
| 948 |
+
|
| 949 |
+
# add pos embed and scale embed
|
| 950 |
+
full_seq = rearrange(full_seq, 'b (t n) c -> (b t) n c', t=T)
|
| 951 |
+
decoder_pos_embed = []
|
| 952 |
+
decoder_scale_embed = []
|
| 953 |
+
for i, scale in enumerate(self.scales):
|
| 954 |
+
x_cur_scale = full_seq[:, sum(self.num_patch_each_frame_each_scale[:i]):sum(self.num_patch_each_frame_each_scale[:i+1])]
|
| 955 |
+
if interpolate_pos_encoding:
|
| 956 |
+
decoder_pos_embed_cur_scale = self.interpolate_pos_encoding(F.pad(x_cur_scale, (0, 0, 1, 0)))[:, 1:]
|
| 957 |
+
else:
|
| 958 |
+
decoder_pos_embed_cur_scale = self.decoder_pos_embed
|
| 959 |
+
decoder_pos_embed.append(decoder_pos_embed_cur_scale)
|
| 960 |
+
if self.config.scale_embed:
|
| 961 |
+
decoder_scale_embed.append(self.decoder_scale_embed[i][None, None].repeat(1, decoder_pos_embed_cur_scale.shape[1], 1))
|
| 962 |
+
decoder_pos_embed = torch.cat(decoder_pos_embed, dim=1)
|
| 963 |
+
decoder_scale_embed = torch.cat(decoder_scale_embed, dim=1) if self.config.scale_embed else 0
|
| 964 |
+
full_seq = full_seq + decoder_pos_embed + decoder_scale_embed
|
| 965 |
+
full_seq = rearrange(full_seq, '(b t) n c -> b (t n) c', t=T)
|
| 966 |
+
|
| 967 |
+
# add time embed
|
| 968 |
+
if self.config.time_embed:
|
| 969 |
+
time_embed = self.time_embed[None, :T, None, :]
|
| 970 |
+
full_seq = rearrange(full_seq, 'b (t n) c -> b t n c', t=T)
|
| 971 |
+
full_seq = full_seq + time_embed
|
| 972 |
+
full_seq = rearrange(full_seq, 'b t n c -> b (t n) c', t=T)
|
| 973 |
+
|
| 974 |
+
# Get the index of tokens to feed into decoder (encoded tokens + mask tokens for selected frames)
|
| 975 |
+
idx_to_decode = gazing_pos.clone()
|
| 976 |
+
idx_to_decode = list(idx_to_decode.split(num_gazing_each_frame.tolist(), dim=-1))
|
| 977 |
+
for frame_idx in frame_idx_to_reconstruct:
|
| 978 |
+
idx_to_decode[frame_idx] = torch.arange(original_seq_length_each_frame, device=gazing_pos.device)[None].repeat(B, 1) + original_seq_length_each_frame * frame_idx
|
| 979 |
+
idx_to_decode = torch.cat(idx_to_decode, dim=-1)
|
| 980 |
+
|
| 981 |
+
# Get the tokens to decode
|
| 982 |
+
full_seq = torch.cat([full_seq, full_seq[:, :1]], dim=1)
|
| 983 |
+
hidden_states = full_seq[torch.arange(B)[..., None], idx_to_decode]
|
| 984 |
+
|
| 985 |
+
# add cls token
|
| 986 |
+
cls_token = cls_token + self.decoder_pos_embed[:, :1]
|
| 987 |
+
hidden_states = torch.cat([cls_token, hidden_states], dim=1)
|
| 988 |
+
|
| 989 |
+
# apply Transformer layers (blocks)
|
| 990 |
+
head_mask = head_mask.to(self.dtype)
|
| 991 |
+
all_hidden_states = () if output_hidden_states else None
|
| 992 |
+
all_self_attentions = () if output_attentions else None
|
| 993 |
+
for i, layer_module in enumerate(self.decoder_layers):
|
| 994 |
+
if output_hidden_states:
|
| 995 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 996 |
+
|
| 997 |
+
if self.gradient_checkpointing and self.training:
|
| 998 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 999 |
+
layer_module.__call__,
|
| 1000 |
+
hidden_states,
|
| 1001 |
+
head_mask,
|
| 1002 |
+
output_attentions,
|
| 1003 |
+
)
|
| 1004 |
+
else:
|
| 1005 |
+
layer_outputs = layer_module(hidden_states, head_mask=head_mask, output_attentions=output_attentions)
|
| 1006 |
+
|
| 1007 |
+
hidden_states = layer_outputs[0]
|
| 1008 |
+
|
| 1009 |
+
if output_attentions:
|
| 1010 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1011 |
+
|
| 1012 |
+
if output_hidden_states:
|
| 1013 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1014 |
+
|
| 1015 |
+
hidden_states = self.decoder_norm(hidden_states)
|
| 1016 |
+
|
| 1017 |
+
# predictor projection
|
| 1018 |
+
logits = self.decoder_pred(hidden_states)
|
| 1019 |
+
|
| 1020 |
+
# remove cls token
|
| 1021 |
+
logits = logits[:, 1:, :]
|
| 1022 |
+
|
| 1023 |
+
if not return_dict:
|
| 1024 |
+
return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
|
| 1025 |
+
return ViTMAEDecoderOutput(
|
| 1026 |
+
logits=logits,
|
| 1027 |
+
hidden_states=all_hidden_states,
|
| 1028 |
+
attentions=all_self_attentions,
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
| 1033 |
+
def __init__(self, config):
|
| 1034 |
+
super().__init__(config)
|
| 1035 |
+
self.config = config
|
| 1036 |
+
|
| 1037 |
+
self.vit = ViTMAEModel(config)
|
| 1038 |
+
self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches)
|
| 1039 |
+
|
| 1040 |
+
# multi-scale setting
|
| 1041 |
+
self.scales = sorted([int(scale) for scale in config.scales.split('+')])
|
| 1042 |
+
self.num_patch_each_scale = [(scale // config.patch_size)**2 for scale in self.scales]
|
| 1043 |
+
self.num_token_each_frame = sum(self.num_patch_each_scale)
|
| 1044 |
+
|
| 1045 |
+
# loss setting
|
| 1046 |
+
self.loss_type = [str(loss) for loss in config.loss_type.split('+')]
|
| 1047 |
+
self.loss_weights = [float(weight) for weight in config.loss_weights.split('+')]
|
| 1048 |
+
self.transform = None # will be initialized in the outer
|
| 1049 |
+
self.loss_fns = []
|
| 1050 |
+
for loss in self.loss_type:
|
| 1051 |
+
if loss == 'l1':
|
| 1052 |
+
self.loss_fns.append(self.l1_loss)
|
| 1053 |
+
elif loss == 'dinov2_reg':
|
| 1054 |
+
self.dinov2_reg = None # will be initialized in the outer
|
| 1055 |
+
self.dinov2_reg_transform = None # will be initialized in the outer
|
| 1056 |
+
self.loss_fns.append(self.dinov2_reg_loss)
|
| 1057 |
+
elif loss == 'siglip2':
|
| 1058 |
+
self.siglip2 = None # will be initialized in the outer
|
| 1059 |
+
self.siglip2_transform = None # will be initialized in the outer
|
| 1060 |
+
self.loss_fns.append(self.siglip2_loss)
|
| 1061 |
+
else:
|
| 1062 |
+
raise ValueError(f"Loss type {loss} not supported")
|
| 1063 |
+
|
| 1064 |
+
# Initialize weights and apply final processing
|
| 1065 |
+
self.post_init()
|
| 1066 |
+
|
| 1067 |
+
def get_input_embeddings(self):
|
| 1068 |
+
return self.vit.embeddings.patch_embeddings
|
| 1069 |
+
|
| 1070 |
+
def _prune_heads(self, heads_to_prune):
|
| 1071 |
+
"""
|
| 1072 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1073 |
+
class PreTrainedModel
|
| 1074 |
+
"""
|
| 1075 |
+
for layer, heads in heads_to_prune.items():
|
| 1076 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 1077 |
+
|
| 1078 |
+
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
|
| 1079 |
+
"""
|
| 1080 |
+
Args:
|
| 1081 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1082 |
+
Pixel values.
|
| 1083 |
+
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
| 1084 |
+
interpolation flag passed during the forward pass.
|
| 1085 |
+
|
| 1086 |
+
Returns:
|
| 1087 |
+
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 1088 |
+
Patchified pixel values.
|
| 1089 |
+
"""
|
| 1090 |
+
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
| 1091 |
+
# sanity checks
|
| 1092 |
+
if not interpolate_pos_encoding and (
|
| 1093 |
+
pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
|
| 1094 |
+
):
|
| 1095 |
+
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
|
| 1096 |
+
if pixel_values.shape[1] != num_channels:
|
| 1097 |
+
raise ValueError(
|
| 1098 |
+
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
# patchify
|
| 1102 |
+
batch_size = pixel_values.shape[0]
|
| 1103 |
+
num_patches_h = pixel_values.shape[2] // patch_size
|
| 1104 |
+
num_patches_w = pixel_values.shape[3] // patch_size
|
| 1105 |
+
patchified_pixel_values = pixel_values.reshape(
|
| 1106 |
+
batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
|
| 1107 |
+
)
|
| 1108 |
+
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
|
| 1109 |
+
patchified_pixel_values = patchified_pixel_values.reshape(
|
| 1110 |
+
batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
|
| 1111 |
+
)
|
| 1112 |
+
return patchified_pixel_values
|
| 1113 |
+
|
| 1114 |
+
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
|
| 1115 |
+
"""
|
| 1116 |
+
Args:
|
| 1117 |
+
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 1118 |
+
Patchified pixel values.
|
| 1119 |
+
original_image_size (`Tuple[int, int]`, *optional*):
|
| 1120 |
+
Original image size.
|
| 1121 |
+
|
| 1122 |
+
Returns:
|
| 1123 |
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
| 1124 |
+
Pixel values.
|
| 1125 |
+
"""
|
| 1126 |
+
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
| 1127 |
+
original_image_size = (
|
| 1128 |
+
original_image_size
|
| 1129 |
+
if original_image_size is not None
|
| 1130 |
+
else (self.config.image_size, self.config.image_size)
|
| 1131 |
+
)
|
| 1132 |
+
original_height, original_width = original_image_size
|
| 1133 |
+
num_patches_h = original_height // patch_size
|
| 1134 |
+
num_patches_w = original_width // patch_size
|
| 1135 |
+
# sanity check
|
| 1136 |
+
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
|
| 1137 |
+
raise ValueError(
|
| 1138 |
+
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
# unpatchify
|
| 1142 |
+
batch_size = patchified_pixel_values.shape[0]
|
| 1143 |
+
patchified_pixel_values = patchified_pixel_values.reshape(
|
| 1144 |
+
batch_size,
|
| 1145 |
+
num_patches_h,
|
| 1146 |
+
num_patches_w,
|
| 1147 |
+
patch_size,
|
| 1148 |
+
patch_size,
|
| 1149 |
+
num_channels,
|
| 1150 |
+
)
|
| 1151 |
+
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
|
| 1152 |
+
pixel_values = patchified_pixel_values.reshape(
|
| 1153 |
+
batch_size,
|
| 1154 |
+
num_channels,
|
| 1155 |
+
num_patches_h * patch_size,
|
| 1156 |
+
num_patches_w * patch_size,
|
| 1157 |
+
)
|
| 1158 |
+
return pixel_values
|
| 1159 |
+
|
| 1160 |
+
def retransform(self, image, source_transform, target_transform):
|
| 1161 |
+
# Revert the source transform
|
| 1162 |
+
image = rearrange(image, 'b c h w -> b h w c')
|
| 1163 |
+
if source_transform.do_normalize:
|
| 1164 |
+
image = image * torch.tensor(source_transform.image_std, device=image.device, dtype=image.dtype) + torch.tensor(source_transform.image_mean, device=image.device, dtype=image.dtype)
|
| 1165 |
+
if source_transform.do_rescale:
|
| 1166 |
+
if hasattr(source_transform, 'offset') and source_transform.offset:
|
| 1167 |
+
image = image + 1
|
| 1168 |
+
image = image / source_transform.rescale_factor
|
| 1169 |
+
image = rearrange(image, 'b h w c -> b c h w')
|
| 1170 |
+
|
| 1171 |
+
# Apply the target transform
|
| 1172 |
+
image = rearrange(image, 'b c h w -> b h w c')
|
| 1173 |
+
if target_transform.do_rescale:
|
| 1174 |
+
image = image * target_transform.rescale_factor
|
| 1175 |
+
if hasattr(target_transform, 'offset') and target_transform.offset:
|
| 1176 |
+
image = image - 1
|
| 1177 |
+
if target_transform.do_normalize:
|
| 1178 |
+
image = (image - torch.tensor(target_transform.image_mean, device=image.device, dtype=image.dtype)) / torch.tensor(target_transform.image_std, device=image.device, dtype=image.dtype)
|
| 1179 |
+
image = rearrange(image, 'b h w c -> b c h w')
|
| 1180 |
+
|
| 1181 |
+
return image
|
| 1182 |
+
|
| 1183 |
+
def l1_loss(self, pred, target):
|
| 1184 |
+
"""
|
| 1185 |
+
pred, target: (B, C, H, W)
|
| 1186 |
+
"""
|
| 1187 |
+
return (pred - target).abs().mean(dim=(-1, -2, -3))
|
| 1188 |
+
|
| 1189 |
+
def dinov2_reg_loss(self, pred, target):
|
| 1190 |
+
"""
|
| 1191 |
+
pred, target: (B, C, H, W)
|
| 1192 |
+
"""
|
| 1193 |
+
def get_dinov2_reg_features(image):
|
| 1194 |
+
image = self.retransform(image, self.transform, self.dinov2_reg_transform)
|
| 1195 |
+
features = self.dinov2_reg(image, output_hidden_states=True).hidden_states
|
| 1196 |
+
features = torch.cat([feature[:, self.dinov2_reg.config.num_register_tokens + 1:] for feature in features[-4:]], dim=-1)
|
| 1197 |
+
return features
|
| 1198 |
+
|
| 1199 |
+
pred_features = get_dinov2_reg_features(pred)
|
| 1200 |
+
target_features = get_dinov2_reg_features(target)
|
| 1201 |
+
|
| 1202 |
+
# Get average l2 loss over last k layers' features
|
| 1203 |
+
loss = (pred_features - target_features).pow(2).mean(dim=(-1, -2))
|
| 1204 |
+
|
| 1205 |
+
return loss
|
| 1206 |
+
|
| 1207 |
+
def siglip2_loss(self, pred, target):
|
| 1208 |
+
"""
|
| 1209 |
+
pred, target: (B, C, H, W)
|
| 1210 |
+
"""
|
| 1211 |
+
def get_siglip2_features(image):
|
| 1212 |
+
image = self.retransform(image, self.transform, self.siglip2_transform)
|
| 1213 |
+
features = self.siglip2(image, output_hidden_states=True).hidden_states
|
| 1214 |
+
features = torch.cat(features[-4:], dim=-1)
|
| 1215 |
+
return features
|
| 1216 |
+
|
| 1217 |
+
pred_features = get_siglip2_features(pred)
|
| 1218 |
+
target_features = get_siglip2_features(target)
|
| 1219 |
+
|
| 1220 |
+
# Get average l2 loss over last k layers' features
|
| 1221 |
+
loss = (pred_features - target_features).pow(2).mean(dim=(-1, -2))
|
| 1222 |
+
|
| 1223 |
+
return loss
|
| 1224 |
+
|
| 1225 |
+
def forward_loss(self, pixel_values, pred, interpolate_pos_encoding: bool = False):
|
| 1226 |
+
"""
|
| 1227 |
+
Args:
|
| 1228 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, T, num_channels, height, width)`):
|
| 1229 |
+
Pixel values.
|
| 1230 |
+
pred (`torch.FloatTensor` of shape `(batch_size, T, num_patches, patch_size**2 * num_channels)`:
|
| 1231 |
+
Predicted pixel values.
|
| 1232 |
+
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
| 1233 |
+
interpolation flag passed during the forward pass.
|
| 1234 |
+
|
| 1235 |
+
Returns:
|
| 1236 |
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
| 1237 |
+
"""
|
| 1238 |
+
B, T = pixel_values.shape[:2]
|
| 1239 |
+
pixel_values = pixel_values.flatten(0, 1) # (B * T), C, H, W
|
| 1240 |
+
pred = pred.flatten(0, 1) # (B * T), N, C
|
| 1241 |
+
|
| 1242 |
+
pred = self.unpatchify(pred, original_image_size=(pixel_values.shape[2], pixel_values.shape[3]))
|
| 1243 |
+
|
| 1244 |
+
loss = 0
|
| 1245 |
+
for loss_fn, loss_weight in zip(self.loss_fns, self.loss_weights):
|
| 1246 |
+
loss += loss_weight * loss_fn(pred, pixel_values)
|
| 1247 |
+
|
| 1248 |
+
loss = rearrange(loss, '(b t) -> b t', b=B, t=T)
|
| 1249 |
+
mean_loss = loss.mean(dim=-1)
|
| 1250 |
+
|
| 1251 |
+
return loss, mean_loss
|
| 1252 |
+
|
| 1253 |
+
def get_reconstructed_image(self, pixel_values, pred, interpolate_pos_encoding: bool = False):
|
| 1254 |
+
"""
|
| 1255 |
+
pixel_values: (B, T, C, H, W)
|
| 1256 |
+
pred: (B, T, N, C)
|
| 1257 |
+
"""
|
| 1258 |
+
B, T = pixel_values.shape[:2]
|
| 1259 |
+
pixel_values = pixel_values.flatten(0, 1) # (B * T), C, H, W
|
| 1260 |
+
pred = pred.flatten(0, 1) # (B * T), N, C
|
| 1261 |
+
|
| 1262 |
+
pred = self.unpatchify(pred, original_image_size=(pixel_values.shape[2], pixel_values.shape[3]))
|
| 1263 |
+
|
| 1264 |
+
pred = rearrange(pred, '(b t) c h w -> b t c h w', b=B, t=T)
|
| 1265 |
+
|
| 1266 |
+
return pred
|
| 1267 |
+
|
| 1268 |
+
def get_causal_mask(self, num_tokens_each_frame, num_layers, batch_size, num_heads, token_mask=None, cls_token=True):
|
| 1269 |
+
"""
|
| 1270 |
+
Assume a input of shape B * N * C, where N contains tokens from several frames.
|
| 1271 |
+
Each frame has num_tokens_each_frame[t] tokens.
|
| 1272 |
+
Create a block-causal attention mask such that each token can only attend to tokens from either previous frames or the same frame.
|
| 1273 |
+
Additionally, mask any tokens indicated by token_mask (e.g., the tokens at padded gazing positions)
|
| 1274 |
+
|
| 1275 |
+
Inputs:
|
| 1276 |
+
num_tokens_each_frame: (T)
|
| 1277 |
+
token_mask: (B, N)
|
| 1278 |
+
cls_token: whether to include the cls token in the mask
|
| 1279 |
+
Return:
|
| 1280 |
+
mask: batch x num_heads x seq_length x seq_length
|
| 1281 |
+
"""
|
| 1282 |
+
T = len(num_tokens_each_frame)
|
| 1283 |
+
N = num_tokens_each_frame.sum()
|
| 1284 |
+
device = num_tokens_each_frame.device
|
| 1285 |
+
|
| 1286 |
+
# Create a causal mask
|
| 1287 |
+
mask = torch.tril(torch.ones(batch_size, N, N, device=device))
|
| 1288 |
+
|
| 1289 |
+
# Make the tokens inside each frame attend to each other
|
| 1290 |
+
for t in range(T):
|
| 1291 |
+
mask[:, sum(num_tokens_each_frame[:t]):sum(num_tokens_each_frame[:t+1]), sum(num_tokens_each_frame[:t]):sum(num_tokens_each_frame[:t+1])] = 1
|
| 1292 |
+
|
| 1293 |
+
# Mask out tokens indicated by token_mask
|
| 1294 |
+
if token_mask is not None:
|
| 1295 |
+
token_mask = token_mask.unsqueeze(1).repeat(1, N, 1)
|
| 1296 |
+
mask = mask * (~token_mask).float()
|
| 1297 |
+
|
| 1298 |
+
# Add mask for cls token
|
| 1299 |
+
if cls_token:
|
| 1300 |
+
mask_ = mask.clone()
|
| 1301 |
+
mask = torch.tril(torch.ones(batch_size, N + 1, N + 1, device=device))
|
| 1302 |
+
mask[:, 1:, 1:] = mask_
|
| 1303 |
+
|
| 1304 |
+
# Each token must be able to attend to itself
|
| 1305 |
+
mask[:, torch.arange(N), torch.arange(N)] = 1
|
| 1306 |
+
|
| 1307 |
+
# According to different attention implementations, the mask values are different.
|
| 1308 |
+
if self.config._attn_implementation == "flex_attention" or self.config._attn_implementation == "sdpa":
|
| 1309 |
+
# mask is a float tensor that will be added to the attention scores. This means the tokens to be attended should have mask value of 0, and the rest should have mask value of -inf.
|
| 1310 |
+
mask = torch.where(mask == 1, 0, -torch.inf)
|
| 1311 |
+
elif self.config._attn_implementation == "flash_attention_2":
|
| 1312 |
+
raise NotImplementedError("Flash attention 2 doesn't support custom attention mask. Please use attention_implementation='flex_attention'.")
|
| 1313 |
+
elif self.config._attn_implementation == "eager":
|
| 1314 |
+
# mask is a float tensor that will be multiplied to the attn prob after softmax. This means the tokens to be attended should have mask value of 1, and the rest should have mask value of 0.
|
| 1315 |
+
pass
|
| 1316 |
+
|
| 1317 |
+
mask = mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
|
| 1318 |
+
|
| 1319 |
+
return mask.to(num_tokens_each_frame.device)
|
| 1320 |
+
|
| 1321 |
+
def forward(
|
| 1322 |
+
self,
|
| 1323 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1324 |
+
gazing_info: Optional[dict] = None,
|
| 1325 |
+
frame_idx_to_reconstruct: Optional[torch.LongTensor] = None,
|
| 1326 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 1327 |
+
output_attentions: Optional[bool] = None,
|
| 1328 |
+
output_hidden_states: Optional[bool] = None,
|
| 1329 |
+
return_dict: Optional[bool] = None,
|
| 1330 |
+
interpolate_pos_encoding: bool = False,
|
| 1331 |
+
) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
|
| 1332 |
+
"""
|
| 1333 |
+
pixel_values: (B, T, C, H, W)
|
| 1334 |
+
gazing_info:
|
| 1335 |
+
gazing_pos: The gazing positions of each whole sequence. (B, N)
|
| 1336 |
+
num_gazing_each_frame: The number of gazing positions for each frame, including the padded gazing. (T, )
|
| 1337 |
+
if_padded_gazing: Whether the gazing is padded. (B, N)
|
| 1338 |
+
frame_idx_to_reconstruct: (num_selected_frames, )
|
| 1339 |
+
"""
|
| 1340 |
+
B, T = pixel_values.shape[:2]
|
| 1341 |
+
|
| 1342 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1343 |
+
|
| 1344 |
+
# Get the encoder attention mask
|
| 1345 |
+
encoder_attn_mask = self.get_causal_mask(gazing_info['num_gazing_each_frame'], self.config.num_hidden_layers, B, self.config.num_attention_heads, token_mask=gazing_info['if_padded_gazing'], cls_token=True) if self.config.causal else None
|
| 1346 |
+
|
| 1347 |
+
# Get the encoder outputs
|
| 1348 |
+
outputs = self.vit(
|
| 1349 |
+
pixel_values,
|
| 1350 |
+
gazing_info=gazing_info,
|
| 1351 |
+
noise=noise,
|
| 1352 |
+
head_mask=encoder_attn_mask,
|
| 1353 |
+
output_attentions=output_attentions,
|
| 1354 |
+
output_hidden_states=output_hidden_states,
|
| 1355 |
+
return_dict=return_dict,
|
| 1356 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1357 |
+
)
|
| 1358 |
+
latent = outputs.last_hidden_state # B * N * C
|
| 1359 |
+
|
| 1360 |
+
# Get the number of tokens to decode for each frame
|
| 1361 |
+
num_decoded_tokens_each_frame = gazing_info['num_gazing_each_frame'].clone()
|
| 1362 |
+
num_decoded_tokens_each_frame[frame_idx_to_reconstruct] = self.num_token_each_frame
|
| 1363 |
+
|
| 1364 |
+
# Get the gazing padding mask for decoder
|
| 1365 |
+
if_padded_gazing_decoder = gazing_info['if_padded_gazing'].clone()
|
| 1366 |
+
if_padded_gazing_decoder = list(if_padded_gazing_decoder.split(gazing_info['num_gazing_each_frame'].tolist(), dim=-1))
|
| 1367 |
+
for frame_idx in frame_idx_to_reconstruct:
|
| 1368 |
+
if_padded_gazing_decoder[frame_idx] = torch.zeros(B, self.num_token_each_frame).to(gazing_info['if_padded_gazing'].device).to(torch.bool)
|
| 1369 |
+
if_padded_gazing_decoder = torch.cat(if_padded_gazing_decoder, dim=-1)
|
| 1370 |
+
|
| 1371 |
+
# Get the decoder attention mask
|
| 1372 |
+
decoder_attn_mask = self.get_causal_mask(num_decoded_tokens_each_frame, self.config.decoder_num_hidden_layers, B, self.config.decoder_num_attention_heads, token_mask=if_padded_gazing_decoder, cls_token=True) if self.config.causal else None
|
| 1373 |
+
|
| 1374 |
+
# Get the decoder outputs
|
| 1375 |
+
decoder_outputs = self.decoder(
|
| 1376 |
+
latent,
|
| 1377 |
+
gazing_info=gazing_info,
|
| 1378 |
+
frame_idx_to_reconstruct=frame_idx_to_reconstruct,
|
| 1379 |
+
head_mask=decoder_attn_mask,
|
| 1380 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1381 |
+
)
|
| 1382 |
+
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
|
| 1383 |
+
|
| 1384 |
+
# Only keep the predictions for the selected frames
|
| 1385 |
+
decoded_token_idx_to_keep = []
|
| 1386 |
+
for frame_idx in frame_idx_to_reconstruct:
|
| 1387 |
+
decoded_token_idx_to_keep.append(torch.arange(sum(num_decoded_tokens_each_frame[:frame_idx]), sum(num_decoded_tokens_each_frame[:frame_idx+1])))
|
| 1388 |
+
decoded_token_idx_to_keep = torch.cat(decoded_token_idx_to_keep, dim=0)
|
| 1389 |
+
logits = logits[:, decoded_token_idx_to_keep]
|
| 1390 |
+
logits = rearrange(logits, 'b (t n) c -> b t n c', t=len(frame_idx_to_reconstruct)) # B * num_selected_frames * N * C
|
| 1391 |
+
|
| 1392 |
+
# throw away the reconstruction and masks for smaller scales
|
| 1393 |
+
logits = logits[:, :, sum(self.num_patch_each_scale[:-1]):, :]
|
| 1394 |
+
|
| 1395 |
+
loss_each_reconstruction_frame, loss_mean = self.forward_loss(pixel_values[:, frame_idx_to_reconstruct], logits, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1396 |
+
reconstruction = self.get_reconstructed_image(pixel_values[:, frame_idx_to_reconstruct], logits, interpolate_pos_encoding=interpolate_pos_encoding) # B * num_selected_frames * C * H * W
|
| 1397 |
+
|
| 1398 |
+
if not return_dict:
|
| 1399 |
+
output = (logits, reconstruction) + outputs[2:]
|
| 1400 |
+
return ((loss_each_reconstruction_frame, loss_mean) + output) if loss_each_reconstruction_frame is not None else output
|
| 1401 |
+
|
| 1402 |
+
return ViTMAEForPreTrainingOutput(
|
| 1403 |
+
loss_each_reconstruction_frame=loss_each_reconstruction_frame,
|
| 1404 |
+
loss_mean=loss_mean,
|
| 1405 |
+
reconstruction=reconstruction,
|
| 1406 |
+
logits=logits,
|
| 1407 |
+
hidden_states=outputs.hidden_states,
|
| 1408 |
+
attentions=outputs.attentions,
|
| 1409 |
+
)
|
| 1410 |
+
|
| 1411 |
+
|
| 1412 |
+
__all__ = ["ViTMAEForPreTraining", "ViTMAELayer", "ViTMAEModel", "ViTMAEPreTrainedModel"]
|
autogaze/tasks/video_mae_reconstruction/task_video_mae_reconstruction.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from transformers import AutoModel, AutoImageProcessor, VivitImageProcessor
|
| 6 |
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
| 7 |
+
from transformers.models.siglip2.modeling_siglip2 import Siglip2VisionModel
|
| 8 |
+
|
| 9 |
+
from .modeling_video_mae import ViTMAEForPreTraining
|
| 10 |
+
from .visualize_video_mae_reconstruction import VisualizeReconstruction
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VideoMAEReconstruction(nn.Module):
|
| 14 |
+
def __init__(self, recon_model, recon_model_config, scales, recon_sample_rate, attn_mode):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
# Create model
|
| 18 |
+
self.scales = sorted([int(scale) for scale in str(scales).split("+")])
|
| 19 |
+
self.transform = VivitImageProcessor.from_pretrained(recon_model, size=self.scales[-1]) # use mae image preprocessor config to intialize video preprocessor
|
| 20 |
+
self.mae = ViTMAEForPreTraining.from_pretrained(recon_model, attn_implementation="sdpa", scales=str(scales), **OmegaConf.to_container(recon_model_config))
|
| 21 |
+
self.mae.transform = self.transform
|
| 22 |
+
if "dinov2_reg" in self.mae.loss_type:
|
| 23 |
+
self.mae.dinov2_reg = AutoModel.from_pretrained(recon_model_config.dinov2_reg_loss_config.model, attn_implementation=attn_mode)
|
| 24 |
+
self.mae.dinov2_reg_transform = AutoImageProcessor.from_pretrained(recon_model_config.dinov2_reg_loss_config.model)
|
| 25 |
+
for param in self.mae.dinov2_reg.parameters():
|
| 26 |
+
param.requires_grad = False
|
| 27 |
+
self.mae.dinov2_reg.eval()
|
| 28 |
+
if "siglip2" in self.mae.loss_type:
|
| 29 |
+
if "naflex" in recon_model_config.siglip2_loss_config.model:
|
| 30 |
+
self.mae.siglip2 = Siglip2VisionModel.from_pretrained(recon_model_config.siglip2_loss_config.model, attn_implementation=attn_mode)
|
| 31 |
+
else:
|
| 32 |
+
self.mae.siglip2 = SiglipVisionModel.from_pretrained(recon_model_config.siglip2_loss_config.model, attn_implementation=attn_mode)
|
| 33 |
+
self.mae.siglip2_transform = AutoImageProcessor.from_pretrained(recon_model_config.siglip2_loss_config.model)
|
| 34 |
+
for param in self.mae.siglip2.parameters():
|
| 35 |
+
param.requires_grad = False
|
| 36 |
+
self.mae.siglip2.eval()
|
| 37 |
+
|
| 38 |
+
# Sampling strategy for reconstruction
|
| 39 |
+
self.recon_sample_rate = recon_sample_rate
|
| 40 |
+
|
| 41 |
+
# Create visualization methods
|
| 42 |
+
self.visualize_methods = [VisualizeReconstruction()]
|
| 43 |
+
|
| 44 |
+
# kwargs for the gaze model input. Will be passed to the gaze model during training.
|
| 45 |
+
self.gaze_model_kwargs = {
|
| 46 |
+
"target_scales": self.scales,
|
| 47 |
+
"target_patch_size": self.mae.config.patch_size,
|
| 48 |
+
"target_image_mean": self.transform.image_mean,
|
| 49 |
+
"target_image_std": self.transform.image_std,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
@torch.autocast("cuda", dtype=torch.bfloat16)
|
| 53 |
+
def forward_output(self, inputs, gaze_outputs, frame_idx_to_reconstruct=None):
|
| 54 |
+
"""
|
| 55 |
+
Get all the outputs from the inputs
|
| 56 |
+
"""
|
| 57 |
+
video = inputs['video']
|
| 58 |
+
gazing_pos = gaze_outputs['gazing_pos']
|
| 59 |
+
num_gazing_each_frame = gaze_outputs['num_gazing_each_frame']
|
| 60 |
+
if_padded_gazing = gaze_outputs['if_padded_gazing']
|
| 61 |
+
frame_sampling_rate = gaze_outputs['frame_sampling_rate']
|
| 62 |
+
num_vision_tokens_each_frame = gaze_outputs['num_vision_tokens_each_frame']
|
| 63 |
+
|
| 64 |
+
assert frame_sampling_rate == 1, "If frame_sampling_rate > 1, we can downsample the video here but ideally we don't want to do that"
|
| 65 |
+
assert num_vision_tokens_each_frame == sum([(scale // self.mae.config.patch_size) ** 2 for scale in self.scales]), "The number of vision tokens in each frame is not consistent between gaze model and MAE model"
|
| 66 |
+
|
| 67 |
+
# Frame sampling strategy for reconstruction
|
| 68 |
+
B, T = video.shape[:2]
|
| 69 |
+
if frame_idx_to_reconstruct is None:
|
| 70 |
+
frame_idx_to_reconstruct = torch.randperm(T)[:int(T * self.recon_sample_rate)].to(video.device)
|
| 71 |
+
|
| 72 |
+
# Reconstruct the video
|
| 73 |
+
gazing_info = {
|
| 74 |
+
'gazing_pos': gazing_pos,
|
| 75 |
+
'num_gazing_each_frame': num_gazing_each_frame,
|
| 76 |
+
'if_padded_gazing': if_padded_gazing,
|
| 77 |
+
}
|
| 78 |
+
recon_output = self.mae(video, gazing_info=gazing_info, frame_idx_to_reconstruct=frame_idx_to_reconstruct, interpolate_pos_encoding=True)
|
| 79 |
+
recon_loss_mean = recon_output.loss_mean
|
| 80 |
+
recon_loss_each_reconstruction_frame = recon_output.loss_each_reconstruction_frame
|
| 81 |
+
num_gazing_before_each_reconstruction_frame = torch.stack([num_gazing_each_frame[:frame_idx+1].sum(dim=-1) for frame_idx in frame_idx_to_reconstruct], dim=0)
|
| 82 |
+
num_non_padded_gazing_at_each_reconstruction_frame = [(~if_padded_gazing)[:, num_gazing_each_frame[:frame_idx].sum():num_gazing_each_frame[:frame_idx+1].sum()].sum(dim=-1) for frame_idx in frame_idx_to_reconstruct]
|
| 83 |
+
num_non_padded_gazing_at_each_reconstruction_frame = torch.stack(num_non_padded_gazing_at_each_reconstruction_frame, dim=-1) # B * num_reconstruction_frames
|
| 84 |
+
|
| 85 |
+
# Organize the recon loss at each gazing token
|
| 86 |
+
if_padded_gazing_each_frame = list(if_padded_gazing.split(num_gazing_each_frame.tolist(), dim=-1))
|
| 87 |
+
reconstruction_loss_each_gazing_token = [torch.zeros(*if_padded_gazing_each_frame[t].shape, dtype=gazing_pos.dtype, device=gazing_pos.device) for t in range(len(num_gazing_each_frame))]
|
| 88 |
+
reconstruction_loss_each_gazing_token_mask = [torch.zeros(*if_padded_gazing_each_frame[t].shape, dtype=gazing_pos.dtype, device=gazing_pos.device) for t in range(len(num_gazing_each_frame))]
|
| 89 |
+
for i, frame_idx in enumerate(frame_idx_to_reconstruct):
|
| 90 |
+
cur_mask = F.pad(if_padded_gazing_each_frame[frame_idx][:, 1:], (0, 1), value=True).to(torch.float)
|
| 91 |
+
reconstruction_loss_each_gazing_token[frame_idx] = recon_loss_each_reconstruction_frame[:, i:i+1] * cur_mask
|
| 92 |
+
reconstruction_loss_each_gazing_token_mask[frame_idx] = cur_mask
|
| 93 |
+
reconstruction_loss_each_gazing_token = torch.cat(reconstruction_loss_each_gazing_token, dim=-1) # B * N
|
| 94 |
+
reconstruction_loss_each_gazing_token_mask = torch.cat(reconstruction_loss_each_gazing_token_mask, dim=-1) # B * N
|
| 95 |
+
|
| 96 |
+
outputs = {
|
| 97 |
+
"reconstruction": recon_output.reconstruction,
|
| 98 |
+
"reconstruction_loss": recon_loss_mean,
|
| 99 |
+
"reconstruction_loss_each_reconstruction_frame": recon_loss_each_reconstruction_frame,
|
| 100 |
+
"reconstruction_loss_each_gazing_token": reconstruction_loss_each_gazing_token,
|
| 101 |
+
"reconstruction_loss_each_gazing_token_mask": reconstruction_loss_each_gazing_token_mask,
|
| 102 |
+
"num_gazing_before_each_reconstruction_frame": num_gazing_before_each_reconstruction_frame,
|
| 103 |
+
"num_non_padded_gazing_at_each_reconstruction_frame": num_non_padded_gazing_at_each_reconstruction_frame,
|
| 104 |
+
"frame_idx_to_reconstruct": frame_idx_to_reconstruct,
|
| 105 |
+
"image_mean": self.transform.image_mean,
|
| 106 |
+
"image_std": self.transform.image_std,
|
| 107 |
+
"rescale_factor": self.transform.rescale_factor,
|
| 108 |
+
"scales": self.scales,
|
| 109 |
+
}
|
| 110 |
+
return outputs
|
| 111 |
+
|
| 112 |
+
def loss(self, inputs, gaze_outputs, outputs):
|
| 113 |
+
"""
|
| 114 |
+
Compute the loss of the outputs. Used for training the task itself.
|
| 115 |
+
"""
|
| 116 |
+
reconstruction_loss = outputs['reconstruction_loss']
|
| 117 |
+
reconstruction_loss_each_gazing_token = outputs['reconstruction_loss_each_gazing_token']
|
| 118 |
+
reconstruction_loss_each_gazing_token_mask = outputs['reconstruction_loss_each_gazing_token_mask']
|
| 119 |
+
return reconstruction_loss, reconstruction_loss_each_gazing_token, reconstruction_loss_each_gazing_token_mask
|
| 120 |
+
|
| 121 |
+
def reward(self, inputs, gaze_outputs, outputs):
|
| 122 |
+
"""
|
| 123 |
+
Compute the reward of the outputs. Used for training the gazing model.
|
| 124 |
+
"""
|
| 125 |
+
reconstruction_loss_each_reconstruction_frame = outputs['reconstruction_loss_each_reconstruction_frame']
|
| 126 |
+
rewards = -reconstruction_loss_each_reconstruction_frame.detach()
|
| 127 |
+
|
| 128 |
+
# Gazing length before each reward
|
| 129 |
+
traj_len_each_reward = outputs['num_gazing_before_each_reconstruction_frame']
|
| 130 |
+
|
| 131 |
+
return rewards, traj_len_each_reward
|
| 132 |
+
|
| 133 |
+
def metric(self, inputs, gaze_outputs, outputs):
|
| 134 |
+
"""
|
| 135 |
+
Compute the metric used for recording during validation.
|
| 136 |
+
"""
|
| 137 |
+
# Reconstruction loss
|
| 138 |
+
reconstruction_loss, _, __ = self.loss(inputs, gaze_outputs, outputs)
|
| 139 |
+
reconstruction_loss = reconstruction_loss.mean()
|
| 140 |
+
|
| 141 |
+
# Average gazing ratio per frame
|
| 142 |
+
bs, num_frames = inputs['video'].shape[:2]
|
| 143 |
+
num_vision_tokens_each_frame = gaze_outputs['num_vision_tokens_each_frame']
|
| 144 |
+
num_gazing_total = (~gaze_outputs['if_padded_gazing']).sum()
|
| 145 |
+
avg_gazing_ratio = num_gazing_total / (bs *num_frames * num_vision_tokens_each_frame)
|
| 146 |
+
|
| 147 |
+
metrics = {
|
| 148 |
+
'reconstruction_loss': reconstruction_loss,
|
| 149 |
+
'avg_gazing_ratio_per_frame': avg_gazing_ratio,
|
| 150 |
+
}
|
| 151 |
+
return metrics
|
| 152 |
+
|
| 153 |
+
def visualize(self, inputs, gaze_outputs, task_outputs, rl_outputs=None):
|
| 154 |
+
"""
|
| 155 |
+
Visualize the outputs.
|
| 156 |
+
"""
|
| 157 |
+
for method in self.visualize_methods:
|
| 158 |
+
method(inputs, gaze_outputs, task_outputs, rl_outputs)
|
| 159 |
+
|
| 160 |
+
def forward(self, inputs, gaze_outputs):
|
| 161 |
+
"""
|
| 162 |
+
Compute the outputs and the loss, reward, and metric of the outputs.
|
| 163 |
+
inputs:
|
| 164 |
+
image: B, C, H, W
|
| 165 |
+
gaze_outputs:
|
| 166 |
+
gazing_pos: B, N
|
| 167 |
+
"""
|
| 168 |
+
outputs = self.forward_output(inputs, gaze_outputs)
|
| 169 |
+
loss, reconstruction_loss_each_gazing_token, reconstruction_loss_each_gazing_token_mask = self.loss(inputs, gaze_outputs, outputs)
|
| 170 |
+
reward, traj_len_each_reward = self.reward(inputs, gaze_outputs, outputs)
|
| 171 |
+
metric = self.metric(inputs, gaze_outputs, outputs)
|
| 172 |
+
|
| 173 |
+
to_return = {
|
| 174 |
+
'outputs': outputs,
|
| 175 |
+
'loss': loss,
|
| 176 |
+
'reward': reward,
|
| 177 |
+
'traj_len_each_reward': traj_len_each_reward,
|
| 178 |
+
'task_losses': reconstruction_loss_each_gazing_token,
|
| 179 |
+
'task_losses_mask': reconstruction_loss_each_gazing_token_mask,
|
| 180 |
+
'metrics': metric,
|
| 181 |
+
}
|
| 182 |
+
return to_return
|
autogaze/tasks/video_mae_reconstruction/visualize_video_mae_reconstruction.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import wandb
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from autogaze.utils import UnNormalize
|
| 7 |
+
|
| 8 |
+
class VisualizeReconstruction:
|
| 9 |
+
def __init__(self, **kwargs):
|
| 10 |
+
|
| 11 |
+
self.visualize_step = 0
|
| 12 |
+
if wandb.run is not None:
|
| 13 |
+
# define our custom x axis metric
|
| 14 |
+
wandb.define_metric("visualize_gaze/visualize_step")
|
| 15 |
+
# define which metrics will be plotted against it
|
| 16 |
+
wandb.define_metric("visualize_gaze/*", step_metric="visualize_gaze/visualize_step")
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def __call__(self, inputs, gaze_outputs, task_outputs, rl_outputs):
|
| 20 |
+
# Get all information for visualization
|
| 21 |
+
videos = inputs['video']
|
| 22 |
+
gazing_mask = gaze_outputs['gazing_mask'] # containing multi-scale masks; list of B * T * N_each_scale
|
| 23 |
+
frame_sampling_rate = gaze_outputs['frame_sampling_rate']
|
| 24 |
+
scales = task_outputs['outputs']['scales']
|
| 25 |
+
reconstruction = task_outputs['outputs']['reconstruction']
|
| 26 |
+
frame_idx_to_reconstruct = task_outputs['outputs']['frame_idx_to_reconstruct']
|
| 27 |
+
image_mean = task_outputs['outputs']['image_mean']
|
| 28 |
+
image_std = task_outputs['outputs']['image_std']
|
| 29 |
+
rescale_factor = task_outputs['outputs']['rescale_factor']
|
| 30 |
+
num_scales = len(scales)
|
| 31 |
+
|
| 32 |
+
# sample the frames to visualize
|
| 33 |
+
videos = videos[:, ::frame_sampling_rate]
|
| 34 |
+
assert videos.shape[1] == gazing_mask[0].shape[1]
|
| 35 |
+
|
| 36 |
+
# only visualize the first instance
|
| 37 |
+
video = videos[0]
|
| 38 |
+
gazing_mask = [m[0] for m in gazing_mask]
|
| 39 |
+
reconstruction = reconstruction[0]
|
| 40 |
+
|
| 41 |
+
unnormalize = UnNormalize(image_mean, image_std, rescale_factor)
|
| 42 |
+
|
| 43 |
+
video = unnormalize(video)
|
| 44 |
+
reconstruction = unnormalize(reconstruction)
|
| 45 |
+
video = video.cpu().float().numpy()
|
| 46 |
+
reconstruction = reconstruction.cpu().float().numpy()
|
| 47 |
+
|
| 48 |
+
# complete the reconstruction by filling the unselected frames
|
| 49 |
+
reconstruction_full = np.zeros_like(video)
|
| 50 |
+
reconstruction_full[frame_idx_to_reconstruct.cpu().numpy()] = reconstruction
|
| 51 |
+
reconstruction = reconstruction_full
|
| 52 |
+
|
| 53 |
+
# Create a figure with subplots: original video frames and one row for each scale's masked video frames
|
| 54 |
+
T = video.shape[0] # Number of frames
|
| 55 |
+
fig, axes = plt.subplots(num_scales + 2, T, figsize=(3 * T, 3 * (num_scales + 2)))
|
| 56 |
+
|
| 57 |
+
# Plot original video frames
|
| 58 |
+
for t in range(T):
|
| 59 |
+
frame = video[t].transpose(1, 2, 0) # C * H * W -> H * W * C
|
| 60 |
+
axes[0, t].imshow(frame)
|
| 61 |
+
axes[0, t].set_title(f'Original Frame {t+1}')
|
| 62 |
+
axes[0, t].axis('off')
|
| 63 |
+
|
| 64 |
+
# Visualize masked video for each scale
|
| 65 |
+
for scale_idx in range(num_scales):
|
| 66 |
+
scale_mask = gazing_mask[scale_idx] # T * N
|
| 67 |
+
|
| 68 |
+
for t in range(T):
|
| 69 |
+
frame_mask = scale_mask[t] # N
|
| 70 |
+
|
| 71 |
+
# Reshape if it's flattened
|
| 72 |
+
if frame_mask.dim() == 1:
|
| 73 |
+
h = w = int(frame_mask.shape[0] ** 0.5)
|
| 74 |
+
frame_mask = frame_mask.reshape(h, w)
|
| 75 |
+
|
| 76 |
+
# Resize mask to match current scale
|
| 77 |
+
frame_mask = F.interpolate(
|
| 78 |
+
frame_mask.unsqueeze(0).unsqueeze(0),
|
| 79 |
+
size=(scales[scale_idx], scales[scale_idx]),
|
| 80 |
+
mode='nearest'
|
| 81 |
+
).squeeze()
|
| 82 |
+
|
| 83 |
+
frame_mask = frame_mask.cpu().float().numpy()
|
| 84 |
+
|
| 85 |
+
# Resize frame to match mask dimensions
|
| 86 |
+
frame = video[t] # C * H * W
|
| 87 |
+
scale_frame = F.interpolate(
|
| 88 |
+
torch.from_numpy(frame).unsqueeze(0),
|
| 89 |
+
size=(scales[scale_idx], scales[scale_idx]),
|
| 90 |
+
mode='bicubic',
|
| 91 |
+
align_corners=False
|
| 92 |
+
).squeeze().clamp(0, 1).numpy()
|
| 93 |
+
|
| 94 |
+
masked_frame = scale_frame * (0.8 * frame_mask[None, :, :] + 0.2)
|
| 95 |
+
|
| 96 |
+
# Plot this frame's masked image
|
| 97 |
+
axes[scale_idx + 1, t].imshow(masked_frame.transpose(1, 2, 0))
|
| 98 |
+
|
| 99 |
+
# Add red borders around gazed patches
|
| 100 |
+
original_mask = gazing_mask[scale_idx][t]
|
| 101 |
+
if original_mask.dim() == 1:
|
| 102 |
+
patch_grid_size = int(original_mask.shape[0] ** 0.5)
|
| 103 |
+
original_mask = original_mask.reshape(patch_grid_size, patch_grid_size)
|
| 104 |
+
|
| 105 |
+
patch_size = scales[scale_idx] // patch_grid_size
|
| 106 |
+
for i in range(patch_grid_size):
|
| 107 |
+
for j in range(patch_grid_size):
|
| 108 |
+
if original_mask[i, j] > 0.5: # If this patch is gazed at
|
| 109 |
+
rect = plt.Rectangle((j * patch_size - 0.5, i * patch_size - 0.5),
|
| 110 |
+
patch_size, patch_size,
|
| 111 |
+
linewidth=1, edgecolor='red', facecolor='none')
|
| 112 |
+
axes[scale_idx + 1, t].add_patch(rect)
|
| 113 |
+
|
| 114 |
+
axes[scale_idx + 1, t].set_title(f'Scale {scales[scale_idx]} Frame {t+1}')
|
| 115 |
+
axes[scale_idx + 1, t].axis('off')
|
| 116 |
+
|
| 117 |
+
# plot the reconstruction
|
| 118 |
+
for t in range(T):
|
| 119 |
+
frame = reconstruction[t].transpose(1, 2, 0) # C * H * W -> H * W * C
|
| 120 |
+
axes[num_scales + 1, t].imshow(frame)
|
| 121 |
+
axes[num_scales + 1, t].set_title(f'Reconstructed Frame {t+1}')
|
| 122 |
+
axes[num_scales + 1, t].axis('off')
|
| 123 |
+
|
| 124 |
+
# Adjust layout and log to wandb
|
| 125 |
+
plt.tight_layout()
|
| 126 |
+
wandb.log({
|
| 127 |
+
"visualize_gaze/visualize_step": self.visualize_step,
|
| 128 |
+
"visualize_gaze/visualize_gaze": wandb.Image(plt)
|
| 129 |
+
})
|
| 130 |
+
|
| 131 |
+
# Close the figure to free memory
|
| 132 |
+
plt.close(fig)
|
| 133 |
+
|
| 134 |
+
self.visualize_step += 1
|
autogaze/utils.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import builtins
|
| 2 |
+
|
| 3 |
+
from omegaconf import OmegaConf
|
| 4 |
+
from loguru import logger
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import wandb
|
| 10 |
+
import random
|
| 11 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class UnNormalize(object):
|
| 15 |
+
def __init__(self, mean, std, rescale_factor=None):
|
| 16 |
+
self.mean = mean
|
| 17 |
+
self.std = std
|
| 18 |
+
self.rescale_factor = rescale_factor
|
| 19 |
+
|
| 20 |
+
def __call__(self, image):
|
| 21 |
+
image2 = torch.clone(image)
|
| 22 |
+
dims = len(image2.shape)
|
| 23 |
+
if dims == 3:
|
| 24 |
+
image2 = image2.unsqueeze(0)
|
| 25 |
+
image2 = image2.permute(1, 0, 2, 3)
|
| 26 |
+
for t, m, s in zip(image2, self.mean, self.std):
|
| 27 |
+
t.mul_(s).add_(m)
|
| 28 |
+
image2 = image2.permute(1, 0, 2, 3)
|
| 29 |
+
if dims == 3:
|
| 30 |
+
image2 = image2.squeeze(0)
|
| 31 |
+
|
| 32 |
+
if self.rescale_factor is not None:
|
| 33 |
+
standard_rescale = 1.0 / 255.0
|
| 34 |
+
if abs(self.rescale_factor - standard_rescale) > 1e-6:
|
| 35 |
+
# if the processor uses 1/127.5, needs /2.0 + 0.5 correction
|
| 36 |
+
image2 = image2 / 2.0 + 0.5
|
| 37 |
+
|
| 38 |
+
return torch.clamp(image2, 0, 1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AverageScalarMeter(object):
|
| 42 |
+
def __init__(self, window_size):
|
| 43 |
+
self.window_size = window_size
|
| 44 |
+
self.current_size = 0
|
| 45 |
+
self.mean = 0
|
| 46 |
+
|
| 47 |
+
def update(self, values):
|
| 48 |
+
size = values.size()[0]
|
| 49 |
+
if size == 0:
|
| 50 |
+
return
|
| 51 |
+
new_mean = torch.mean(values.float(), dim=0).cpu().numpy().item()
|
| 52 |
+
size = np.clip(size, 0, self.window_size)
|
| 53 |
+
old_size = min(self.window_size - size, self.current_size)
|
| 54 |
+
size_sum = old_size + size
|
| 55 |
+
self.current_size = size_sum
|
| 56 |
+
self.mean = (self.mean * old_size + new_mean * size) / size_sum
|
| 57 |
+
|
| 58 |
+
def clear(self):
|
| 59 |
+
self.current_size = 0
|
| 60 |
+
self.mean = 0
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return self.current_size
|
| 64 |
+
|
| 65 |
+
def get_mean(self):
|
| 66 |
+
return self.mean
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def plot_grad_norms(named_parameters, name_prefix=''):
|
| 70 |
+
for name, param in named_parameters:
|
| 71 |
+
if param.grad is not None:
|
| 72 |
+
norm = torch.linalg.vector_norm(param.grad, 2.0).item()
|
| 73 |
+
wandb.log({f'{name_prefix}{name}': norm})
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def suppress_print():
|
| 77 |
+
"""Suppresses printing from the current process."""
|
| 78 |
+
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
|
| 79 |
+
pass
|
| 80 |
+
builtins.print = ignore
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def suppress_wandb():
|
| 84 |
+
"""Suppresses wandb logging from the current_process."""
|
| 85 |
+
# Store original functions
|
| 86 |
+
original_functions = {}
|
| 87 |
+
for attr_name in dir(wandb):
|
| 88 |
+
attr = getattr(wandb, attr_name)
|
| 89 |
+
if callable(attr) and not attr_name.startswith('__'):
|
| 90 |
+
original_functions[attr_name] = attr
|
| 91 |
+
|
| 92 |
+
# Replace with no-op function
|
| 93 |
+
def make_noop(name):
|
| 94 |
+
def noop(*args, **kwargs):
|
| 95 |
+
pass
|
| 96 |
+
return noop
|
| 97 |
+
|
| 98 |
+
setattr(wandb, attr_name, make_noop(attr_name))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def suppress_logging():
|
| 102 |
+
"""Suppresses loguru logging from the current process."""
|
| 103 |
+
logger.remove() # Remove all handlers
|
| 104 |
+
logger.add(lambda _: None) # Add a no-op handler
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def dump_cfg(cfg, logdir):
|
| 108 |
+
out_f = os.path.join(logdir, "config.yaml")
|
| 109 |
+
with open(out_f, "w") as f:
|
| 110 |
+
f.write(OmegaConf.to_yaml(cfg))
|
| 111 |
+
print("Wrote config to: {}".format(out_f))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_scheduled_temperature(step, total_steps, temp_schedule_args):
|
| 115 |
+
if temp_schedule_args['mode'] == 'exp':
|
| 116 |
+
t_start = temp_schedule_args['exp']['temp_start']
|
| 117 |
+
t_end = temp_schedule_args['exp']['temp_end']
|
| 118 |
+
return t_start * (t_end / t_start) ** (step / total_steps)
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(f"Unknown temp_schedule_args: {temp_schedule_args}")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def seed_everything(seed: int):
|
| 124 |
+
random.seed(seed)
|
| 125 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 126 |
+
np.random.seed(seed)
|
| 127 |
+
torch.manual_seed(seed)
|
| 128 |
+
if torch.cuda.is_available():
|
| 129 |
+
torch.cuda.manual_seed(seed)
|
| 130 |
+
torch.backends.cudnn.deterministic = True
|
| 131 |
+
torch.backends.cudnn.benchmark = False
|
| 132 |
+
if hasattr(torch, 'mps') and torch.backends.mps.is_available():
|
| 133 |
+
torch.mps.manual_seed(seed)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def seed_worker(worker_id):
|
| 137 |
+
worker_seed = torch.initial_seed() % 2**32
|
| 138 |
+
np.random.seed(worker_seed)
|
| 139 |
+
random.seed(worker_seed)
|
| 140 |
+
torch.manual_seed(worker_seed + worker_id) # Add worker_id to make it different
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def format_kwargs(cfg, optional_args):
|
| 144 |
+
return {
|
| 145 |
+
arg_name: getattr(getattr(cfg, section), attr)
|
| 146 |
+
for arg_name, section, attr in optional_args
|
| 147 |
+
if hasattr(cfg, section) and hasattr(getattr(cfg, section), attr)
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def move_inputs_to_cuda(inputs):
|
| 152 |
+
for k, v in inputs.items():
|
| 153 |
+
if isinstance(v, torch.Tensor):
|
| 154 |
+
inputs[k] = v.cuda()
|
| 155 |
+
elif isinstance(v, dict):
|
| 156 |
+
inputs[k] = move_inputs_to_cuda(v)
|
| 157 |
+
return inputs
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def unwrap_model(model):
|
| 161 |
+
"""Unwrap DDP model if needed."""
|
| 162 |
+
if isinstance(model, DDP):
|
| 163 |
+
return model.module
|
| 164 |
+
return model
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_gazing_pos_from_gazing_mask(gazing_mask: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
Get the gazing positions from the gazing mask.
|
| 170 |
+
inputs:
|
| 171 |
+
gazing_mask: (B, N). 1 means gazed, 0 means not gazed.
|
| 172 |
+
outputs:
|
| 173 |
+
gazing_pos: (B, K). K is the maximum number of gazed tokens per instance. If the instance has less than K gazed tokens, the remaining positions are padded with -1.
|
| 174 |
+
if_padded_gazing: (B, K). 1 means padded, 0 means not padded.
|
| 175 |
+
"""
|
| 176 |
+
# x: (B, N) with 0/1 values (float/bool/int all fine)
|
| 177 |
+
gazing_mask = gazing_mask.to(torch.long)
|
| 178 |
+
B, N = gazing_mask.shape
|
| 179 |
+
|
| 180 |
+
# Indices per row
|
| 181 |
+
idx = torch.arange(N, device=gazing_mask.device).expand(B, N)
|
| 182 |
+
|
| 183 |
+
# Sort key: put ones first, keep original order among ones/zeros
|
| 184 |
+
# - ones get key = idx (0..N-1)
|
| 185 |
+
# - zeros get key = N + idx (pushed after all ones)
|
| 186 |
+
key = (1 - gazing_mask) * N + idx
|
| 187 |
+
order = key.argsort(dim=1, stable=True) # (B, N)
|
| 188 |
+
sorted_idx = idx.gather(1, order) # ones first, then zeros
|
| 189 |
+
|
| 190 |
+
# Max number of ones (K) and per-row counts
|
| 191 |
+
counts = gazing_mask.sum(dim=1) # (B,)
|
| 192 |
+
K = int(counts.max().item())
|
| 193 |
+
|
| 194 |
+
if K == 0:
|
| 195 |
+
return sorted_idx[:, :0] # (B, 0) empty result
|
| 196 |
+
|
| 197 |
+
topk = sorted_idx[:, :K] # (B, K)
|
| 198 |
+
pos = torch.arange(K, device=gazing_mask.device).expand(B, K)
|
| 199 |
+
mask = pos < counts.unsqueeze(1) # True where a real "1" exists
|
| 200 |
+
|
| 201 |
+
# Pad with -1 where the row has fewer than K ones
|
| 202 |
+
gazing_pos = topk.masked_fill(~mask, -1)
|
| 203 |
+
if_padded_gazing = (gazing_pos == -1)
|
| 204 |
+
|
| 205 |
+
return gazing_pos, if_padded_gazing
|
demo_utils.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
import numpy as np
|
|
@@ -9,19 +14,11 @@ from transformers import VivitImageProcessor
|
|
| 9 |
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
from omegaconf import OmegaConf
|
| 11 |
from einops import rearrange
|
| 12 |
-
|
| 13 |
-
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'gengaze'))
|
| 14 |
from autogaze.models.autogaze import AutoGaze
|
| 15 |
from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch
|
| 16 |
from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction
|
| 17 |
from autogaze.utils import UnNormalize
|
| 18 |
-
from tqdm import trange
|
| 19 |
-
|
| 20 |
-
try:
|
| 21 |
-
import spaces
|
| 22 |
-
ZEROGPU_AVAILABLE = True
|
| 23 |
-
except ImportError:
|
| 24 |
-
ZEROGPU_AVAILABLE = False
|
| 25 |
|
| 26 |
|
| 27 |
def image_to_video(image_path, output_path, fps):
|
|
@@ -218,24 +215,29 @@ def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.
|
|
| 218 |
progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%")
|
| 219 |
yield None
|
| 220 |
|
| 221 |
-
# Extract mini-batch from CPU and move to GPU: (batch_size, nt, 16, C, H, W)
|
| 222 |
spatial_batch = video_chunks[start_idx:end_idx].to(device)
|
| 223 |
-
# Flatten to (batch_size * nt, 16, C, H, W) for model
|
| 224 |
spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w')
|
| 225 |
print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks')
|
| 226 |
|
| 227 |
# Run AutoGaze on this mini-batch
|
| 228 |
batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement)
|
| 229 |
|
|
|
|
|
|
|
|
|
|
| 230 |
# Free GPU memory after forward pass
|
| 231 |
del spatial_batch
|
| 232 |
|
| 233 |
# Count gazing tokens for this batch
|
| 234 |
if_padded = batch_gaze_output.get('if_padded_gazing')
|
| 235 |
if if_padded is not None:
|
| 236 |
-
|
|
|
|
|
|
|
| 237 |
else:
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
|
| 240 |
# Store the output
|
| 241 |
all_gaze_outputs.append(batch_gaze_output)
|
|
@@ -283,7 +285,7 @@ def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.
|
|
| 283 |
# Clean up mini-batch outputs
|
| 284 |
del all_gaze_outputs
|
| 285 |
|
| 286 |
-
total_possible_tokens = 196 * 16 * num_chunks
|
| 287 |
|
| 288 |
# Extract gazing masks for later visualization (already in batched form)
|
| 289 |
gazing_masks_batched = gaze_output['gazing_mask'] # List of 4 scales, each (num_chunks, 16, num_patches)
|
|
|
|
| 1 |
+
# IMPORTANT: Import spaces first, before any CUDA-related packages (torch, etc.)
|
| 2 |
+
try:
|
| 3 |
+
import spaces
|
| 4 |
+
ZEROGPU_AVAILABLE = True
|
| 5 |
+
except ImportError:
|
| 6 |
+
ZEROGPU_AVAILABLE = False
|
| 7 |
+
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
import numpy as np
|
|
|
|
| 14 |
from PIL import Image, ImageDraw, ImageFont
|
| 15 |
from omegaconf import OmegaConf
|
| 16 |
from einops import rearrange
|
| 17 |
+
from tqdm import trange
|
|
|
|
| 18 |
from autogaze.models.autogaze import AutoGaze
|
| 19 |
from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch
|
| 20 |
from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction
|
| 21 |
from autogaze.utils import UnNormalize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def image_to_video(image_path, output_path, fps):
|
|
|
|
| 215 |
progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%")
|
| 216 |
yield None
|
| 217 |
|
|
|
|
| 218 |
spatial_batch = video_chunks[start_idx:end_idx].to(device)
|
|
|
|
| 219 |
spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w')
|
| 220 |
print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks')
|
| 221 |
|
| 222 |
# Run AutoGaze on this mini-batch
|
| 223 |
batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement)
|
| 224 |
|
| 225 |
+
num_gazing_each_frame = batch_gaze_output['num_gazing_each_frame'][:T]
|
| 226 |
+
num_gazing_total = num_gazing_each_frame.sum().item()
|
| 227 |
+
|
| 228 |
# Free GPU memory after forward pass
|
| 229 |
del spatial_batch
|
| 230 |
|
| 231 |
# Count gazing tokens for this batch
|
| 232 |
if_padded = batch_gaze_output.get('if_padded_gazing')
|
| 233 |
if if_padded is not None:
|
| 234 |
+
print(f'shape of if_padded: {if_padded.shape}')
|
| 235 |
+
if_padded = if_padded[:, :min(num_gazing_total, if_padded.shape[1])]
|
| 236 |
+
new_gazing_tokens = (~if_padded).sum().item()
|
| 237 |
else:
|
| 238 |
+
new_gazing_tokens = (batch_gaze_output['gazing_pos'] < (196 * T)).sum().item()
|
| 239 |
+
total_gazing_tokens += new_gazing_tokens
|
| 240 |
+
print(f'Batch {batch_idx+1}: Gazing tokens = {new_gazing_tokens}, Total gazing tokens so far = {total_gazing_tokens}')
|
| 241 |
|
| 242 |
# Store the output
|
| 243 |
all_gaze_outputs.append(batch_gaze_output)
|
|
|
|
| 285 |
# Clean up mini-batch outputs
|
| 286 |
del all_gaze_outputs
|
| 287 |
|
| 288 |
+
total_possible_tokens = 196 * min(T, 16) * num_chunks
|
| 289 |
|
| 290 |
# Extract gazing masks for later visualization (already in batched form)
|
| 291 |
gazing_masks_batched = gaze_output['gazing_mask'] # List of 4 scales, each (num_chunks, 16, num_patches)
|
packages.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git
|
| 2 |
+
git-lfs
|
| 3 |
+
ffmpeg
|
| 4 |
+
pkg-config
|
| 5 |
+
libavcodec-dev
|
| 6 |
+
libavformat-dev
|
| 7 |
+
libavutil-dev
|
| 8 |
+
libswscale-dev
|
| 9 |
+
libswresample-dev
|
| 10 |
+
libavdevice-dev
|
| 11 |
+
libavfilter-dev
|
| 12 |
+
libsm6
|
| 13 |
+
libxext6
|
| 14 |
+
cmake
|
| 15 |
+
rsync
|
| 16 |
+
libgl1
|
requirements.txt
CHANGED
|
@@ -11,5 +11,5 @@ tqdm==4.67.1
|
|
| 11 |
transformers==4.53.0
|
| 12 |
omegaconf==2.3.0
|
| 13 |
einops==0.8.1
|
| 14 |
-
av
|
| 15 |
-
imageio==2.37.0
|
|
|
|
| 11 |
transformers==4.53.0
|
| 12 |
omegaconf==2.3.0
|
| 13 |
einops==0.8.1
|
| 14 |
+
av
|
| 15 |
+
imageio[ffmpeg]==2.37.0
|