magnusdtd commited on
Commit
10d1937
·
verified ·
1 Parent(s): 4ce7c22

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .gitignore +215 -0
  2. README.md +19 -0
  3. main.py +198 -0
  4. requirements.txt +5 -0
  5. transnetv2-pytorch-weights.pth +3 -0
  6. transnetv2_pytorch.py +318 -0
.gitignore ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # Kaggle
210
+ .kaggle/
211
+ data/
212
+ volumes/
213
+ json/
214
+ data/
215
+ kaggle.json
README.md CHANGED
@@ -1,3 +1,22 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ # TransNetV2 (PyTorch Version)
6
+
7
+ This repository provides a PyTorch version of [TransNet V2](https://github.com/soCzech/TransNetV2), a state-of-the-art neural network for shot boundary detection in videos.
8
+
9
+ ## Installation
10
+
11
+ Clone the repository and install the required dependencies.
12
+
13
+ ```sh
14
+ sudo apt-get install ffmpeg
15
+ pip install requirements.txt
16
+ ```
17
+
18
+ ## Usage
19
+
20
+ ```sh
21
+ python -m main --files="path/to/your/file/or/folder" --weights="path/to/the/model/weights" --visualize
22
+ ```
main.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from TransNetV2.transnetv2_pytorch import TransNetV2
2
+ from typing import Optional
3
+ import torch
4
+ import os
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw
7
+ import argparse
8
+ from tqdm import tqdm
9
+
10
+ try:
11
+ import ffmpeg
12
+ except ModuleNotFoundError:
13
+ raise ModuleNotFoundError("For `predict_video` function `ffmpeg` needs to be installed in order to extract "
14
+ "individual frames from video file. Install `ffmpeg` command line tool and then "
15
+ "install python wrapper by `pip install ffmpeg-python`.")
16
+
17
+
18
+ class TransNetV2Torch:
19
+ def __init__(self, model_path: Optional[str] = None):
20
+ weights_path = model_path or os.path.join(os.path.dirname(__file__), "transnetv2-pytorch-weights.pth")
21
+ if not os.path.isfile(weights_path):
22
+ raise FileNotFoundError(f"[TransNetV2] ERROR: weights file not found at {weights_path}.")
23
+ else:
24
+ print(f"[TransNetV2] Using weights from {weights_path}.")
25
+
26
+ self._input_size = (27, 48, 3)
27
+ self.model = TransNetV2()
28
+ try:
29
+ self.model.load_state_dict(torch.load(weights_path))
30
+ except Exception as exc:
31
+ raise IOError(f"[TransNetV2] Could not load weights from {weights_path}.") from exc
32
+ self.model.eval()
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ self.model.to(self.device)
35
+
36
+
37
+ def predict_raw(self, frames: np.ndarray):
38
+ assert len(frames.shape) == 5 and frames.shape[2:] == self._input_size, \
39
+ "[TransNetV2] Input shape must be [batch, frames, height, width, 3]."
40
+
41
+ frames_tensor = torch.from_numpy(frames)
42
+ with torch.no_grad():
43
+ single_frame_pred, all_frames_pred = self.model(frames_tensor.to(self.device))
44
+
45
+ single_frame_pred = torch.sigmoid(single_frame_pred).cpu().numpy()
46
+ all_frames_pred = torch.sigmoid(all_frames_pred["many_hot"]).cpu().numpy()
47
+
48
+ return single_frame_pred, all_frames_pred
49
+
50
+ def predict_frames(self, frames: np.ndarray):
51
+ assert len(frames.shape) == 4 and frames.shape[1:] == self._input_size, \
52
+ "[TransNetV2] Input shape must be [frames, height, width, 3]."
53
+
54
+ total = len(frames)
55
+
56
+ def input_iterator():
57
+ # return windows of size 100 where the first/last 25 frames are from the previous/next batch
58
+ # the first and last window must be padded by copies of the first and last frame of the video
59
+ no_padded_frames_start = 25
60
+ no_padded_frames_end = 25 + 50 - (total % 50 if total % 50 != 0 else 50) # 25 - 74
61
+
62
+ start_frame = np.expand_dims(frames[0], 0)
63
+ end_frame = np.expand_dims(frames[-1], 0)
64
+ padded_inputs = np.concatenate(
65
+ [start_frame] * no_padded_frames_start + [frames] + [end_frame] * no_padded_frames_end, 0
66
+ )
67
+
68
+ ptr = 0
69
+ while ptr + 100 <= len(padded_inputs):
70
+ out = padded_inputs[ptr:ptr + 100]
71
+ ptr += 50
72
+ yield out[np.newaxis]
73
+
74
+ predictions = []
75
+
76
+ for inp in input_iterator():
77
+ single_frame_pred, all_frames_pred = self.predict_raw(inp)
78
+ predictions.append((single_frame_pred[0, 25:75, 0],
79
+ all_frames_pred[0, 25:75, 0]))
80
+
81
+ print("\r[TransNetV2] Processing video frames {}/{}".format(
82
+ min(len(predictions) * 50, total), total
83
+ ), end="")
84
+ print("")
85
+
86
+ single_frame_pred = np.concatenate([single_ for single_, _ in predictions])
87
+ all_frames_pred = np.concatenate([all_ for _, all_ in predictions])
88
+
89
+ return single_frame_pred[:total], all_frames_pred[:total]
90
+
91
+
92
+ def predict_video(self, video_fn: str):
93
+ print("[TransNetV2] Extracting frames from {}".format(video_fn))
94
+ video_stream, _ = ffmpeg.input(video_fn).output(
95
+ "pipe:", format="rawvideo", pix_fmt="rgb24", s="48x27"
96
+ ).run(capture_stdout=True, capture_stderr=True)
97
+
98
+ video = np.frombuffer(video_stream, np.uint8).reshape([-1, 27, 48, 3])
99
+ return (video, *self.predict_frames(video))
100
+
101
+ @staticmethod
102
+ def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5):
103
+ predictions = (predictions > threshold).astype(np.uint8)
104
+
105
+ scenes = []
106
+ t_prev, start = 0, 0
107
+ for i, t in enumerate(predictions):
108
+ if t_prev == 1 and t == 0:
109
+ start = i
110
+ if t_prev == 0 and t == 1 and i != 0:
111
+ scenes.append([start, i])
112
+ t_prev = t
113
+ if t == 0:
114
+ scenes.append([start, i])
115
+ if len(scenes) == 0: # just fix if all predictions are 1
116
+ return np.array([[0, len(predictions) - 1]], dtype=np.int32)
117
+
118
+ return np.array(scenes, dtype=np.int32)
119
+
120
+ @staticmethod
121
+ def visualize_predictions(frames: np.ndarray, predictions):
122
+
123
+ if isinstance(predictions, np.ndarray):
124
+ predictions = [predictions]
125
+
126
+ ih, iw, ic = frames.shape[1:]
127
+ width = 25
128
+
129
+ # pad frames so that length of the video is divisible by width
130
+ # pad frames also by len(predictions) pixels in width in order to show predictions
131
+ pad_with = width - len(frames) % width if len(frames) % width != 0 else 0
132
+ frames = np.pad(frames, [(0, pad_with), (0, 1), (0, len(predictions)), (0, 0)])
133
+
134
+ predictions = [np.pad(x, (0, pad_with)) for x in predictions]
135
+ height = len(frames) // width
136
+
137
+ img = frames.reshape([height, width, ih + 1, iw + len(predictions), ic])
138
+ img = np.concatenate(np.split(
139
+ np.concatenate(np.split(img, height), axis=2)[0], width
140
+ ), axis=2)[0, :-1]
141
+
142
+ img = Image.fromarray(img)
143
+ draw = ImageDraw.Draw(img)
144
+
145
+ for i, pred in enumerate(zip(*predictions)):
146
+ x, y = i % width, i // width
147
+ x, y = x * (iw + len(predictions)) + iw, y * (ih + 1) + ih - 1
148
+
149
+ # we can visualize multiple predictions per single frame
150
+ for j, p in enumerate(pred):
151
+ color = [0, 0, 0]
152
+ color[(j + 1) % 3] = 255
153
+
154
+ value = round(p * (ih - 1))
155
+ if value != 0:
156
+ draw.line((x + j, y, x + j, y - value), fill=tuple(color), width=1)
157
+ return img
158
+
159
+ def parse_args():
160
+ parser = argparse.ArgumentParser()
161
+ parser.add_argument("--files", type=str, help="path to video files to process")
162
+ parser.add_argument("--weights", type=str, default=None,
163
+ help="path to TransNet V2 weights, tries to infer the location if not specified")
164
+ parser.add_argument('--visualize', action="store_true",
165
+ help="save a png file with prediction visualization for each extracted video")
166
+ args = parser.parse_args()
167
+
168
+ return args
169
+
170
+ def main(args):
171
+ model = TransNetV2Torch(args.weights)
172
+
173
+ files = []
174
+ if os.path.isdir(args.files):
175
+ for f in os.listdir(args.files):
176
+ if f.lower().endswith(".mp4"):
177
+ files.append(os.path.join(args.files, f))
178
+ else:
179
+ files = [args.files]
180
+
181
+ for file in files:
182
+ video_frames, single_frame_predictions, all_frames_predictions = \
183
+ model.predict_video(file)
184
+
185
+ predictions = np.stack([single_frame_predictions, all_frames_predictions], 1)
186
+ np.savetxt(file + ".predictions.txt", predictions, fmt="%.6f")
187
+
188
+ scenes = model.predictions_to_scenes(single_frame_predictions)
189
+ np.savetxt(file + ".scenes.txt", scenes, fmt="%d")
190
+
191
+ if args.visualize:
192
+ pil_image = model.visualize_predictions(
193
+ video_frames, predictions=(single_frame_predictions, all_frames_predictions))
194
+ pil_image.save(file + ".vis.png")
195
+
196
+ if __name__ == "__main__":
197
+ args = parse_args()
198
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==2.3.2
2
+ pillow==11.3.0
3
+ tqdm==4.67.1
4
+ torch==2.8.0
5
+ ffmpeg-python==0.2.0
transnetv2-pytorch-weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eed5336d5d6a013c67f5863505a26e7e835053e64a9ce413d6b089ccba07bb53
3
+ size 30509621
transnetv2_pytorch.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as functional
4
+
5
+ import random
6
+
7
+
8
+ class TransNetV2(nn.Module):
9
+
10
+ def __init__(self,
11
+ F=16, L=3, S=2, D=1024,
12
+ use_many_hot_targets=True,
13
+ use_frame_similarity=True,
14
+ use_color_histograms=True,
15
+ use_mean_pooling=False,
16
+ dropout_rate=0.5,
17
+ use_convex_comb_reg=False, # not supported
18
+ use_resnet_features=False, # not supported
19
+ use_resnet_like_top=False, # not supported
20
+ frame_similarity_on_last_layer=False): # not supported
21
+ super(TransNetV2, self).__init__()
22
+
23
+ if use_resnet_features or use_resnet_like_top or use_convex_comb_reg or frame_similarity_on_last_layer:
24
+ raise NotImplemented("Some options not implemented in Pytorch version of Transnet!")
25
+
26
+ self.SDDCNN = nn.ModuleList(
27
+ [StackedDDCNNV2(in_filters=3, n_blocks=S, filters=F, stochastic_depth_drop_prob=0.)] +
28
+ [StackedDDCNNV2(in_filters=(F * 2 ** (i - 1)) * 4, n_blocks=S, filters=F * 2 ** i) for i in range(1, L)]
29
+ )
30
+
31
+ self.frame_sim_layer = FrameSimilarity(
32
+ sum([(F * 2 ** i) * 4 for i in range(L)]), lookup_window=101, output_dim=128, similarity_dim=128, use_bias=True
33
+ ) if use_frame_similarity else None
34
+ self.color_hist_layer = ColorHistograms(
35
+ lookup_window=101, output_dim=128
36
+ ) if use_color_histograms else None
37
+
38
+ self.dropout = nn.Dropout(dropout_rate) if dropout_rate is not None else None
39
+
40
+ output_dim = ((F * 2 ** (L - 1)) * 4) * 3 * 6 # 3x6 for spatial dimensions
41
+ if use_frame_similarity: output_dim += 128
42
+ if use_color_histograms: output_dim += 128
43
+
44
+ self.fc1 = nn.Linear(output_dim, D)
45
+ self.cls_layer1 = nn.Linear(D, 1)
46
+ self.cls_layer2 = nn.Linear(D, 1) if use_many_hot_targets else None
47
+
48
+ self.use_mean_pooling = use_mean_pooling
49
+ self.eval()
50
+
51
+ def forward(self, inputs):
52
+ assert isinstance(inputs, torch.Tensor) and list(inputs.shape[2:]) == [27, 48, 3] and inputs.dtype == torch.uint8, \
53
+ "incorrect input type and/or shape"
54
+ # uint8 of shape [B, T, H, W, 3] to float of shape [B, 3, T, H, W]
55
+ x = inputs.permute([0, 4, 1, 2, 3]).float()
56
+ x = x.div_(255.)
57
+
58
+ block_features = []
59
+ for block in self.SDDCNN:
60
+ x = block(x)
61
+ block_features.append(x)
62
+
63
+ if self.use_mean_pooling:
64
+ x = torch.mean(x, dim=[3, 4])
65
+ x = x.permute(0, 2, 1)
66
+ else:
67
+ x = x.permute(0, 2, 3, 4, 1)
68
+ x = x.reshape(x.shape[0], x.shape[1], -1)
69
+
70
+ if self.frame_sim_layer is not None:
71
+ x = torch.cat([self.frame_sim_layer(block_features), x], 2)
72
+
73
+ if self.color_hist_layer is not None:
74
+ x = torch.cat([self.color_hist_layer(inputs), x], 2)
75
+
76
+ x = self.fc1(x)
77
+ x = functional.relu(x)
78
+
79
+ if self.dropout is not None:
80
+ x = self.dropout(x)
81
+
82
+ one_hot = self.cls_layer1(x)
83
+
84
+ if self.cls_layer2 is not None:
85
+ return one_hot, {"many_hot": self.cls_layer2(x)}
86
+
87
+ return one_hot
88
+
89
+
90
+ class StackedDDCNNV2(nn.Module):
91
+
92
+ def __init__(self,
93
+ in_filters,
94
+ n_blocks,
95
+ filters,
96
+ shortcut=True,
97
+ use_octave_conv=False, # not supported
98
+ pool_type="avg",
99
+ stochastic_depth_drop_prob=0.0):
100
+ super(StackedDDCNNV2, self).__init__()
101
+
102
+ if use_octave_conv:
103
+ raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!")
104
+
105
+ assert pool_type == "max" or pool_type == "avg"
106
+ if use_octave_conv and pool_type == "max":
107
+ print("WARN: Octave convolution was designed with average pooling, not max pooling.")
108
+
109
+ self.shortcut = shortcut
110
+ self.DDCNN = nn.ModuleList([
111
+ DilatedDCNNV2(in_filters if i == 1 else filters * 4, filters, octave_conv=use_octave_conv,
112
+ activation=functional.relu if i != n_blocks else None) for i in range(1, n_blocks + 1)
113
+ ])
114
+ self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2)) if pool_type == "max" else nn.AvgPool3d(kernel_size=(1, 2, 2))
115
+ self.stochastic_depth_drop_prob = stochastic_depth_drop_prob
116
+
117
+ def forward(self, inputs):
118
+ x = inputs
119
+ shortcut = None
120
+
121
+ for block in self.DDCNN:
122
+ x = block(x)
123
+ if shortcut is None:
124
+ shortcut = x
125
+
126
+ x = functional.relu(x)
127
+
128
+ if self.shortcut is not None:
129
+ if self.stochastic_depth_drop_prob != 0.:
130
+ if self.training:
131
+ if random.random() < self.stochastic_depth_drop_prob:
132
+ x = shortcut
133
+ else:
134
+ x = x + shortcut
135
+ else:
136
+ x = (1 - self.stochastic_depth_drop_prob) * x + shortcut
137
+ else:
138
+ x += shortcut
139
+
140
+ x = self.pool(x)
141
+ return x
142
+
143
+
144
+ class DilatedDCNNV2(nn.Module):
145
+
146
+ def __init__(self,
147
+ in_filters,
148
+ filters,
149
+ batch_norm=True,
150
+ activation=None,
151
+ octave_conv=False): # not supported
152
+ super(DilatedDCNNV2, self).__init__()
153
+
154
+ if octave_conv:
155
+ raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!")
156
+
157
+ assert not (octave_conv and batch_norm)
158
+
159
+ self.Conv3D_1 = Conv3DConfigurable(in_filters, filters, 1, use_bias=not batch_norm)
160
+ self.Conv3D_2 = Conv3DConfigurable(in_filters, filters, 2, use_bias=not batch_norm)
161
+ self.Conv3D_4 = Conv3DConfigurable(in_filters, filters, 4, use_bias=not batch_norm)
162
+ self.Conv3D_8 = Conv3DConfigurable(in_filters, filters, 8, use_bias=not batch_norm)
163
+
164
+ self.bn = nn.BatchNorm3d(filters * 4, eps=1e-3) if batch_norm else None
165
+ self.activation = activation
166
+
167
+ def forward(self, inputs):
168
+ conv1 = self.Conv3D_1(inputs)
169
+ conv2 = self.Conv3D_2(inputs)
170
+ conv3 = self.Conv3D_4(inputs)
171
+ conv4 = self.Conv3D_8(inputs)
172
+
173
+ x = torch.cat([conv1, conv2, conv3, conv4], dim=1)
174
+
175
+ if self.bn is not None:
176
+ x = self.bn(x)
177
+
178
+ if self.activation is not None:
179
+ x = self.activation(x)
180
+
181
+ return x
182
+
183
+
184
+ class Conv3DConfigurable(nn.Module):
185
+
186
+ def __init__(self,
187
+ in_filters,
188
+ filters,
189
+ dilation_rate,
190
+ separable=True,
191
+ octave=False, # not supported
192
+ use_bias=True,
193
+ kernel_initializer=None): # not supported
194
+ super(Conv3DConfigurable, self).__init__()
195
+
196
+ if octave:
197
+ raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!")
198
+ if kernel_initializer is not None:
199
+ raise NotImplemented("Kernel initializers are not implemented in Pytorch version of Transnet!")
200
+
201
+ assert not (separable and octave)
202
+
203
+ if separable:
204
+ # (2+1)D convolution https://arxiv.org/pdf/1711.11248.pdf
205
+ conv1 = nn.Conv3d(in_filters, 2 * filters, kernel_size=(1, 3, 3),
206
+ dilation=(1, 1, 1), padding=(0, 1, 1), bias=False)
207
+ conv2 = nn.Conv3d(2 * filters, filters, kernel_size=(3, 1, 1),
208
+ dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 0, 0), bias=use_bias)
209
+ self.layers = nn.ModuleList([conv1, conv2])
210
+ else:
211
+ conv = nn.Conv3d(in_filters, filters, kernel_size=3,
212
+ dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 1, 1), bias=use_bias)
213
+ self.layers = nn.ModuleList([conv])
214
+
215
+ def forward(self, inputs):
216
+ x = inputs
217
+ for layer in self.layers:
218
+ x = layer(x)
219
+ return x
220
+
221
+
222
+ class FrameSimilarity(nn.Module):
223
+
224
+ def __init__(self,
225
+ in_filters,
226
+ similarity_dim=128,
227
+ lookup_window=101,
228
+ output_dim=128,
229
+ stop_gradient=False, # not supported
230
+ use_bias=False):
231
+ super(FrameSimilarity, self).__init__()
232
+
233
+ if stop_gradient:
234
+ raise NotImplemented("Stop gradient not implemented in Pytorch version of Transnet!")
235
+
236
+ self.projection = nn.Linear(in_filters, similarity_dim, bias=use_bias)
237
+ self.fc = nn.Linear(lookup_window, output_dim)
238
+
239
+ self.lookup_window = lookup_window
240
+ assert lookup_window % 2 == 1, "`lookup_window` must be odd integer"
241
+
242
+ def forward(self, inputs):
243
+ x = torch.cat([torch.mean(x, dim=[3, 4]) for x in inputs], dim=1)
244
+ x = torch.transpose(x, 1, 2)
245
+
246
+ x = self.projection(x)
247
+ x = functional.normalize(x, p=2, dim=2)
248
+
249
+ batch_size, time_window = x.shape[0], x.shape[1]
250
+ similarities = torch.bmm(x, x.transpose(1, 2)) # [batch_size, time_window, time_window]
251
+ similarities_padded = functional.pad(similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2])
252
+
253
+ batch_indices = torch.arange(0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat(
254
+ [1, time_window, self.lookup_window])
255
+ time_indices = torch.arange(0, time_window, device=x.device).view([1, time_window, 1]).repeat(
256
+ [batch_size, 1, self.lookup_window])
257
+ lookup_indices = torch.arange(0, self.lookup_window, device=x.device).view([1, 1, self.lookup_window]).repeat(
258
+ [batch_size, time_window, 1]) + time_indices
259
+
260
+ similarities = similarities_padded[batch_indices, time_indices, lookup_indices]
261
+ return functional.relu(self.fc(similarities))
262
+
263
+
264
+ class ColorHistograms(nn.Module):
265
+
266
+ def __init__(self,
267
+ lookup_window=101,
268
+ output_dim=None):
269
+ super(ColorHistograms, self).__init__()
270
+
271
+ self.fc = nn.Linear(lookup_window, output_dim) if output_dim is not None else None
272
+ self.lookup_window = lookup_window
273
+ assert lookup_window % 2 == 1, "`lookup_window` must be odd integer"
274
+
275
+ @staticmethod
276
+ def compute_color_histograms(frames):
277
+ frames = frames.int()
278
+
279
+ def get_bin(frames):
280
+ # returns 0 .. 511
281
+ R, G, B = frames[:, :, 0], frames[:, :, 1], frames[:, :, 2]
282
+ R, G, B = R >> 5, G >> 5, B >> 5
283
+ return (R << 6) + (G << 3) + B
284
+
285
+ batch_size, time_window, height, width, no_channels = frames.shape
286
+ assert no_channels == 3
287
+ frames_flatten = frames.view(batch_size * time_window, height * width, 3)
288
+
289
+ binned_values = get_bin(frames_flatten)
290
+ frame_bin_prefix = (torch.arange(0, batch_size * time_window, device=frames.device) << 9).view(-1, 1)
291
+ binned_values = (binned_values + frame_bin_prefix).view(-1)
292
+
293
+ histograms = torch.zeros(batch_size * time_window * 512, dtype=torch.int32, device=frames.device)
294
+ histograms.scatter_add_(0, binned_values, torch.ones(len(binned_values), dtype=torch.int32, device=frames.device))
295
+
296
+ histograms = histograms.view(batch_size, time_window, 512).float()
297
+ histograms_normalized = functional.normalize(histograms, p=2, dim=2)
298
+ return histograms_normalized
299
+
300
+ def forward(self, inputs):
301
+ x = self.compute_color_histograms(inputs)
302
+
303
+ batch_size, time_window = x.shape[0], x.shape[1]
304
+ similarities = torch.bmm(x, x.transpose(1, 2)) # [batch_size, time_window, time_window]
305
+ similarities_padded = functional.pad(similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2])
306
+
307
+ batch_indices = torch.arange(0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat(
308
+ [1, time_window, self.lookup_window])
309
+ time_indices = torch.arange(0, time_window, device=x.device).view([1, time_window, 1]).repeat(
310
+ [batch_size, 1, self.lookup_window])
311
+ lookup_indices = torch.arange(0, self.lookup_window, device=x.device).view([1, 1, self.lookup_window]).repeat(
312
+ [batch_size, time_window, 1]) + time_indices
313
+
314
+ similarities = similarities_padded[batch_indices, time_indices, lookup_indices]
315
+
316
+ if self.fc is not None:
317
+ return functional.relu(self.fc(similarities))
318
+ return similarities