Exgc commited on
Commit
9a6ee98
·
1 Parent(s): 285b2a6
Files changed (6) hide show
  1. app.py +172 -0
  2. exp/checkpoints/best_model.pt +1 -0
  3. exp/train-args.json +1 -0
  4. omnisep.py +752 -0
  5. requirements.txt +8 -0
  6. utils.py +348 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import librosa
5
+ import pathlib
6
+ import scipy.io.wavfile
7
+ import os
8
+
9
+ from imagebind import data
10
+ from imagebind.models import imagebind_model
11
+ from imagebind.models.imagebind_model import ModalityType
12
+ import torch.nn.functional as F
13
+
14
+ import omnisep
15
+ import utils
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # ========== Configuration & Model Loading ==========
19
+
20
+ def setup_models(checkpoint_path, train_args_path):
21
+ train_args = utils.load_json(train_args_path)
22
+
23
+ model = omnisep.OmniSep(
24
+ train_args['n_mix'], train_args['layers'], train_args['channels'],
25
+ use_log_freq=train_args['log_freq'],
26
+ use_weighted_loss=train_args['weighted_loss'],
27
+ use_binary_mask=train_args['binary_mask'],
28
+ emb_dim=train_args.get('emb_dim', 512)
29
+ )
30
+ model = torch.nn.DataParallel(model)
31
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
32
+ model.to(device)
33
+ model.eval()
34
+
35
+ imagebind_net = imagebind_model.imagebind_huge(pretrained=True)
36
+ imagebind_net = torch.nn.DataParallel(imagebind_net)
37
+ imagebind_net.to(device)
38
+ imagebind_net.eval()
39
+
40
+ return model, imagebind_net, train_args
41
+
42
+ # ========== Audio Loading & Preprocessing ==========
43
+
44
+ def load_audio_and_spec(audio_file, audio_len, sample_rate, n_fft, hop_len, win_len):
45
+ y, sr = librosa.load(audio_file, sr=sample_rate, mono=True)
46
+ if len(y) < audio_len:
47
+ y = np.tile(y, (audio_len // len(y) + 1))[:audio_len]
48
+ else:
49
+ y = y[:audio_len]
50
+ y = np.clip(y, -1, 1)
51
+
52
+ spec_mix = librosa.stft(y, n_fft=n_fft, hop_length=hop_len, win_length=win_len)
53
+ mag_mix = torch.tensor(np.abs(spec_mix)).unsqueeze(0).unsqueeze(0)
54
+ phase_mix = torch.tensor(np.angle(spec_mix)).unsqueeze(0).unsqueeze(0)
55
+ return mag_mix, phase_mix, y.shape[0]
56
+
57
+ # ========== Embedding Construction ==========
58
+
59
+ def get_combined_embedding(imagebind_net, text=None, image=None, audio=None,
60
+ text_w=1.0, image_w=1.0, audio_w=1.0):
61
+ inputs = {}
62
+ if text: inputs[ModalityType.TEXT] = data.load_and_transform_text([text], device)
63
+ if image: inputs[ModalityType.VISION] = data.load_and_transform_vision_data([image], device)
64
+ if audio: inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data([audio], device)
65
+ emb = imagebind_net(inputs)
66
+
67
+ result = None
68
+ denom = 0
69
+ if text:
70
+ result = text_w * emb[ModalityType.TEXT]
71
+ denom += text_w
72
+ if image:
73
+ result = emb[ModalityType.VISION] * image_w if result is None else result + image_w * emb[ModalityType.VISION]
74
+ denom += image_w
75
+ if audio:
76
+ result = emb[ModalityType.AUDIO] * audio_w if result is None else result + audio_w * emb[ModalityType.AUDIO]
77
+ denom += audio_w
78
+ if denom > 0:
79
+ result = F.normalize(result / denom)
80
+
81
+ return result
82
+
83
+ # ========== Waveform Recovery ==========
84
+
85
+ def recover_waveform(mag_mix, phase_mix, pred_mask, args):
86
+ B = mag_mix.size(0)
87
+ if args['log_freq']:
88
+ grid_unwarp = torch.from_numpy(utils.warpgrid(B, args['n_fft'] // 2 + 1, pred_mask.size(3), warp=False)).to(pred_mask.device)
89
+ pred_mask_linear = F.grid_sample(pred_mask, grid_unwarp, align_corners=True)
90
+ else:
91
+ pred_mask_linear = pred_mask[0]
92
+
93
+ # pred_mag = mag_mix[0, 0].numpy() * pred_mask_linear[0, 0].numpy()
94
+ # pred_wav = utils.istft_reconstruction(pred_mag, phase_mix[0, 0].numpy(),
95
+ # hop_len=args['hop_len'], win_len=args['win_len'])
96
+
97
+ # Convert into numpy arrays
98
+ mag_mix = mag_mix.detach().cpu().numpy()
99
+ phase_mix = phase_mix.detach().cpu().numpy()
100
+ pred_mask = pred_mask.detach().cpu().numpy()
101
+ pred_mask_linear = pred_mask_linear.detach().cpu().numpy()
102
+
103
+ # Apply the threshold
104
+ pred_mask = (pred_mask > 0.5).astype(np.float32)
105
+ pred_mask_linear = (pred_mask_linear > 0.5).astype(np.float32)
106
+
107
+ # Recover predicted audio
108
+ pred_mag = mag_mix[0, 0] * pred_mask_linear[0, 0]
109
+ pred_wav = utils.istft_reconstruction(
110
+ pred_mag,
111
+ phase_mix[0, 0],
112
+ hop_len=args['hop_len'],
113
+ win_len=args['win_len'],
114
+ )
115
+
116
+ return pred_wav
117
+
118
+ # ========== Gradio Interface ==========
119
+
120
+ def run_inference(input_audio, text_pos, audio_pos, image_pos, text_neg, audio_neg, image_neg,
121
+ text_w, image_w, audio_w, neg_w):
122
+ model, imagebind_net, args = setup_models("./exp/checkpoints/best_model.pt", "./exp/checkpoints/train-args.json")
123
+ audio_len = 65535
124
+ mag_mix, phase_mix, out_len = load_audio_and_spec(input_audio, audio_len,
125
+ args['audio_rate'], args['n_fft'], args['hop_len'], args['win_len'])
126
+ img_emb = get_combined_embedding(imagebind_net, text_pos, image_pos, audio_pos,
127
+ text_w, image_w, audio_w)
128
+ if any([text_neg, audio_neg, image_neg]):
129
+ neg_emb = get_combined_embedding(imagebind_net, text_neg, image_neg, audio_neg,
130
+ 1.0, 1.0, 1.0)
131
+ img_emb = (1 + neg_w) * img_emb - neg_w * neg_emb
132
+ mag_mix = mag_mix.to(device)
133
+ phase_mix = phase_mix.to(device)
134
+
135
+ pred_mask = model.module.infer(mag_mix, [img_emb])[0]
136
+ pred_wav = recover_waveform(mag_mix, phase_mix, pred_mask, args)
137
+ out_path = "/tmp/output.wav"
138
+ scipy.io.wavfile.write(out_path, args['audio_rate'], pred_wav[:out_len])
139
+ return out_path
140
+
141
+ with gr.Blocks(title="OmniSep UI") as iface:
142
+ gr.Markdown("## 🎧 Upload Your Mixed Audio")
143
+ mixed_audio = gr.Audio(type="filepath", label="Mixed Input Audio")
144
+
145
+ gr.Markdown("### ✅ Positive Query")
146
+ with gr.Row():
147
+ pos_text = gr.Textbox(label="Text Query", placeholder="e.g. dog barking")
148
+ pos_audio = gr.Audio(type="filepath", label="Audio Query")
149
+ pos_image = gr.Image(type="filepath", label="Image Query")
150
+
151
+ gr.Markdown("### ❌ Negative Query (Optional)")
152
+ with gr.Row():
153
+ neg_text = gr.Textbox(label="Negative Text Query")
154
+ neg_audio = gr.Audio(type="filepath", label="Negative Audio Query")
155
+ neg_image = gr.Image(type="filepath", label="Negative Image Query")
156
+
157
+ gr.Markdown("### 🎚️ Modality Weights")
158
+ with gr.Row():
159
+ text_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Text Weight")
160
+ image_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Image Weight")
161
+ audio_weight = gr.Slider(0, 5, value=1.0, step=0.1, label="Audio Weight")
162
+ neg_weight = gr.Slider(0, 2, value=0.5, step=0.1, label="Negative Embedding Weight")
163
+
164
+ output_audio = gr.Audio(type="filepath", label="Separated Output Audio")
165
+
166
+ btn = gr.Button("Run OmniSep Inference")
167
+ btn.click(fn=run_inference,
168
+ inputs=[mixed_audio, pos_text, pos_audio, pos_image, neg_text, neg_audio, neg_image,
169
+ text_weight, image_weight, audio_weight, neg_weight],
170
+ outputs=output_audio)
171
+
172
+ iface.launch(share=True)
exp/checkpoints/best_model.pt ADDED
@@ -0,0 +1 @@
 
 
1
+ /root/autodl-tmp/data/OmniSep/best_model.pt
exp/train-args.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"out_dir": "/root/autodl-tmp/OmniSep/omnisep/exp/vggsound/omnisep", "train_list": ["data/vggsound/test.csv"], "val_list": ["data/vggsound/test.csv"], "n_validation": null, "weights": null, "batch_size": 32, "drop_closest": null, "drop_closest_steps": 10000, "repeat": null, "frame_margin": null, "audio_only": false, "audio_len": 65535, "emb_dim": 1024, "audio_rate": 16000, "n_fft": 1024, "hop_len": 256, "win_len": 1024, "img_size": 224, "fps": 1, "train_mode": ["image", "text", "audio"], "n_mix": 2, "channels": 32, "layers": 7, "frames": 3, "stride_frames": 1, "binary_mask": true, "loss": "bce", "weighted_loss": true, "log_freq": true, "n_labels": null, "steps": 500000, "valid_steps": 10000, "lr": 0.001, "lr_warmup_steps": 5000, "lr_decay_steps": 100000, "lr_decay_multiplier": 0.1, "grad_norm_clip": 1.0, "pit_warmup_steps": 0, "seed": 1234, "gpus": 1, "workers": 20, "quiet": false, "is_feature": true, "is_neg": false, "feature_mode": "imagebind"}
omnisep.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define the models."""
2
+ import functools
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import utils
9
+
10
+
11
+ def init_weights(net):
12
+ classname = net.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ net.weight.data.normal_(0.0, 0.001)
15
+ elif classname.find("BatchNorm") != -1:
16
+ net.weight.data.normal_(1.0, 0.02)
17
+ net.bias.data.fill_(0)
18
+ elif classname.find("Linear") != -1:
19
+ net.weight.data.normal_(0.0, 0.0001)
20
+
21
+
22
+ class OmniSep(torch.nn.Module):
23
+ """Separation model based on the CLIP model."""
24
+
25
+ def __init__(
26
+ self,
27
+ n_mix,
28
+ layers=7,
29
+ channels=32,
30
+ use_log_freq=True,
31
+ use_weighted_loss=True,
32
+ use_binary_mask=True,
33
+ emb_dim=512
34
+ ):
35
+ super().__init__()
36
+ self.n_mix = n_mix
37
+ self.use_log_freq = use_log_freq
38
+ self.use_weighted_loss = use_weighted_loss
39
+ self.use_binary_mask = use_binary_mask
40
+
41
+ # Create the neural net
42
+ self.sound_net = UNet(in_dim=1, out_dim=channels, num_downs=layers)
43
+ self.frame_net = nn.Linear(emb_dim, channels)
44
+ self.synth_net = InnerProd(fc_dim=channels)
45
+
46
+ # Initialize the weights
47
+ self.sound_net.apply(init_weights)
48
+ self.frame_net.apply(init_weights)
49
+ self.synth_net.apply(init_weights)
50
+
51
+ def forward(self, batch, img_emb, drop_closest=None):
52
+
53
+ N = self.n_mix
54
+ mag_mix = batch["mag_mix"]
55
+ mags = batch["mags"]
56
+
57
+ # Pass through the frame net -> Bx1xC
58
+ feat_frames_pre = [self.frame_net(img_emb[n]) for n in range(N)]
59
+ feat_frames = [torch.sigmoid(feat) for feat in feat_frames_pre]
60
+
61
+ # Compute similarities
62
+ if drop_closest is not None:
63
+ assert N == 2, "N must be 2 when `drop_closest` is enabled."
64
+ similarities = F.cosine_similarity(
65
+ img_emb[0].detach(), img_emb[1].detach()
66
+ )
67
+
68
+ # Drop most similar pairs
69
+ if drop_closest is not None and drop_closest > 0:
70
+ # Sort the similarities
71
+ sorted_indices = torch.argsort(similarities)
72
+
73
+ # Keep only those with low similarities
74
+ mag_mix = mag_mix[sorted_indices[:-drop_closest]]
75
+ for n in range(N):
76
+ mags[n] = mags[n][sorted_indices[:-drop_closest]]
77
+ feat_frames[n] = feat_frames[n][sorted_indices[:-drop_closest]]
78
+ mag_mix = mag_mix + 1e-10
79
+
80
+ B = mag_mix.size(0)
81
+ T = mag_mix.size(3)
82
+
83
+ # Warp the spectrogram
84
+ if self.use_log_freq:
85
+ grid_warp = torch.from_numpy(
86
+ utils.warpgrid(B, 256, T, warp=True)
87
+ )
88
+ grid_warp = grid_warp.to(mag_mix.device)
89
+ mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True)
90
+ for n in range(N):
91
+ mags[n] = F.grid_sample(mags[n], grid_warp, align_corners=True)
92
+ # Calculate loss weighting coefficient (magnitude of input mixture)
93
+ if self.use_weighted_loss:
94
+ weight = torch.log1p(mag_mix)
95
+ weight = torch.clamp(weight, 1e-3, 10)
96
+ else:
97
+ weight = torch.ones_like(mag_mix)
98
+
99
+ # Drop most similar pairs
100
+ if drop_closest is not None and drop_closest == -1:
101
+ # Desired weight as a function of similarity:
102
+ # sim -1 <-> 0.5 <---------------> 1
103
+ # weight 1 1 2 x (1 - sim) 0
104
+ w = F.relu(1 - 2 * F.relu(similarities - 0.5))
105
+ weight *= w.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
106
+ # Compute ground truth masks after warping!
107
+ gt_masks = [None] * N
108
+ for n in range(N):
109
+ if self.use_binary_mask:
110
+ gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
111
+ else:
112
+ gt_masks[n] = mags[n] / sum(mags[n])
113
+ gt_masks[n].clamp_(0.0, 1.0)
114
+
115
+ # Compute log magnitude
116
+ log_mag_mix = torch.log(mag_mix).detach()
117
+
118
+ # Pass through the sound net -> BxCxHxW
119
+ feat_sound = self.sound_net(log_mag_mix)
120
+ # Pass through the synth net
121
+ pred_masks = [
122
+ self.synth_net(feat_frames[n], feat_sound) for n in range(N)
123
+ ]
124
+
125
+ # Activate with Sigmoid function if using binary mask
126
+ if self.use_binary_mask:
127
+ pred_masks = [torch.sigmoid(mask) for mask in pred_masks]
128
+
129
+ # Compute the loss
130
+ loss = torch.mean(
131
+ torch.stack(
132
+ [
133
+ F.binary_cross_entropy(pred_masks[n], gt_masks[n], weight)
134
+ for n in range(N)
135
+ ]
136
+ )
137
+ )
138
+ return (
139
+ loss,
140
+ {
141
+ "pred_masks": pred_masks,
142
+ "gt_masks": gt_masks,
143
+ "mag_mix": mag_mix,
144
+ "mags": mags,
145
+ "weight": weight,
146
+ },
147
+ )
148
+
149
+ def infer(self, mag_mix, img_emb, n_mix=1):
150
+ N = n_mix
151
+
152
+ # Pass through the frame net -> Bx1xC
153
+ feat_frames_pre = [self.frame_net(img_emb[n]) for n in range(N)]
154
+ feat_frames = [torch.sigmoid(feat) for feat in feat_frames_pre]
155
+
156
+ mag_mix = mag_mix + 1e-10
157
+
158
+ B = mag_mix.size(0)
159
+ T = mag_mix.size(3)
160
+
161
+ # Warp the spectrogram
162
+ if self.use_log_freq:
163
+ grid_warp = torch.from_numpy(
164
+ utils.warpgrid(B, 256, T, warp=True)
165
+ ).to(mag_mix.device)
166
+ mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True)
167
+
168
+ # Compute log magnitude
169
+ log_mag_mix = torch.log(mag_mix).detach()
170
+
171
+ # Pass through the sound net -> BxCxHxW
172
+ feat_sound = self.sound_net(log_mag_mix)
173
+
174
+ # Pass through the synth net
175
+ pred_masks = [
176
+ self.synth_net(feat_frames[n], feat_sound) for n in range(N)
177
+ ]
178
+
179
+ # Activate with Sigmoid function if using binary mask
180
+ if self.use_binary_mask:
181
+ pred_masks = [torch.sigmoid(mask) for mask in pred_masks]
182
+
183
+ return pred_masks
184
+
185
+ def infer2(self, batch, img_emb):
186
+ N = self.n_mix
187
+ mag_mix = batch["mag_mix"]
188
+ mags = batch["mags"]
189
+
190
+ # Pass through the frame net -> Bx1xC
191
+ feat_frames_pre = [self.frame_net(img_emb[0])]
192
+ feat_frames = [torch.sigmoid(feat) for feat in feat_frames_pre]
193
+
194
+ mag_mix = mag_mix + 1e-10
195
+
196
+ B = mag_mix.size(0)
197
+ T = mag_mix.size(3)
198
+
199
+ # Warp the spectrogram
200
+ if self.use_log_freq:
201
+ grid_warp = torch.from_numpy(
202
+ utils.warpgrid(B, 256, T, warp=True)
203
+ ).to(mag_mix.device)
204
+ mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True)
205
+ for n in range(N):
206
+ mags[n] = F.grid_sample(mags[n], grid_warp, align_corners=True)
207
+
208
+ # Calculate loss weighting coefficient (magnitude of input mixture)
209
+ if self.use_weighted_loss:
210
+ weight = torch.log1p(mag_mix)
211
+ weight = torch.clamp(weight, 1e-3, 10)
212
+ else:
213
+ weight = torch.ones_like(mag_mix)
214
+
215
+ # Compute ground truth masks after warping!
216
+ gt_masks = [None] * N
217
+ for n in range(N):
218
+ if self.use_binary_mask:
219
+ gt_masks[n] = (mags[n] > 0.5 * mag_mix).float()
220
+ else:
221
+ gt_masks[n] = mags[n] / sum(mags[n])
222
+ gt_masks[n].clamp_(0.0, 1.0)
223
+
224
+ # Compute log magnitude
225
+ log_mag_mix = torch.log(mag_mix).detach()
226
+
227
+ # Pass through the sound net -> BxCxHxW
228
+ feat_sound = self.sound_net(log_mag_mix)
229
+
230
+ # Pass through the synth net
231
+ pred_masks = [self.synth_net(feat_frames[0], feat_sound)]
232
+
233
+ # Activate with Sigmoid function if using binary mask
234
+ if self.use_binary_mask:
235
+ pred_masks = [torch.sigmoid(pred_masks[0])]
236
+
237
+ return {
238
+ "pred_masks": pred_masks,
239
+ "gt_masks": gt_masks,
240
+ "mag_mix": mag_mix,
241
+ "mags": mags,
242
+ "weight": weight,
243
+ }
244
+
245
+ def infer3(self, batch, img_emb):
246
+
247
+ mag_mix = batch["mag_mix"]
248
+
249
+ # Pass through the frame net -> Bx1xC
250
+ feat_frames_pre = [self.frame_net(img_emb)]
251
+ feat_frames = [torch.sigmoid(feat) for feat in feat_frames_pre]
252
+
253
+ mag_mix = mag_mix + 1e-10
254
+
255
+ B = mag_mix.size(0)
256
+ T = mag_mix.size(3)
257
+
258
+ # Warp the spectrogram
259
+ if self.use_log_freq:
260
+ grid_warp = torch.from_numpy(
261
+ utils.warpgrid(B, 256, T, warp=True)
262
+ ).to(mag_mix.device)
263
+ mag_mix = F.grid_sample(mag_mix, grid_warp, align_corners=True)
264
+
265
+ # Calculate loss weighting coefficient (magnitude of input mixture)
266
+ if self.use_weighted_loss:
267
+ weight = torch.log1p(mag_mix)
268
+ weight = torch.clamp(weight, 1e-3, 10)
269
+ else:
270
+ weight = torch.ones_like(mag_mix)
271
+
272
+ # Compute log magnitude
273
+ log_mag_mix = torch.log(mag_mix).detach()
274
+
275
+ # Pass through the sound net -> BxCxHxW
276
+ feat_sound = self.sound_net(log_mag_mix)
277
+
278
+ # Pass through the synth net
279
+ pred_masks = [self.synth_net(feat_frames[0], feat_sound)]
280
+
281
+ # Get the input to the PIT stream
282
+ # mean_feat_frames_pre = feat_frames_pre[0]
283
+ # feat_pit_pre = [net(mean_feat_frames_pre) for net in self.pit_nets]
284
+ # feat_pit = [torch.sigmoid(feat) for feat in feat_pit_pre]
285
+
286
+ # Pass through the synth net for the PIT stream
287
+ # pit_masks = [self.synth_net(feat, feat_sound) for feat in feat_pit]
288
+
289
+ # Mean activation
290
+ mean_act = torch.mean(torch.sigmoid(pred_masks[0]))
291
+ # mean_pit_act = torch.mean(
292
+ # torch.sigmoid(pit_masks[0]) + torch.sigmoid(pit_masks[1])
293
+ # )
294
+
295
+ return {
296
+ "pred_masks": pred_masks,
297
+ # "pit_masks": pit_masks,
298
+ "mag_mix": mag_mix,
299
+ "weight": weight,
300
+ "mean_act": mean_act,
301
+ # "mean_pit_act": mean_pit_act,
302
+ }
303
+
304
+
305
+ class ResnetDilated(nn.Module):
306
+ def __init__(self, orig_resnet, pool_type="maxpool", dilate_scale=16):
307
+ super().__init__()
308
+
309
+ self.pool_type = pool_type
310
+
311
+ if dilate_scale == 8:
312
+ orig_resnet.layer3.apply(
313
+ functools.partial(self._nostride_dilate, dilate=2)
314
+ )
315
+ orig_resnet.layer4.apply(
316
+ functools.partial(self._nostride_dilate, dilate=4)
317
+ )
318
+ elif dilate_scale == 16:
319
+ orig_resnet.layer4.apply(
320
+ functools.partial(self._nostride_dilate, dilate=2)
321
+ )
322
+
323
+ self.features = nn.Sequential(*list(orig_resnet.children())[:-2])
324
+
325
+ def _nostride_dilate(self, m, dilate):
326
+ classname = m.__class__.__name__
327
+ if classname.find("Conv") != -1:
328
+ # Convolution layers with stride
329
+ if m.stride == (2, 2):
330
+ m.stride = (1, 1)
331
+ if m.kernel_size == (3, 3):
332
+ m.dilation = (dilate // 2, dilate // 2)
333
+ m.padding = (dilate // 2, dilate // 2)
334
+ # Other convolution layers
335
+ else:
336
+ if m.kernel_size == (3, 3):
337
+ m.dilation = (dilate, dilate)
338
+ m.padding = (dilate, dilate)
339
+
340
+ def forward(self, x, pool=True):
341
+ x = self.features(x)
342
+
343
+ if not pool:
344
+ return x
345
+
346
+ if self.pool_type == "avgpool":
347
+ x = F.adaptive_avg_pool2d(x, 1)
348
+ elif self.pool_type == "maxpool":
349
+ x = F.adaptive_max_pool2d(x, 1)
350
+
351
+ x = x.view(x.size(0), x.size(1))
352
+ return x
353
+
354
+
355
+ class UNetBlock(nn.Module):
356
+ """A U-Net block that defines the submodule with skip connection.
357
+
358
+ X ---------------------identity-------------------- X
359
+ |-- downsampling --| submodule |-- upsampling --|
360
+
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ outer_nc,
366
+ inner_input_nc,
367
+ input_nc=None,
368
+ submodule=None,
369
+ outermost=False,
370
+ innermost=False,
371
+ use_dropout=False,
372
+ inner_output_nc=None,
373
+ noskip=False,
374
+ ):
375
+ super().__init__()
376
+ self.outermost = outermost
377
+ self.noskip = noskip
378
+ use_bias = False
379
+ if input_nc is None:
380
+ input_nc = outer_nc
381
+ if innermost:
382
+ inner_output_nc = inner_input_nc
383
+ elif inner_output_nc is None:
384
+ inner_output_nc = 2 * inner_input_nc
385
+
386
+ downrelu = nn.LeakyReLU(0.2, True)
387
+ downnorm = nn.BatchNorm2d(inner_input_nc)
388
+ uprelu = nn.ReLU(True)
389
+ upnorm = nn.BatchNorm2d(outer_nc)
390
+ upsample = nn.Upsample(
391
+ scale_factor=2, mode="bilinear", align_corners=True
392
+ )
393
+
394
+ if outermost:
395
+ downconv = nn.Conv2d(
396
+ input_nc,
397
+ inner_input_nc,
398
+ kernel_size=4,
399
+ stride=2,
400
+ padding=1,
401
+ bias=use_bias,
402
+ )
403
+ upconv = nn.Conv2d(
404
+ inner_output_nc, outer_nc, kernel_size=3, padding=1
405
+ )
406
+
407
+ down = [downconv]
408
+ up = [uprelu, upsample, upconv]
409
+ model = down + [submodule] + up
410
+ elif innermost:
411
+ downconv = nn.Conv2d(
412
+ input_nc,
413
+ inner_input_nc,
414
+ kernel_size=4,
415
+ stride=2,
416
+ padding=1,
417
+ bias=use_bias,
418
+ )
419
+ upconv = nn.Conv2d(
420
+ inner_output_nc,
421
+ outer_nc,
422
+ kernel_size=3,
423
+ padding=1,
424
+ bias=use_bias,
425
+ )
426
+
427
+ down = [downrelu, downconv]
428
+ up = [uprelu, upsample, upconv, upnorm]
429
+ model = down + up
430
+ else:
431
+ downconv = nn.Conv2d(
432
+ input_nc,
433
+ inner_input_nc,
434
+ kernel_size=4,
435
+ stride=2,
436
+ padding=1,
437
+ bias=use_bias,
438
+ )
439
+ upconv = nn.Conv2d(
440
+ inner_output_nc,
441
+ outer_nc,
442
+ kernel_size=3,
443
+ padding=1,
444
+ bias=use_bias,
445
+ )
446
+ down = [downrelu, downconv, downnorm]
447
+ up = [uprelu, upsample, upconv, upnorm]
448
+
449
+ if use_dropout:
450
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
451
+ else:
452
+ model = down + [submodule] + up
453
+
454
+ self.model = nn.Sequential(*model)
455
+
456
+ def forward(self, x):
457
+ if self.outermost or self.noskip:
458
+ return self.model(x)
459
+ else:
460
+ return torch.cat([x, self.model(x)], 1)
461
+
462
+
463
+ class UNet(nn.Module):
464
+ """A UNet model."""
465
+
466
+ def __init__(
467
+ self,
468
+ in_dim=1,
469
+ out_dim=64,
470
+ num_downs=5,
471
+ ngf=64,
472
+ use_dropout=False,
473
+ ):
474
+ super().__init__()
475
+
476
+ # Construct the U-Net structure
477
+ unet_block = UNetBlock(
478
+ ngf * 8, ngf * 8, input_nc=None, submodule=None, innermost=True
479
+ )
480
+ for i in range(num_downs - 5):
481
+ unet_block = UNetBlock(
482
+ ngf * 8,
483
+ ngf * 8,
484
+ input_nc=None,
485
+ submodule=unet_block,
486
+ use_dropout=use_dropout,
487
+ )
488
+ unet_block = UNetBlock(
489
+ ngf * 4, ngf * 8, input_nc=None, submodule=unet_block
490
+ )
491
+ unet_block = UNetBlock(
492
+ ngf * 2, ngf * 4, input_nc=None, submodule=unet_block
493
+ )
494
+ unet_block = UNetBlock(
495
+ ngf, ngf * 2, input_nc=None, submodule=unet_block
496
+ )
497
+ unet_block = UNetBlock(
498
+ out_dim,
499
+ ngf,
500
+ input_nc=in_dim,
501
+ submodule=unet_block,
502
+ outermost=True,
503
+ )
504
+
505
+ self.bn0 = nn.BatchNorm2d(in_dim)
506
+ self.unet_block = unet_block
507
+
508
+ def forward(self, x):
509
+ x = self.bn0(x)
510
+ x = self.unet_block(x)
511
+ return x
512
+
513
+
514
+ class CondUNetBlock(nn.Module):
515
+ """A U-Net block that defines the submodule with skip connection.
516
+
517
+ X ---------------------identity-------------------- X
518
+ |-- downsampling --| submodule |-- upsampling --|
519
+
520
+ """
521
+
522
+ def __init__(
523
+ self,
524
+ outer_nc,
525
+ inner_input_nc,
526
+ input_nc=None,
527
+ submodule=None,
528
+ outermost=False,
529
+ innermost=False,
530
+ inner_output_nc=None,
531
+ noskip=False,
532
+ cond_nc=None,
533
+ ):
534
+ super().__init__()
535
+ self.outermost = outermost
536
+ self.innermost = innermost
537
+ self.noskip = noskip
538
+ self.cond_nc = cond_nc
539
+ self.submodule = submodule
540
+
541
+ use_bias = False
542
+ if input_nc is None:
543
+ input_nc = outer_nc
544
+ if innermost:
545
+ assert cond_nc > 0
546
+ inner_output_nc = inner_input_nc + cond_nc
547
+ elif inner_output_nc is None:
548
+ inner_output_nc = 2 * inner_input_nc
549
+
550
+ self.downnorm = nn.BatchNorm2d(inner_input_nc)
551
+ self.uprelu = nn.ReLU(True)
552
+ self.upsample = nn.Upsample(
553
+ scale_factor=2, mode="bilinear", align_corners=True
554
+ )
555
+
556
+ if outermost:
557
+ self.downconv = nn.Conv2d(
558
+ input_nc,
559
+ inner_input_nc,
560
+ kernel_size=4,
561
+ stride=2,
562
+ padding=1,
563
+ bias=use_bias,
564
+ )
565
+ self.upconv = nn.Conv2d(
566
+ inner_output_nc, outer_nc, kernel_size=3, padding=1
567
+ )
568
+
569
+ elif innermost:
570
+ self.downrelu = nn.LeakyReLU(0.2, True)
571
+ self.downconv = nn.Conv2d(
572
+ input_nc,
573
+ inner_input_nc,
574
+ kernel_size=4,
575
+ stride=2,
576
+ padding=1,
577
+ bias=use_bias,
578
+ )
579
+ self.upconv = nn.Conv2d(
580
+ inner_output_nc,
581
+ outer_nc,
582
+ kernel_size=3,
583
+ padding=1,
584
+ bias=use_bias,
585
+ )
586
+ self.upnorm = nn.BatchNorm2d(outer_nc)
587
+
588
+ else:
589
+ self.downrelu = nn.LeakyReLU(0.2, True)
590
+ self.downconv = nn.Conv2d(
591
+ input_nc,
592
+ inner_input_nc,
593
+ kernel_size=4,
594
+ stride=2,
595
+ padding=1,
596
+ bias=use_bias,
597
+ )
598
+ self.upconv = nn.Conv2d(
599
+ inner_output_nc,
600
+ outer_nc,
601
+ kernel_size=3,
602
+ padding=1,
603
+ bias=use_bias,
604
+ )
605
+ self.upnorm = nn.BatchNorm2d(outer_nc)
606
+
607
+ def forward(self, x, cond):
608
+ if self.outermost:
609
+ x_ = self.downconv(x)
610
+ x_ = self.submodule(x_, cond)
611
+ x_ = self.upconv(self.upsample(self.uprelu(x_)))
612
+
613
+ elif self.innermost:
614
+ x_ = self.downconv(self.downrelu(x))
615
+
616
+ B, _, H, W = x_.size()
617
+ cond_ = cond.unsqueeze(-1).unsqueeze(-1) * torch.ones(
618
+ (B, self.cond_nc, H, W), device=x_.device
619
+ )
620
+ x_ = torch.concat((x_, cond_), 1)
621
+
622
+ x_ = self.upnorm(self.upconv(self.upsample(self.uprelu(x_))))
623
+
624
+ else:
625
+ x_ = self.downnorm(self.downconv(self.downrelu(x)))
626
+ x_ = self.submodule(x_, cond)
627
+ x_ = self.upnorm(self.upconv(self.upsample(self.uprelu(x_))))
628
+
629
+ if self.outermost or self.noskip:
630
+ return x_
631
+ else:
632
+ return torch.cat([x, x_], 1)
633
+
634
+
635
+ class CondUNet(nn.Module):
636
+ """A UNet model."""
637
+
638
+ def __init__(
639
+ self,
640
+ in_dim=1,
641
+ out_dim=64,
642
+ cond_dim=32,
643
+ num_downs=5,
644
+ ngf=64,
645
+ use_dropout=False,
646
+ ):
647
+ super().__init__()
648
+
649
+ # Construct the U-Net structure
650
+ unet_block = CondUNetBlock(
651
+ ngf * 8,
652
+ ngf * 8,
653
+ input_nc=None,
654
+ submodule=None,
655
+ innermost=True,
656
+ cond_nc=cond_dim,
657
+ )
658
+ for _ in range(num_downs - 5):
659
+ unet_block = CondUNetBlock(
660
+ ngf * 8, ngf * 8, input_nc=None, submodule=unet_block
661
+ )
662
+ unet_block = CondUNetBlock(
663
+ ngf * 4, ngf * 8, input_nc=None, submodule=unet_block
664
+ )
665
+ unet_block = CondUNetBlock(
666
+ ngf * 2, ngf * 4, input_nc=None, submodule=unet_block
667
+ )
668
+ unet_block = CondUNetBlock(
669
+ ngf, ngf * 2, input_nc=None, submodule=unet_block
670
+ )
671
+ unet_block = CondUNetBlock(
672
+ out_dim,
673
+ ngf,
674
+ input_nc=in_dim,
675
+ submodule=unet_block,
676
+ outermost=True,
677
+ )
678
+
679
+ self.bn0 = nn.BatchNorm2d(in_dim)
680
+ self.unet_block = unet_block
681
+
682
+ def forward(self, x, cond):
683
+ x = self.bn0(x)
684
+ x = self.unet_block(x, cond)
685
+ return x
686
+
687
+
688
+ class InnerProd(nn.Module):
689
+ def __init__(self, fc_dim):
690
+ super().__init__()
691
+ self.scale = nn.Parameter(torch.ones(fc_dim))
692
+ self.bias = nn.Parameter(torch.zeros(1))
693
+
694
+ def forward(self, feat_img, feat_sound):
695
+ sound_size = feat_sound.size()
696
+ B, C = sound_size[0], sound_size[1]
697
+ feat_img = feat_img.view(B, 1, C)
698
+ z = torch.bmm(feat_img * self.scale, feat_sound.view(B, C, -1)).view(
699
+ B, 1, *sound_size[2:]
700
+ )
701
+ z = z + self.bias
702
+ return z
703
+
704
+ def forward_nosum(self, feat_img, feat_sound):
705
+ (B, C, H, W) = feat_sound.size()
706
+ feat_img = feat_img.view(B, C)
707
+ z = (feat_img * self.scale).view(B, C, 1, 1) * feat_sound
708
+ z = z + self.bias
709
+ return z
710
+
711
+ # inference purposes
712
+ def forward_pixelwise(self, feats_img, feat_sound):
713
+ (B, C, HI, WI) = feats_img.size()
714
+ (B, C, HS, WS) = feat_sound.size()
715
+ feats_img = feats_img.view(B, C, HI * WI)
716
+ feats_img = feats_img.transpose(1, 2)
717
+ feat_sound = feat_sound.view(B, C, HS * WS)
718
+ z = torch.bmm(feats_img * self.scale, feat_sound).view(
719
+ B, HI, WI, HS, WS
720
+ )
721
+ z = z + self.bias
722
+ return z
723
+
724
+
725
+ class Bias(nn.Module):
726
+ def __init__(self):
727
+ super().__init__()
728
+ self.bias = nn.Parameter(torch.zeros(1))
729
+
730
+ def forward(self, feat_img, feat_sound):
731
+ (B, C, H, W) = feat_sound.size()
732
+ feat_img = feat_img.view(B, 1, C)
733
+ z = torch.bmm(feat_img, feat_sound.view(B, C, H * W)).view(B, 1, H, W)
734
+ z = z + self.bias
735
+ return z
736
+
737
+ def forward_nosum(self, feat_img, feat_sound):
738
+ (B, C, H, W) = feat_sound.size()
739
+ z = feat_img.view(B, C, 1, 1) * feat_sound
740
+ z = z + self.bias
741
+ return z
742
+
743
+ # inference purposes
744
+ def forward_pixelwise(self, feats_img, feat_sound):
745
+ (B, C, HI, WI) = feats_img.size()
746
+ (B, C, HS, WS) = feat_sound.size()
747
+ feats_img = feats_img.view(B, C, HI * WI)
748
+ feats_img = feats_img.transpose(1, 2)
749
+ feat_sound = feat_sound.view(B, C, HS * WS)
750
+ z = torch.bmm(feats_img, feat_sound).view(B, HI, WI, HS, WS)
751
+ z = z + self.bias
752
+ return z
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ librosa==0.9.2
2
+ numba==0.56.2
3
+ mir_eval==0.7
4
+ opencv-python
5
+ museval==0.4.0
6
+ pydub
7
+ gradio
8
+ imagebind @ git+https://github.com/facebookresearch/ImageBind.git
utils.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions."""
2
+ import contextlib
3
+ import csv
4
+ import json
5
+ import os
6
+ import pathlib
7
+ import subprocess as sp
8
+ import warnings
9
+ from threading import Timer
10
+
11
+ import cv2
12
+ import librosa
13
+ import numpy as np
14
+
15
+
16
+ def save_args(filename, args):
17
+ """Save the command-line arguments."""
18
+ args_dict = {}
19
+ for key, value in vars(args).items():
20
+ if isinstance(value, pathlib.Path):
21
+ args_dict[key] = str(value)
22
+ elif key =='train_list' or key =='val_list':
23
+ args_dict[key] = [str(v) for v in value]
24
+ else:
25
+ args_dict[key] = value
26
+ save_json(filename, args_dict)
27
+
28
+
29
+ def inverse_dict(d):
30
+ """Return the inverse dictionary."""
31
+ return {v: k for k, v in d.items()}
32
+
33
+
34
+ def save_txt(filename, data):
35
+ """Save a list to a TXT file."""
36
+ with open(filename, "w", encoding="utf8") as f:
37
+ for item in data:
38
+ f.write(f"{item}\n")
39
+
40
+
41
+ def load_txt(filename):
42
+ """Load a TXT file as a list."""
43
+ with open(filename, encoding="utf8") as f:
44
+ return [line.strip() for line in f]
45
+
46
+
47
+ def save_json(filename, data):
48
+ """Save data as a JSON file."""
49
+ with open(filename, "w", encoding="utf8") as f:
50
+ json.dump(data, f)
51
+
52
+
53
+ def load_json(filename):
54
+ """Load data from a JSON file."""
55
+ with open(filename, encoding="utf8") as f:
56
+ return json.load(f)
57
+
58
+
59
+ def save_csv(filename, data, fmt="%d", header=""):
60
+ """Save data as a CSV file."""
61
+ np.savetxt(
62
+ filename, data, fmt=fmt, delimiter=",", header=header, comments=""
63
+ )
64
+
65
+
66
+ def load_csv(filename, skiprows=1):
67
+ """Load data from a CSV file."""
68
+ return np.loadtxt(filename, dtype=int, delimiter=",", skiprows=skiprows)
69
+
70
+
71
+ def load_csv_text(filename, headerless=True):
72
+ """Read a CSV file into a list of dictionaries or lists."""
73
+ with open(filename) as f:
74
+ if headerless:
75
+ return [row for row in csv.reader(f)]
76
+ reader = csv.DictReader(f)
77
+ return [
78
+ {field: row[field] for field in reader.fieldnames}
79
+ for row in reader
80
+ ]
81
+
82
+
83
+ def ignore_exceptions(func):
84
+ """Decorator that ignores all errors and warnings."""
85
+
86
+ def inner(*args, **kwargs):
87
+ with warnings.catch_warnings():
88
+ warnings.simplefilter("ignore")
89
+ try:
90
+ return func(*args, **kwargs)
91
+ except Exception:
92
+ return None
93
+
94
+ return inner
95
+
96
+
97
+ def suppress_outputs(func):
98
+ """Decorator that suppresses writing to stdout and stderr."""
99
+
100
+ def inner(*args, **kwargs):
101
+ devnull = open(os.devnull, "w")
102
+ with contextlib.redirect_stdout(devnull):
103
+ with contextlib.redirect_stderr(devnull):
104
+ return func(*args, **kwargs)
105
+
106
+ return inner
107
+
108
+
109
+ def resolve_paths(func):
110
+ """Decorator that resolves all paths."""
111
+
112
+ def inner(*args, **kwargs):
113
+ parsed = func(*args, **kwargs)
114
+ for key in vars(parsed).keys():
115
+ if isinstance(getattr(parsed, key), pathlib.Path):
116
+ setattr(
117
+ parsed, key, getattr(parsed, key).expanduser().resolve()
118
+ )
119
+ return parsed
120
+
121
+ return inner
122
+
123
+
124
+ def warpgrid(bs, HO, WO, warp=True):
125
+ # meshgrid
126
+ x = np.linspace(-1, 1, WO)
127
+ y = np.linspace(-1, 1, HO)
128
+ xv, yv = np.meshgrid(x, y)
129
+ grid = np.zeros((bs, HO, WO, 2))
130
+ grid_x = xv
131
+ if warp:
132
+ grid_y = (np.power(21, (yv + 1) / 2) - 11) / 10
133
+ else:
134
+ grid_y = np.log(yv * 10 + 11) / np.log(21) * 2 - 1
135
+ grid[:, :, :, 0] = grid_x
136
+ grid[:, :, :, 1] = grid_y
137
+ grid = grid.astype(np.float32)
138
+ return grid
139
+
140
+
141
+ class AverageMeter(object):
142
+ """Computes and stores the average and current value"""
143
+
144
+ def __init__(self):
145
+ self.initialized = False
146
+ self.val = None
147
+ self.avg = None
148
+ self.sum = None
149
+ self.count = None
150
+
151
+ def initialize(self, val, weight):
152
+ self.val = val
153
+ self.avg = val
154
+ self.sum = val * weight
155
+ self.count = weight
156
+ self.initialized = True
157
+
158
+ def update(self, val, weight=1):
159
+ val = np.asarray(val)
160
+ if not self.initialized:
161
+ self.initialize(val, weight)
162
+ else:
163
+ self.add(val, weight)
164
+
165
+ def add(self, val, weight):
166
+ self.val = val
167
+ self.sum += val * weight
168
+ self.count += weight
169
+ self.avg = self.sum / self.count
170
+
171
+ def value(self):
172
+ if self.val is None:
173
+ return 0.0
174
+ else:
175
+ return self.val.tolist()
176
+
177
+ def average(self):
178
+ if self.avg is None:
179
+ return 0.0
180
+ else:
181
+ return self.avg.tolist()
182
+
183
+
184
+ def recover_rgb(img):
185
+ for t, m, s in zip(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):
186
+ t.mul_(s).add_(m)
187
+ img = (img.numpy().transpose((1, 2, 0)) * 255).astype(np.uint8)
188
+ return img
189
+
190
+
191
+ def recover_rgb_clip(img):
192
+ for t, m, s in zip(
193
+ img,
194
+ [0.48145466, 0.4578275, 0.40821073],
195
+ [0.26862954, 0.26130258, 0.27577711],
196
+ ):
197
+ t.mul_(s).add_(m)
198
+ img = (img.numpy().transpose((1, 2, 0)) * 255).astype(np.uint8)
199
+ return img
200
+
201
+
202
+ def magnitude2heatmap(mag, log=True, scale=200.0):
203
+ if log:
204
+ mag = np.log10(mag + 1.0)
205
+ mag *= scale
206
+ mag[mag > 255] = 255
207
+ mag = mag.astype(np.uint8)
208
+ # mag_color = cv2.applyColorMap(mag, cv2.COLORMAP_JET)
209
+ mag_color = cv2.applyColorMap(mag, cv2.COLORMAP_INFERNO)
210
+ mag_color = mag_color[:, :, ::-1]
211
+ return mag_color
212
+
213
+
214
+ def istft_reconstruction(mag, phase, hop_len, win_len):
215
+ spec = mag.astype(np.complex) * np.exp(1j * phase)
216
+ wav = librosa.istft(spec, hop_length=hop_len, win_length=win_len)
217
+ return np.clip(wav, -1.0, 1.0).astype(np.float32)
218
+
219
+
220
+ class VideoWriter:
221
+ """ Combine numpy frames into video using ffmpeg
222
+
223
+ Arguments:
224
+ filename: name of the output video
225
+ fps: frame per second
226
+ shape: shape of video frame
227
+
228
+ Properties:
229
+ add_frame(frame):
230
+ add a frame to the video
231
+ add_frames(frames):
232
+ add multiple frames to the video
233
+ release():
234
+ release writing pipe
235
+
236
+ """
237
+
238
+ def __init__(self, filename, fps, shape):
239
+ self.file = filename
240
+ self.fps = fps
241
+ self.shape = shape
242
+
243
+ # video codec
244
+ ext = filename.split(".")[-1]
245
+ if ext == "mp4":
246
+ self.vcodec = "h264"
247
+ else:
248
+ raise RuntimeError("Video codec not supoorted.")
249
+
250
+ # video writing pipe
251
+ cmd = [
252
+ "ffmpeg",
253
+ "-y", # overwrite existing file
254
+ "-f",
255
+ "rawvideo", # file format
256
+ "-s",
257
+ "{}x{}".format(shape[1], shape[0]), # size of one frame
258
+ "-pix_fmt",
259
+ "rgb24", # 3 channels
260
+ "-r",
261
+ str(self.fps), # frames per second
262
+ "-i",
263
+ "-", # input comes from a pipe
264
+ "-an", # not to expect any audio
265
+ "-vcodec",
266
+ self.vcodec, # video codec
267
+ "-pix_fmt",
268
+ "yuv420p", # output video in yuv420p
269
+ self.file,
270
+ ]
271
+
272
+ self.pipe = sp.Popen(
273
+ cmd, stdin=sp.PIPE, stderr=sp.PIPE, bufsize=10 ** 9
274
+ )
275
+
276
+ def release(self):
277
+ self.pipe.stdin.close()
278
+
279
+ def add_frame(self, frame):
280
+ assert len(frame.shape) == 3
281
+ assert frame.shape[0] == self.shape[0]
282
+ assert frame.shape[1] == self.shape[1]
283
+ try:
284
+ self.pipe.stdin.write(frame.tostring())
285
+ except:
286
+ _, ffmpeg_error = self.pipe.communicate()
287
+ print(ffmpeg_error)
288
+
289
+ def add_frames(self, frames):
290
+ for frame in frames:
291
+ self.add_frame(frame)
292
+
293
+
294
+ def kill_proc(proc):
295
+ proc.kill()
296
+ print("Process running overtime! Killed.")
297
+
298
+
299
+ def run_proc_timeout(proc, timeout_sec):
300
+ # kill_proc = lambda p: p.kill()
301
+ timer = Timer(timeout_sec, kill_proc, [proc])
302
+ try:
303
+ timer.start()
304
+ proc.communicate()
305
+ finally:
306
+ timer.cancel()
307
+
308
+
309
+ def combine_video_audio(src_video, src_audio, dst_video, verbose=False):
310
+ try:
311
+ cmd = [
312
+ "ffmpeg",
313
+ "-y",
314
+ "-loglevel",
315
+ "quiet",
316
+ "-i",
317
+ src_video,
318
+ "-i",
319
+ src_audio,
320
+ "-c:v",
321
+ "copy",
322
+ "-c:a",
323
+ "aac",
324
+ "-strict",
325
+ "experimental",
326
+ dst_video,
327
+ ]
328
+ proc = sp.Popen(cmd)
329
+ run_proc_timeout(proc, 10.0)
330
+
331
+ if verbose:
332
+ print("Processed:{}".format(dst_video))
333
+ except Exception as e:
334
+ print("Error:[{}] {}".format(dst_video, e))
335
+
336
+
337
+ # save video to the disk using ffmpeg
338
+ def save_video(path, tensor, fps=25):
339
+ assert tensor.ndim == 4, "video should be in 4D numpy array"
340
+ L, H, W, C = tensor.shape
341
+ writer = VideoWriter(path, fps=fps, shape=[H, W])
342
+ for t in range(L):
343
+ writer.add_frame(tensor[t])
344
+ writer.release()
345
+
346
+
347
+ def save_audio(path, audio_numpy, sr):
348
+ librosa.output.write_wav(path, audio_numpy, sr)