darshankr commited on
Commit
872e650
·
verified ·
1 Parent(s): 766f849

Upload 49 files

Browse files
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. app.py +66 -0
  3. assets/id02548.0pAkJZmlFqc.00001_id04570.0YMGn6BI9rg.00001.gif +3 -0
  4. assets/website_gif_v2.gif +3 -0
  5. audio/__init__.py +0 -0
  6. audio/audio.py +136 -0
  7. audio/hparams.py +66 -0
  8. checkpoints/checkpoint.pt +3 -0
  9. dataset/LRW/lrw_fullpath.py +25 -0
  10. dataset/filelists/lrw_cross.txt +0 -0
  11. dataset/filelists/lrw_cross_relative_path.txt +0 -0
  12. dataset/filelists/lrw_reconstruction.txt +0 -0
  13. dataset/filelists/lrw_reconstruction_relative_path.txt +0 -0
  14. dataset/filelists/voxceleb2_test_n_5000_reconstruction_5k.txt +0 -0
  15. dataset/filelists/voxceleb2_test_n_5000_seed_797_cross_5K.txt +0 -0
  16. dataset/filelists/voxceleb2_test_n_500_reconstruction.txt +500 -0
  17. dataset/filelists/voxceleb2_test_n_500_seed_797_cross.txt +500 -0
  18. face_detection/README.md +1 -0
  19. face_detection/__init__.py +7 -0
  20. face_detection/api.py +98 -0
  21. face_detection/detection/__init__.py +1 -0
  22. face_detection/detection/core.py +130 -0
  23. face_detection/detection/sfd/__init__.py +1 -0
  24. face_detection/detection/sfd/bbox.py +129 -0
  25. face_detection/detection/sfd/detect.py +112 -0
  26. face_detection/detection/sfd/net_s3fd.py +129 -0
  27. face_detection/detection/sfd/sfd_detector.py +59 -0
  28. face_detection/models.py +261 -0
  29. face_detection/utils.py +313 -0
  30. generate.py +398 -0
  31. generate_dist.py +428 -0
  32. guided-diffusion/LICENSE +21 -0
  33. guided-diffusion/guided_diffusion/__init__.py +3 -0
  34. guided-diffusion/guided_diffusion/dist_util.py +94 -0
  35. guided-diffusion/guided_diffusion/fp16_util.py +237 -0
  36. guided-diffusion/guided_diffusion/gaussian_diffusion.py +843 -0
  37. guided-diffusion/guided_diffusion/image_datasets.py +167 -0
  38. guided-diffusion/guided_diffusion/logger.py +491 -0
  39. guided-diffusion/guided_diffusion/losses.py +77 -0
  40. guided-diffusion/guided_diffusion/lpips.py +20 -0
  41. guided-diffusion/guided_diffusion/nn.py +170 -0
  42. guided-diffusion/guided_diffusion/resample.py +154 -0
  43. guided-diffusion/guided_diffusion/respace.py +128 -0
  44. guided-diffusion/guided_diffusion/script_util.py +614 -0
  45. guided-diffusion/guided_diffusion/tfg_data_util.py +75 -0
  46. guided-diffusion/guided_diffusion/unet.py +1275 -0
  47. guided-diffusion/setup.py +7 -0
  48. requirements.txt +11 -0
  49. scripts/inference.sh +40 -0
  50. scripts/inference_single_video.sh +35 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/id02548.0pAkJZmlFqc.00001_id04570.0YMGn6BI9rg.00001.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/website_gif_v2.gif filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import requests
5
+
6
+ def process_video(audio_file, video_file):
7
+ # Define file paths
8
+ audio_path = audio_file.name
9
+ video_path = video_file.name
10
+ out_path = "output_video.mp4"
11
+
12
+ # Save uploaded files
13
+ audio_file.save(audio_path)
14
+ video_file.save(video_path)
15
+
16
+ # Define command flags
17
+ sample_mode = "cross" # or "reconstruction"
18
+ generate_from_filelist = 0
19
+ model_path = "checkpoints/checkpoint.pt"
20
+ pads = "0,0,0,0"
21
+
22
+ if sample_mode == "reconstruction":
23
+ sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
24
+ elif sample_mode == "cross":
25
+ sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
26
+ else:
27
+ return "Error: sample_mode can only be \"cross\" or \"reconstruction\""
28
+
29
+ MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
30
+ DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
31
+ SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
32
+ DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
33
+ TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
34
+ GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path={out_path} --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
35
+
36
+ # Combine all flags into one command
37
+ command = f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
38
+
39
+ # Execute the command
40
+ try:
41
+ subprocess.run(command, shell=True, check=True)
42
+ return out_path
43
+ except subprocess.CalledProcessError as e:
44
+ return f"Error processing video: {e}"
45
+
46
+ # Clean up the files after processing
47
+ os.remove(audio_path)
48
+ os.remove(video_path)
49
+
50
+ # Delete output video after sending to the user
51
+ os.remove(out_path)
52
+
53
+ # Create a Gradio interface
54
+ iface = gr.Interface(
55
+ fn=process_video,
56
+ inputs=[
57
+ gr.inputs.Audio(label="Input Audio", type="file"),
58
+ gr.inputs.Video(label="Input Video", type="file")
59
+ ],
60
+ outputs=gr.outputs.Video(label="Processed Video"),
61
+ title="Audio-Video Processing",
62
+ description="Upload an audio file and a video file to process the video based on the audio input."
63
+ )
64
+
65
+ # Launch the interface
66
+ iface.launch()
assets/id02548.0pAkJZmlFqc.00001_id04570.0YMGn6BI9rg.00001.gif ADDED

Git LFS Details

  • SHA256: e870f498b739b783cd69ade2991dd1b0021eab47a3c5a6fe4abf3d07c931dc73
  • Pointer size: 132 Bytes
  • Size of remote file: 9.61 MB
assets/website_gif_v2.gif ADDED

Git LFS Details

  • SHA256: 5ef17ae4f9de5b9397dfe97077d4a82aa592ad34fe8d7559e08189661ef38753
  • Pointer size: 132 Bytes
  • Size of remote file: 6.19 MB
audio/__init__.py ADDED
File without changes
audio/audio.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ # import tensorflow as tf
5
+ from scipy import signal
6
+ from scipy.io import wavfile
7
+ from .hparams import hparams as hp
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ librosa.output.write_wav(path, wav, sr=sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ def get_hop_size():
31
+ hop_size = hp.hop_size
32
+ if hop_size is None:
33
+ assert hp.frame_shift_ms is not None
34
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
35
+ return hop_size
36
+
37
+ def linearspectrogram(wav):
38
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
40
+
41
+ if hp.signal_normalization:
42
+ return _normalize(S)
43
+ return S
44
+
45
+ def melspectrogram(wav):
46
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
48
+
49
+ if hp.signal_normalization:
50
+ return _normalize(S)
51
+ return S
52
+
53
+ def _lws_processor():
54
+ import lws
55
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
56
+
57
+ def _stft(y):
58
+ if hp.use_lws:
59
+ return _lws_processor(hp).stft(y).T
60
+ else:
61
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
62
+
63
+ ##########################################################
64
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
65
+ def num_frames(length, fsize, fshift):
66
+ """Compute number of time frames of spectrogram
67
+ """
68
+ pad = (fsize - fshift)
69
+ if length % fshift == 0:
70
+ M = (length + pad * 2 - fsize) // fshift + 1
71
+ else:
72
+ M = (length + pad * 2 - fsize) // fshift + 2
73
+ return M
74
+
75
+
76
+ def pad_lr(x, fsize, fshift):
77
+ """Compute left and right padding
78
+ """
79
+ M = num_frames(len(x), fsize, fshift)
80
+ pad = (fsize - fshift)
81
+ T = len(x) + 2 * pad
82
+ r = (M - 1) * fshift + fsize - T
83
+ return pad, pad + r
84
+ ##########################################################
85
+ #Librosa correct padding
86
+ def librosa_pad_lr(x, fsize, fshift):
87
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
88
+
89
+ # Conversions
90
+ _mel_basis = None
91
+
92
+ def _linear_to_mel(spectogram):
93
+ global _mel_basis
94
+ if _mel_basis is None:
95
+ _mel_basis = _build_mel_basis()
96
+ return np.dot(_mel_basis, spectogram)
97
+
98
+ def _build_mel_basis():
99
+ assert hp.fmax <= hp.sample_rate // 2
100
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
101
+ fmin=hp.fmin, fmax=hp.fmax)
102
+
103
+ def _amp_to_db(x):
104
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
105
+ return 20 * np.log10(np.maximum(min_level, x))
106
+
107
+ def _db_to_amp(x):
108
+ return np.power(10.0, (x) * 0.05)
109
+
110
+ def _normalize(S):
111
+ if hp.allow_clipping_in_normalization:
112
+ if hp.symmetric_mels:
113
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
114
+ -hp.max_abs_value, hp.max_abs_value)
115
+ else:
116
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
117
+
118
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
119
+ if hp.symmetric_mels:
120
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
121
+ else:
122
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
123
+
124
+ def _denormalize(D):
125
+ if hp.allow_clipping_in_normalization:
126
+ if hp.symmetric_mels:
127
+ return (((np.clip(D, -hp.max_abs_value,
128
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
129
+ + hp.min_level_db)
130
+ else:
131
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
132
+
133
+ if hp.symmetric_mels:
134
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
135
+ else:
136
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
audio/hparams.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import os
3
+
4
+
5
+ class HParams:
6
+ def __init__(self, **kwargs):
7
+ self.data = {}
8
+
9
+ for key, value in kwargs.items():
10
+ self.data[key] = value
11
+
12
+ def __getattr__(self, key):
13
+ if key not in self.data:
14
+ raise AttributeError("'HParams' object has no attribute %s" % key)
15
+ return self.data[key]
16
+
17
+ def set_hparam(self, key, value):
18
+ self.data[key] = value
19
+
20
+
21
+ # Default hyperparameters
22
+ hparams = HParams(
23
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
24
+ # network
25
+ rescale=True, # Whether to rescale audio prior to preprocessing
26
+ rescaling_max=0.9, # Rescaling value
27
+
28
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
29
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
30
+ # Does not work if n_ffit is not multiple of hop_size!!
31
+ use_lws=False,
32
+
33
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
34
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
35
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
36
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
37
+
38
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
39
+
40
+ # Mel and Linear spectrograms normalization/scaling and clipping
41
+ signal_normalization=True,
42
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
43
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
44
+ symmetric_mels=True,
45
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
46
+ # faster and cleaner convergence)
47
+ max_abs_value=4.,
48
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
49
+ # be too big to avoid gradient explosion,
50
+ # not too small for fast convergence)
51
+ # Contribution by @begeekmyfriend
52
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
53
+ # levels. Also allows for better G&L phase reconstruction)
54
+ preemphasize=True, # whether to apply filter
55
+ preemphasis=0.97, # filter coefficient.
56
+
57
+ # Limits
58
+ min_level_db=-100,
59
+ ref_level_db=20,
60
+ fmin=55,
61
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
62
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
63
+ fmax=7600, # To be increased/reduced depending on data.
64
+
65
+ )
66
+
checkpoints/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c71166482d2b893f2f77450563a1bb31d805f3048c7213b974fd9201e9aa4b3
3
+ size 406815527
dataset/LRW/lrw_fullpath.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Converts the LRW video names in filelists to LRW relative paths and dumps them unto new filelists'''
2
+ import os
3
+ filelist = "../filelists/lrw_cross.txt"
4
+
5
+ filelist_split_path = filelist.replace(".txt","_relative_path.txt")
6
+ with open(filelist, 'r') as f:
7
+ lines = f.readlines()
8
+ with open(filelist_split_path, 'w') as f:
9
+ for i in range(len(lines)):
10
+ audio_name, video_name=lines[i].split(' ')
11
+ audio_word = audio_name.split('_')[0]
12
+ video_word = video_name.split('_')[0]
13
+ f.write(os.path.join(audio_word,'test',audio_name)+' '+os.path.join(video_word,'test',video_name))
14
+
15
+ filelist = "../filelists/lrw_reconstruction.txt"
16
+
17
+ filelist_split_path = filelist.replace(".txt","_relative_path.txt")
18
+ with open(filelist, 'r') as f:
19
+ lines = f.readlines()
20
+ with open(filelist_split_path, 'w') as f:
21
+ for i in range(len(lines)):
22
+ audio_name, video_name=lines[i].split(' ')
23
+ audio_word = audio_name.split('_')[0]
24
+ video_word = video_name.split('_')[0]
25
+ f.write(os.path.join(audio_word,'test',audio_name)+' '+os.path.join(video_word,'test',video_name))
dataset/filelists/lrw_cross.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/filelists/lrw_cross_relative_path.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/filelists/lrw_reconstruction.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/filelists/lrw_reconstruction_relative_path.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/filelists/voxceleb2_test_n_5000_reconstruction_5k.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/filelists/voxceleb2_test_n_5000_seed_797_cross_5K.txt ADDED
The diff for this file is too large to render. See raw diff
 
dataset/filelists/voxceleb2_test_n_500_reconstruction.txt ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id09017/SjCgiXBHfNU/00111 id09017/SjCgiXBHfNU/00111
2
+ id05055/HobsYUHmgr0/00138 id05055/HobsYUHmgr0/00138
3
+ id01567/M47d5UckOV8/00099 id01567/M47d5UckOV8/00099
4
+ id01228/SH3eBLMsRwY/00211 id01228/SH3eBLMsRwY/00211
5
+ id07312/m1VY1sC_P_o/00093 id07312/m1VY1sC_P_o/00093
6
+ id08696/dJH9aBSs1nE/00370 id08696/dJH9aBSs1nE/00370
7
+ id07312/9PCY4xwxgcE/00006 id07312/9PCY4xwxgcE/00006
8
+ id07494/1A8ZDo11tzY/00006 id07494/1A8ZDo11tzY/00006
9
+ id00061/thHLZ8tDJ-M/00276 id00061/thHLZ8tDJ-M/00276
10
+ id03862/YQMdTzyG-P8/00297 id03862/YQMdTzyG-P8/00297
11
+ id04570/zsnG6eKzOGE/00406 id04570/zsnG6eKzOGE/00406
12
+ id07414/1m1C-CdhmZ0/00016 id07414/1m1C-CdhmZ0/00016
13
+ id01509/2uKpHd-euIo/00038 id01509/2uKpHd-euIo/00038
14
+ id04276/qHXwXqxL0mk/00401 id04276/qHXwXqxL0mk/00401
15
+ id04366/x-VQ6z2QC4w/00252 id04366/x-VQ6z2QC4w/00252
16
+ id07620/um6sY627GaE/00475 id07620/um6sY627GaE/00475
17
+ id01000/RvjbLfo3XDM/00052 id01000/RvjbLfo3XDM/00052
18
+ id07868/fnWDbUI_Zbg/00289 id07868/fnWDbUI_Zbg/00289
19
+ id01333/cymDCPEhalE/00351 id01333/cymDCPEhalE/00351
20
+ id02317/Mv16h1Bx7HE/00241 id02317/Mv16h1Bx7HE/00241
21
+ id02317/Vi4k3cuwfgc/00342 id02317/Vi4k3cuwfgc/00342
22
+ id01000/eGeGHhuOJJ0/00077 id01000/eGeGHhuOJJ0/00077
23
+ id03980/zaDLb12pDBQ/00130 id03980/zaDLb12pDBQ/00130
24
+ id05124/c-Pa7b81coQ/00354 id05124/c-Pa7b81coQ/00354
25
+ id04478/nhLuGj0vGb8/00234 id04478/nhLuGj0vGb8/00234
26
+ id01541/3su8tn9nwi4/00007 id01541/3su8tn9nwi4/00007
27
+ id06484/cmIyVotzXiE/00125 id06484/cmIyVotzXiE/00125
28
+ id06209/oxofNHGCj7s/00139 id06209/oxofNHGCj7s/00139
29
+ id02181/rxX3t2rzLbg/00146 id02181/rxX3t2rzLbg/00146
30
+ id02286/YL75-u9XYUM/00105 id02286/YL75-u9XYUM/00105
31
+ id04276/v9mSslwD0Kg/00470 id04276/v9mSslwD0Kg/00470
32
+ id07802/6qBSFfV_Mig/00042 id07802/6qBSFfV_Mig/00042
33
+ id04295/DtC2X1KG8TE/00057 id04295/DtC2X1KG8TE/00057
34
+ id00866/shG_183xFlw/00243 id00866/shG_183xFlw/00243
35
+ id03862/2nagLhV_Yvw/00012 id03862/2nagLhV_Yvw/00012
36
+ id04119/Yndoy1jgHWs/00042 id04119/Yndoy1jgHWs/00042
37
+ id04295/mTCDT_Fv5So/00203 id04295/mTCDT_Fv5So/00203
38
+ id08456/o-5hKwhGqac/00354 id08456/o-5hKwhGqac/00354
39
+ id07494/tv6GJkx_Wy4/00331 id07494/tv6GJkx_Wy4/00331
40
+ id04295/mClPHVzTCLI/00196 id04295/mClPHVzTCLI/00196
41
+ id04478/81Tb6kjlNIk/00019 id04478/81Tb6kjlNIk/00019
42
+ id00812/NeNXGI8mox8/00158 id00812/NeNXGI8mox8/00158
43
+ id04536/sfldoEPrFPI/00438 id04536/sfldoEPrFPI/00438
44
+ id07620/aJVbccKJwEw/00327 id07620/aJVbccKJwEw/00327
45
+ id02286/4LAIxvdvguc/00001 id02286/4LAIxvdvguc/00001
46
+ id07802/BfQUBDw7TiM/00080 id07802/BfQUBDw7TiM/00080
47
+ id01066/65k0p7fUBVI/00026 id01066/65k0p7fUBVI/00026
48
+ id03862/w97YzyPYm1k/00460 id03862/w97YzyPYm1k/00460
49
+ id05816/njFBkJSpUrY/00414 id05816/njFBkJSpUrY/00414
50
+ id05124/_Oxp6absIhY/00341 id05124/_Oxp6absIhY/00341
51
+ id07663/mUw-kxAavdM/00192 id07663/mUw-kxAavdM/00192
52
+ id05999/Ls440srvfR4/00127 id05999/Ls440srvfR4/00127
53
+ id02548/Hmlw5PIf64o/00098 id02548/Hmlw5PIf64o/00098
54
+ id04276/Pbo_nlcZ0Lc/00190 id04276/Pbo_nlcZ0Lc/00190
55
+ id07802/FhKML4dLE60/00115 id07802/FhKML4dLE60/00115
56
+ id07621/1L2IUy6gqaM/00012 id07621/1L2IUy6gqaM/00012
57
+ id05654/veGIQ7p2ZSk/00130 id05654/veGIQ7p2ZSk/00130
58
+ id04094/0z1JYPKGBI8/00007 id04094/0z1JYPKGBI8/00007
59
+ id02576/wWUREnOwYo0/00136 id02576/wWUREnOwYo0/00136
60
+ id09017/PLNK1g5w4FY/00099 id09017/PLNK1g5w4FY/00099
61
+ id06484/USbx34RUkVI/00096 id06484/USbx34RUkVI/00096
62
+ id03030/FXbzdRO7t98/00101 id03030/FXbzdRO7t98/00101
63
+ id02057/VCXnx-ozS8c/00263 id02057/VCXnx-ozS8c/00263
64
+ id02542/JUodrwt9ucI/00033 id02542/JUodrwt9ucI/00033
65
+ id03030/DM_Z5D2fkRA/00068 id03030/DM_Z5D2fkRA/00068
66
+ id08552/irj3SqKAe0c/00196 id08552/irj3SqKAe0c/00196
67
+ id03030/YxBoufnVIMw/00177 id03030/YxBoufnVIMw/00177
68
+ id07868/Eaf-dgA59Gs/00061 id07868/Eaf-dgA59Gs/00061
69
+ id08456/6xVSlQDr7-w/00031 id08456/6xVSlQDr7-w/00031
70
+ id06811/OYFkt_n18hg/00128 id06811/OYFkt_n18hg/00128
71
+ id00817/tCnW5E8cMow/00383 id00817/tCnW5E8cMow/00383
72
+ id02542/fXQbNcIbcek/00053 id02542/fXQbNcIbcek/00053
73
+ id01567/oi2g17EF55s/00377 id01567/oi2g17EF55s/00377
74
+ id04366/HsG3OGE22DY/00117 id04366/HsG3OGE22DY/00117
75
+ id01509/1y0aWmgYDtw/00006 id01509/1y0aWmgYDtw/00006
76
+ id04295/pYfyopS672Y/00213 id04295/pYfyopS672Y/00213
77
+ id01989/6JfW9CPAoGY/00006 id01989/6JfW9CPAoGY/00006
78
+ id04366/tbcKV-IjZdI/00243 id04366/tbcKV-IjZdI/00243
79
+ id01298/UY0fkYSUFrY/00208 id01298/UY0fkYSUFrY/00208
80
+ id00817/GAs8WnyFKJM/00120 id00817/GAs8WnyFKJM/00120
81
+ id06484/TCp2-XVatIE/00079 id06484/TCp2-XVatIE/00079
82
+ id08374/Kf9N5AWprG8/00150 id08374/Kf9N5AWprG8/00150
83
+ id01822/QDWgjZqOkvM/00065 id01822/QDWgjZqOkvM/00065
84
+ id03030/pTz652Dx_6w/00230 id03030/pTz652Dx_6w/00230
85
+ id01460/chrI43l2Nuw/00201 id01460/chrI43l2Nuw/00201
86
+ id08374/85f-qB_KJP8/00041 id08374/85f-qB_KJP8/00041
87
+ id07961/PoSkUxZ4ags/00172 id07961/PoSkUxZ4ags/00172
88
+ id01437/uFPYqotT7tU/00233 id01437/uFPYqotT7tU/00233
89
+ id07621/Aan8MoozxII/00095 id07621/Aan8MoozxII/00095
90
+ id08456/fWTULQWYVoA/00250 id08456/fWTULQWYVoA/00250
91
+ id05055/da7Z8oWhFPY/00351 id05055/da7Z8oWhFPY/00351
92
+ id02181/hIvctbfcBx8/00106 id02181/hIvctbfcBx8/00106
93
+ id01541/dEmuPb4A7do/00184 id01541/dEmuPb4A7do/00184
94
+ id00419/a3Y7pQzcn40/00305 id00419/a3Y7pQzcn40/00305
95
+ id07354/dsDxN33xvL0/00262 id07354/dsDxN33xvL0/00262
96
+ id04478/MZh3AEgJ9pc/00092 id04478/MZh3AEgJ9pc/00092
97
+ id05124/UBUFmICrT-I/00281 id05124/UBUFmICrT-I/00281
98
+ id03127/SmGJu-t24hY/00195 id03127/SmGJu-t24hY/00195
99
+ id02465/coOp_DnsmEI/00150 id02465/coOp_DnsmEI/00150
100
+ id01618/qrOl1aaXBH0/00187 id01618/qrOl1aaXBH0/00187
101
+ id03969/WZVnB-m0X9g/00038 id03969/WZVnB-m0X9g/00038
102
+ id05202/s0m_4-SCn44/00186 id05202/s0m_4-SCn44/00186
103
+ id04657/SYVkfHq-pro/00172 id04657/SYVkfHq-pro/00172
104
+ id05176/p2IOP5_s_LM/00093 id05176/p2IOP5_s_LM/00093
105
+ id04950/XJS6SLQuCNM/00169 id04950/XJS6SLQuCNM/00169
106
+ id02019/anSrwA_9RPE/00152 id02019/anSrwA_9RPE/00152
107
+ id04570/Q-faEy1VXxQ/00140 id04570/Q-faEy1VXxQ/00140
108
+ id07621/bMvG2mQMZZw/00303 id07621/bMvG2mQMZZw/00303
109
+ id06811/vC3yQiWuuOI/00354 id06811/vC3yQiWuuOI/00354
110
+ id03839/aWMP8xzq2BE/00292 id03839/aWMP8xzq2BE/00292
111
+ id04094/j1ajUkR6_Q4/00326 id04094/j1ajUkR6_Q4/00326
112
+ id08149/o0Zdr9Jla7U/00047 id08149/o0Zdr9Jla7U/00047
113
+ id00017/hcr4tT9y3xs/00117 id00017/hcr4tT9y3xs/00117
114
+ id04950/Cu4jGRmYa4c/00064 id04950/Cu4jGRmYa4c/00064
115
+ id01567/TMozlhoPMfI/00223 id01567/TMozlhoPMfI/00223
116
+ id08374/QltFme-lqeI/00226 id08374/QltFme-lqeI/00226
117
+ id06816/tHor4VN8090/00259 id06816/tHor4VN8090/00259
118
+ id07494/xQ0YMPe-9u8/00413 id07494/xQ0YMPe-9u8/00413
119
+ id08374/FwR1K1rL3QI/00110 id08374/FwR1K1rL3QI/00110
120
+ id06692/Hlahj5abifM/00257 id06692/Hlahj5abifM/00257
121
+ id00419/J2LscHjRX7Q/00154 id00419/J2LscHjRX7Q/00154
122
+ id02057/CI5-q_qTR5I/00112 id02057/CI5-q_qTR5I/00112
123
+ id03862/7IccaH4HXRs/00069 id03862/7IccaH4HXRs/00069
124
+ id04656/ar3rKrkbjqI/00257 id04656/ar3rKrkbjqI/00257
125
+ id07494/XMEIdqio6ic/00184 id07494/XMEIdqio6ic/00184
126
+ id04657/dn4XY5c6mEw/00265 id04657/dn4XY5c6mEw/00265
127
+ id04570/SFKt669qIqs/00156 id04570/SFKt669qIqs/00156
128
+ id01541/sMDYdAB0MPs/00306 id01541/sMDYdAB0MPs/00306
129
+ id08456/F2O-frqyr9c/00101 id08456/F2O-frqyr9c/00101
130
+ id08701/_Ysb9mVibbk/00253 id08701/_Ysb9mVibbk/00253
131
+ id01333/e4FoER8nqx0/00365 id01333/e4FoER8nqx0/00365
132
+ id05124/F0Xpd6OoiDY/00161 id05124/F0Xpd6OoiDY/00161
133
+ id01593/AVmZf6Kl1So/00071 id01593/AVmZf6Kl1So/00071
134
+ id01567/fOlxxDqdrgc/00299 id01567/fOlxxDqdrgc/00299
135
+ id06484/2KVWoftPf2o/00001 id06484/2KVWoftPf2o/00001
136
+ id01224/g4jVqkEm1Gs/00274 id01224/g4jVqkEm1Gs/00274
137
+ id02445/ZX_6RMrTEP0/00066 id02445/ZX_6RMrTEP0/00066
138
+ id04656/5TR-W77XgF4/00032 id04656/5TR-W77XgF4/00032
139
+ id01618/F_ExF9xDajc/00060 id01618/F_ExF9xDajc/00060
140
+ id08392/gPX4IC53KwI/00355 id08392/gPX4IC53KwI/00355
141
+ id00866/pNbDtfW1JW4/00221 id00866/pNbDtfW1JW4/00221
142
+ id00812/b3dBqOtzsx0/00276 id00812/b3dBqOtzsx0/00276
143
+ id08701/61Al05HARgA/00001 id08701/61Al05HARgA/00001
144
+ id07663/FFo4JwVXeUM/00119 id07663/FFo4JwVXeUM/00119
145
+ id02057/22zJ50ky7CQ/00013 id02057/22zJ50ky7CQ/00013
146
+ id05055/2onVoeSgouI/00028 id05055/2onVoeSgouI/00028
147
+ id04006/zvUZFL0NyhM/00260 id04006/zvUZFL0NyhM/00260
148
+ id04950/EpOnsaBin0A/00077 id04950/EpOnsaBin0A/00077
149
+ id05015/RhBpC9Fc7a4/00154 id05015/RhBpC9Fc7a4/00154
150
+ id04656/Z_JFBDW9eZE/00251 id04656/Z_JFBDW9eZE/00251
151
+ id01509/2sb83ZBlbJg/00034 id01509/2sb83ZBlbJg/00034
152
+ id04030/JbcD0P6KGe0/00036 id04030/JbcD0P6KGe0/00036
153
+ id02542/cwgUjse_REU/00040 id02542/cwgUjse_REU/00040
154
+ id07620/xFc9X6EXtRM/00478 id07620/xFc9X6EXtRM/00478
155
+ id07354/Qrg89rvtZ1k/00217 id07354/Qrg89rvtZ1k/00217
156
+ id03839/wSQMEZMxxx4/00461 id03839/wSQMEZMxxx4/00461
157
+ id03127/iWeklsXc0H8/00268 id03127/iWeklsXc0H8/00268
158
+ id07663/54qlJ2HZ08s/00096 id07663/54qlJ2HZ08s/00096
159
+ id07961/Orp8s5aHYc8/00158 id07961/Orp8s5aHYc8/00158
160
+ id03347/y_F4aAkN0d8/00417 id03347/y_F4aAkN0d8/00417
161
+ id06913/KNDyf594xQg/00056 id06913/KNDyf594xQg/00056
162
+ id04366/DIgAc22fq9c/00080 id04366/DIgAc22fq9c/00080
163
+ id07396/uJPtbxlXi2c/00187 id07396/uJPtbxlXi2c/00187
164
+ id07868/gVspdH-U2XE/00290 id07868/gVspdH-U2XE/00290
165
+ id05594/u7qCFBP1nH4/00184 id05594/u7qCFBP1nH4/00184
166
+ id01541/mDoT5mpo_2c/00241 id01541/mDoT5mpo_2c/00241
167
+ id07354/0y9b8qlM170/00011 id07354/0y9b8qlM170/00011
168
+ id01460/DnnphhTlRPE/00075 id01460/DnnphhTlRPE/00075
169
+ id02548/1CNhmMmirfA/00009 id02548/1CNhmMmirfA/00009
170
+ id03127/k8z6DxdyF9w/00291 id03127/k8z6DxdyF9w/00291
171
+ id01437/zLRJ_8_M5Wg/00263 id01437/zLRJ_8_M5Wg/00263
172
+ id02576/WnbNQuJzErQ/00086 id02576/WnbNQuJzErQ/00086
173
+ id01333/M0UD9g1x18c/00128 id01333/M0UD9g1x18c/00128
174
+ id04295/1fSjOItVYVg/00001 id04295/1fSjOItVYVg/00001
175
+ id08456/8tt1LbCoU0E/00054 id08456/8tt1LbCoU0E/00054
176
+ id07494/r-ToqH_EJNs/00318 id07494/r-ToqH_EJNs/00318
177
+ id06816/XBKj9XWlZCw/00123 id06816/XBKj9XWlZCw/00123
178
+ id03030/haoNit7a4W0/00201 id03030/haoNit7a4W0/00201
179
+ id03839/aeObhOJLQzQ/00293 id03839/aeObhOJLQzQ/00293
180
+ id07868/COb1gFHXsBQ/00059 id07868/COb1gFHXsBQ/00059
181
+ id01224/eYWcMCsgkLY/00255 id01224/eYWcMCsgkLY/00255
182
+ id04006/K5ueXBlS6rc/00049 id04006/K5ueXBlS6rc/00049
183
+ id07620/G5-1CUbaz0c/00107 id07620/G5-1CUbaz0c/00107
184
+ id06104/cj0TAnwndoc/00230 id06104/cj0TAnwndoc/00230
185
+ id00061/STX1ycPt8fU/00076 id00061/STX1ycPt8fU/00076
186
+ id04478/wMbobxEQ7j8/00336 id04478/wMbobxEQ7j8/00336
187
+ id01106/7X_xtnJhEc0/00031 id01106/7X_xtnJhEc0/00031
188
+ id08374/zaYzRbE_2C8/00494 id08374/zaYzRbE_2C8/00494
189
+ id04276/MgOqCfwKE70/00173 id04276/MgOqCfwKE70/00173
190
+ id03127/Lgd5qn2-kMo/00079 id03127/Lgd5qn2-kMo/00079
191
+ id00154/xH3Pp_5yxOk/00153 id00154/xH3Pp_5yxOk/00153
192
+ id04030/7mXUMuo5_NE/00001 id04030/7mXUMuo5_NE/00001
193
+ id02542/p7bvjcLbZm4/00097 id02542/p7bvjcLbZm4/00097
194
+ id04232/T7dROCqmwNQ/00235 id04232/T7dROCqmwNQ/00235
195
+ id02548/KrXU-_jrtxY/00147 id02548/KrXU-_jrtxY/00147
196
+ id01567/SZyTC5dxJOY/00219 id01567/SZyTC5dxJOY/00219
197
+ id03524/2DD4Np7SaWw/00007 id03524/2DD4Np7SaWw/00007
198
+ id04094/DRq5F2261Ko/00072 id04094/DRq5F2261Ko/00072
199
+ id07802/HrpJg06dowY/00152 id07802/HrpJg06dowY/00152
200
+ id06816/pBt-DxsTFc8/00231 id06816/pBt-DxsTFc8/00231
201
+ id00154/2pSNL5YdcoQ/00002 id00154/2pSNL5YdcoQ/00002
202
+ id01541/C29fUBtimOE/00038 id01541/C29fUBtimOE/00038
203
+ id06310/b6qPjJ0isPI/00155 id06310/b6qPjJ0isPI/00155
204
+ id05714/wFGNufaMbDY/00025 id05714/wFGNufaMbDY/00025
205
+ id03980/m-8Ffv2RqYs/00092 id03980/m-8Ffv2RqYs/00092
206
+ id01437/uXAe0vbNWeo/00238 id01437/uXAe0vbNWeo/00238
207
+ id04232/tPZ-zVT67gs/00479 id04232/tPZ-zVT67gs/00479
208
+ id06811/ImzUwwYU6SQ/00067 id06811/ImzUwwYU6SQ/00067
209
+ id05459/wq3Z0I944wU/00436 id05459/wq3Z0I944wU/00436
210
+ id03969/Evoldg-U2_c/00024 id03969/Evoldg-U2_c/00024
211
+ id08548/BSChFozahbU/00019 id08548/BSChFozahbU/00019
212
+ id04950/PQEAck-3wcA/00134 id04950/PQEAck-3wcA/00134
213
+ id04295/G4YnExZSzlM/00066 id04295/G4YnExZSzlM/00066
214
+ id05176/mc7rFp2B1j0/00092 id05176/mc7rFp2B1j0/00092
215
+ id00812/1Xfgvdu7oDo/00001 id00812/1Xfgvdu7oDo/00001
216
+ id05459/UPSPGawaVsg/00233 id05459/UPSPGawaVsg/00233
217
+ id04656/7nG3rOv0oBw/00050 id04656/7nG3rOv0oBw/00050
218
+ id02548/nvYBpt14BrQ/00309 id02548/nvYBpt14BrQ/00309
219
+ id02317/A3AvljK8Upk/00102 id02317/A3AvljK8Upk/00102
220
+ id04478/qLNvRwMkhik/00242 id04478/qLNvRwMkhik/00242
221
+ id01228/lCDMC8JvKyU/00295 id01228/lCDMC8JvKyU/00295
222
+ id03041/5CfnYwQCW48/00001 id03041/5CfnYwQCW48/00001
223
+ id04950/LnsriCjCIV4/00116 id04950/LnsriCjCIV4/00116
224
+ id04094/plxNYSFgDTM/00384 id04094/plxNYSFgDTM/00384
225
+ id01460/30_QmGw7lmE/00030 id01460/30_QmGw7lmE/00030
226
+ id04366/6rX7hCNSjaw/00056 id04366/6rX7hCNSjaw/00056
227
+ id01041/m-xolqIq8p4/00370 id01041/m-xolqIq8p4/00370
228
+ id04950/BG4CCg2RiuQ/00052 id04950/BG4CCg2RiuQ/00052
229
+ id01989/7g0A7pF94r0/00018 id01989/7g0A7pF94r0/00018
230
+ id03382/b_NJ2Xz3G4Y/00030 id03382/b_NJ2Xz3G4Y/00030
231
+ id00812/IteHRVKyzaE/00138 id00812/IteHRVKyzaE/00138
232
+ id00061/bdkqfVtDZVY/00121 id00061/bdkqfVtDZVY/00121
233
+ id03839/YkYIh4cYwwg/00275 id03839/YkYIh4cYwwg/00275
234
+ id07354/wyTuCRGjUIQ/00477 id07354/wyTuCRGjUIQ/00477
235
+ id02057/TddnW2TaXrc/00246 id02057/TddnW2TaXrc/00246
236
+ id01989/gHVHtKTQBsw/00128 id01989/gHVHtKTQBsw/00128
237
+ id08374/bXlUHb5hxxA/00266 id08374/bXlUHb5hxxA/00266
238
+ id03862/TE2zQc8_W-g/00252 id03862/TE2zQc8_W-g/00252
239
+ id08696/86-k8TuowAE/00033 id08696/86-k8TuowAE/00033
240
+ id05176/K8yZYHg_4ro/00050 id05176/K8yZYHg_4ro/00050
241
+ id04253/SKsPkHMGHYY/00240 id04253/SKsPkHMGHYY/00240
242
+ id07874/2KK4ozkjaEE/00002 id07874/2KK4ozkjaEE/00002
243
+ id08392/g-SJYYaaLgE/00352 id08392/g-SJYYaaLgE/00352
244
+ id02542/glhCf1hwJhE/00065 id02542/glhCf1hwJhE/00065
245
+ id00817/FsL-bTbDTyw/00112 id00817/FsL-bTbDTyw/00112
246
+ id04862/IuXPj9VhUVA/00100 id04862/IuXPj9VhUVA/00100
247
+ id06811/f9-8d3lNNcw/00237 id06811/f9-8d3lNNcw/00237
248
+ id04094/JUYMzfVp8zI/00113 id04094/JUYMzfVp8zI/00113
249
+ id03347/r-xJUB0A4ok/00346 id03347/r-xJUB0A4ok/00346
250
+ id07868/MNibTv_ODQ8/00148 id07868/MNibTv_ODQ8/00148
251
+ id08392/3e5zvNaT-eU/00020 id08392/3e5zvNaT-eU/00020
252
+ id04295/bKMKvAr440A/00141 id04295/bKMKvAr440A/00141
253
+ id04295/l62YPD0ZkZI/00185 id04295/l62YPD0ZkZI/00185
254
+ id07312/RO9DsspwXiE/00047 id07312/RO9DsspwXiE/00047
255
+ id03030/rmFsUV5ICKk/00267 id03030/rmFsUV5ICKk/00267
256
+ id03677/nVWTTopGQdU/00181 id03677/nVWTTopGQdU/00181
257
+ id00866/xQ1Yy0kjvjA/00256 id00866/xQ1Yy0kjvjA/00256
258
+ id01333/fRnqtJR0rws/00371 id01333/fRnqtJR0rws/00371
259
+ id05055/AZoIKG33E8s/00115 id05055/AZoIKG33E8s/00115
260
+ id01822/_CkfCmQXII8/00098 id01822/_CkfCmQXII8/00098
261
+ id01593/_gyaAyVi6SA/00344 id01593/_gyaAyVi6SA/00344
262
+ id04295/DS3RDwf2xI8/00049 id04295/DS3RDwf2xI8/00049
263
+ id00812/EjO-VORTv_o/00098 id00812/EjO-VORTv_o/00098
264
+ id04657/WdJ_DuU0ack/00236 id04657/WdJ_DuU0ack/00236
265
+ id04232/AB9fk1MH2rA/00035 id04232/AB9fk1MH2rA/00035
266
+ id00419/chfgCUm9-Mg/00364 id00419/chfgCUm9-Mg/00364
267
+ id02577/Az0BGrX_TwI/00021 id02577/Az0BGrX_TwI/00021
268
+ id01437/hyj4OYm0cvA/00195 id01437/hyj4OYm0cvA/00195
269
+ id01593/tLFWX-IdAwI/00431 id01593/tLFWX-IdAwI/00431
270
+ id04536/MNDmkEXRS7s/00312 id04536/MNDmkEXRS7s/00312
271
+ id03789/7qhkM8qY3Fw/00077 id03789/7qhkM8qY3Fw/00077
272
+ id01593/neAk6K8BvTA/00397 id01593/neAk6K8BvTA/00397
273
+ id06484/jTHSVo6NvS4/00151 id06484/jTHSVo6NvS4/00151
274
+ id07414/cAudd_5Yv2I/00256 id07414/cAudd_5Yv2I/00256
275
+ id00866/ADzqaRZtJNA/00087 id00866/ADzqaRZtJNA/00087
276
+ id06484/ZySpn0Aj09k/00108 id06484/ZySpn0Aj09k/00108
277
+ id07312/ZHBjHQENqW8/00053 id07312/ZHBjHQENqW8/00053
278
+ id04656/LDuq2UPHKoA/00157 id04656/LDuq2UPHKoA/00157
279
+ id01509/UZL8Obdt--8/00181 id01509/UZL8Obdt--8/00181
280
+ id05816/7jt8zGB27QQ/00017 id05816/7jt8zGB27QQ/00017
281
+ id08456/7PKsuBS5LQI/00050 id08456/7PKsuBS5LQI/00050
282
+ id06913/Tx0vAZhSPuE/00077 id06913/Tx0vAZhSPuE/00077
283
+ id02465/UEmI4r5G-5Y/00117 id02465/UEmI4r5G-5Y/00117
284
+ id01460/9sefvU9y4Kw/00046 id01460/9sefvU9y4Kw/00046
285
+ id01567/uYDx0vIVy_A/00429 id01567/uYDx0vIVy_A/00429
286
+ id07961/qott7SmhA-A/00351 id07961/qott7SmhA-A/00351
287
+ id00866/Awi1Q0yib1s/00092 id00866/Awi1Q0yib1s/00092
288
+ id02086/CqJKcn8m_Xo/00152 id02086/CqJKcn8m_Xo/00152
289
+ id05015/Obbv73CqtmQ/00137 id05015/Obbv73CqtmQ/00137
290
+ id01041/1UYZqPpavtk/00001 id01041/1UYZqPpavtk/00001
291
+ id01593/GiLxqKSI68o/00188 id01593/GiLxqKSI68o/00188
292
+ id02317/IR0psXbOjdc/00176 id02317/IR0psXbOjdc/00176
293
+ id01066/X33aJxc3Kt0/00112 id01066/X33aJxc3Kt0/00112
294
+ id08456/VU3fkD-QqPw/00206 id08456/VU3fkD-QqPw/00206
295
+ id04536/wat5sbCSs0k/00470 id04536/wat5sbCSs0k/00470
296
+ id01066/4KOSmyAMipc/00020 id01066/4KOSmyAMipc/00020
297
+ id02445/f5u3ktNPHAk/00074 id02445/f5u3ktNPHAk/00074
298
+ id03041/NJUcU7j30JI/00011 id03041/NJUcU7j30JI/00011
299
+ id00817/vUezvJDh_tA/00394 id00817/vUezvJDh_tA/00394
300
+ id04478/sw50KQMY8vw/00298 id04478/sw50KQMY8vw/00298
301
+ id04657/hMrgeYf5ToQ/00267 id04657/hMrgeYf5ToQ/00267
302
+ id02548/VdjlKRtLD_w/00206 id02548/VdjlKRtLD_w/00206
303
+ id06310/4oJF1NW2bIg/00006 id06310/4oJF1NW2bIg/00006
304
+ id01509/jqbtAt91alI/00329 id01509/jqbtAt91alI/00329
305
+ id07414/oXx9CvIeFFY/00407 id07414/oXx9CvIeFFY/00407
306
+ id04570/mwhiZtTZYX0/00271 id04570/mwhiZtTZYX0/00271
307
+ id00812/AzDjo0Uyk4Y/00061 id00812/AzDjo0Uyk4Y/00061
308
+ id05999/MJwLq17VoMA/00146 id05999/MJwLq17VoMA/00146
309
+ id07414/dsqrI97WQHE/00319 id07414/dsqrI97WQHE/00319
310
+ id05015/C3KsCD-pUgs/00046 id05015/C3KsCD-pUgs/00046
311
+ id06484/Gh6H7Md_L2k/00053 id06484/Gh6H7Md_L2k/00053
312
+ id00081/xlwJqdrzeMA/00291 id00081/xlwJqdrzeMA/00291
313
+ id05055/RLN5nKfza4A/00219 id05055/RLN5nKfza4A/00219
314
+ id05055/OKw_hph-hK8/00197 id05055/OKw_hph-hK8/00197
315
+ id03839/xtBkY9xYpjA/00464 id03839/xtBkY9xYpjA/00464
316
+ id07620/HEX00yF8LTs/00117 id07620/HEX00yF8LTs/00117
317
+ id05816/hjrZgsKuvpw/00349 id05816/hjrZgsKuvpw/00349
318
+ id02548/6LPbT49zy38/00050 id02548/6LPbT49zy38/00050
319
+ id01000/7eYakM6qrTs/00006 id01000/7eYakM6qrTs/00006
320
+ id02181/cNCj0pLxR24/00084 id02181/cNCj0pLxR24/00084
321
+ id02086/sSliWvu6Ufs/00453 id02086/sSliWvu6Ufs/00453
322
+ id03178/KHelFt1Jyyg/00057 id03178/KHelFt1Jyyg/00057
323
+ id05594/8dYcSoUAQO8/00014 id05594/8dYcSoUAQO8/00014
324
+ id05015/JmvJemqIeS0/00102 id05015/JmvJemqIeS0/00102
325
+ id00081/EvCyt2keqW4/00065 id00081/EvCyt2keqW4/00065
326
+ id07663/QWe7IIGrv5s/00146 id07663/QWe7IIGrv5s/00146
327
+ id01618/kzxW2WAFWLI/00126 id01618/kzxW2WAFWLI/00126
328
+ id00562/X7FJ3M3bz3c/00124 id00562/X7FJ3M3bz3c/00124
329
+ id07961/bvPOvzukTE4/00224 id07961/bvPOvzukTE4/00224
330
+ id03789/nv8sQplhvX0/00357 id03789/nv8sQplhvX0/00357
331
+ id04295/VUHarbuO_eE/00125 id04295/VUHarbuO_eE/00125
332
+ id01822/IaBziWYcwK4/00037 id01822/IaBziWYcwK4/00037
333
+ id05015/X1opVctkTE8/00170 id05015/X1opVctkTE8/00170
334
+ id01041/MMXznNig_iU/00248 id01041/MMXznNig_iU/00248
335
+ id02465/EZ_F0hUZdS4/00054 id02465/EZ_F0hUZdS4/00054
336
+ id04656/Bi7kCsbg5L0/00061 id04656/Bi7kCsbg5L0/00061
337
+ id07494/K4ndWNAHgdU/00093 id07494/K4ndWNAHgdU/00093
338
+ id07354/TKTT7fArInQ/00218 id07354/TKTT7fArInQ/00218
339
+ id05714/Lu4PPvWXGn8/00014 id05714/Lu4PPvWXGn8/00014
340
+ id05654/07pANazoyJg/00001 id05654/07pANazoyJg/00001
341
+ id01066/FDp-ZLCWrIc/00054 id01066/FDp-ZLCWrIc/00054
342
+ id05999/ZQJVmCJFjNs/00182 id05999/ZQJVmCJFjNs/00182
343
+ id04570/5Fg6CLuRntk/00041 id04570/5Fg6CLuRntk/00041
344
+ id08696/vqLNqYW4TQA/00476 id08696/vqLNqYW4TQA/00476
345
+ id04862/2uYHadPvHRU/00016 id04862/2uYHadPvHRU/00016
346
+ id03980/7MRUusImkno/00001 id03980/7MRUusImkno/00001
347
+ id02542/QJKFnt1lHeE/00035 id02542/QJKFnt1lHeE/00035
348
+ id04536/OYH-6uGB6jI/00322 id04536/OYH-6uGB6jI/00322
349
+ id06484/dOTMnYZcY9Q/00126 id06484/dOTMnYZcY9Q/00126
350
+ id04478/GZQGZOmFU5U/00063 id04478/GZQGZOmFU5U/00063
351
+ id01224/tELp6C7FELU/00421 id01224/tELp6C7FELU/00421
352
+ id03862/5m5iPZNJS6c/00022 id03862/5m5iPZNJS6c/00022
353
+ id05124/lcDhSnyeN5E/00381 id05124/lcDhSnyeN5E/00381
354
+ id08149/3V9V5sDAWTc/00001 id08149/3V9V5sDAWTc/00001
355
+ id02181/iEF0MWApQms/00108 id02181/iEF0MWApQms/00108
356
+ id04536/xrsxSF2qey8/00471 id04536/xrsxSF2qey8/00471
357
+ id03178/9AJzTUwGbRk/00005 id03178/9AJzTUwGbRk/00005
358
+ id01041/Izmh75CZNW0/00207 id01041/Izmh75CZNW0/00207
359
+ id03041/g5YLpUZBNKc/00018 id03041/g5YLpUZBNKc/00018
360
+ id03347/nSAKXYdEOOM/00297 id03347/nSAKXYdEOOM/00297
361
+ id03347/pPWGEPixOoM/00337 id03347/pPWGEPixOoM/00337
362
+ id07312/XBBpLMEjfUo/00048 id07312/XBBpLMEjfUo/00048
363
+ id08456/6QFe7cYnZk4/00023 id08456/6QFe7cYnZk4/00023
364
+ id05176/5Hk_hj0oXN8/00004 id05176/5Hk_hj0oXN8/00004
365
+ id07426/DBBfi7aKLx4/00038 id07426/DBBfi7aKLx4/00038
366
+ id07494/uhPKcTLLwcM/00347 id07494/uhPKcTLLwcM/00347
367
+ id02576/agxjz_O2Wfs/00088 id02576/agxjz_O2Wfs/00088
368
+ id01541/SvTz_Pn15Vk/00119 id01541/SvTz_Pn15Vk/00119
369
+ id07414/Uxggn91FBog/00214 id07414/Uxggn91FBog/00214
370
+ id04253/1HOlzefgLu8/00001 id04253/1HOlzefgLu8/00001
371
+ id01567/RPUd0ua7RR0/00216 id01567/RPUd0ua7RR0/00216
372
+ id04657/5DzZTPLgwTM/00044 id04657/5DzZTPLgwTM/00044
373
+ id04006/zSMWS35kYdQ/00253 id04006/zSMWS35kYdQ/00253
374
+ id03347/KT7B07WFWyM/00104 id03347/KT7B07WFWyM/00104
375
+ id02445/z5u4yO1EsZo/00109 id02445/z5u4yO1EsZo/00109
376
+ id00154/z1dLArSg5PQ/00190 id00154/z1dLArSg5PQ/00190
377
+ id07414/Cn6Ws4oK1jg/00095 id07414/Cn6Ws4oK1jg/00095
378
+ id02286/WHS1n7XUt_8/00103 id02286/WHS1n7XUt_8/00103
379
+ id01509/Zmmnr4iRsCM/00230 id01509/Zmmnr4iRsCM/00230
380
+ id04276/tGOA4fVnSgw/00448 id04276/tGOA4fVnSgw/00448
381
+ id00419/nu9cRW2J4Dk/00420 id00419/nu9cRW2J4Dk/00420
382
+ id07868/6RQX9l98N-g/00002 id07868/6RQX9l98N-g/00002
383
+ id03839/1lh57VnuaKE/00004 id03839/1lh57VnuaKE/00004
384
+ id03178/LT-BNQKA9NU/00075 id03178/LT-BNQKA9NU/00075
385
+ id01460/Es6CkRmkIBY/00080 id01460/Es6CkRmkIBY/00080
386
+ id06692/T2Xk7MO6m2g/00297 id06692/T2Xk7MO6m2g/00297
387
+ id01892/d8b9y_CRE3M/00102 id01892/d8b9y_CRE3M/00102
388
+ id07426/K_25cVSB-JU/00063 id07426/K_25cVSB-JU/00063
389
+ id01333/LI6eLfuTn6I/00127 id01333/LI6eLfuTn6I/00127
390
+ id00081/hIBFutPzn8s/00158 id00081/hIBFutPzn8s/00158
391
+ id04536/2j8I_WX5mhY/00009 id04536/2j8I_WX5mhY/00009
392
+ id04232/UElg0R7fmlk/00253 id04232/UElg0R7fmlk/00253
393
+ id01460/eZR__GGkVw4/00221 id01460/eZR__GGkVw4/00221
394
+ id01041/GymfYtTsKEU/00119 id01041/GymfYtTsKEU/00119
395
+ id07396/xK1gClL60tY/00191 id07396/xK1gClL60tY/00191
396
+ id05459/81o3ictaOnU/00075 id05459/81o3ictaOnU/00075
397
+ id02685/yN8ilDTW-o4/00114 id02685/yN8ilDTW-o4/00114
398
+ id02286/c8LjgwDQAkw/00137 id02286/c8LjgwDQAkw/00137
399
+ id01541/SWcGs-DbV9Q/00100 id01541/SWcGs-DbV9Q/00100
400
+ id01822/x4Fr2ceg_f8/00231 id01822/x4Fr2ceg_f8/00231
401
+ id03347/FKY5V8wmX5k/00043 id03347/FKY5V8wmX5k/00043
402
+ id00817/0GmSijZelGY/00001 id00817/0GmSijZelGY/00001
403
+ id06209/ahL3F1x5sE4/00091 id06209/ahL3F1x5sE4/00091
404
+ id06692/4k3Eo5s1Rwo/00057 id06692/4k3Eo5s1Rwo/00057
405
+ id09017/sduESYpj2-I/00297 id09017/sduESYpj2-I/00297
406
+ id07354/grg37qaxKjI/00329 id07354/grg37qaxKjI/00329
407
+ id07802/X8I5FN64_Oc/00199 id07802/X8I5FN64_Oc/00199
408
+ id07494/JV5S_SUcHmI/00088 id07494/JV5S_SUcHmI/00088
409
+ id03524/eHrI5bD8hSs/00282 id03524/eHrI5bD8hSs/00282
410
+ id01460/HNjuGz9ayBk/00109 id01460/HNjuGz9ayBk/00109
411
+ id04570/961AefP1-is/00056 id04570/961AefP1-is/00056
412
+ id00419/749eTxP4Us8/00061 id00419/749eTxP4Us8/00061
413
+ id00017/OLguY5ofUrY/00039 id00017/OLguY5ofUrY/00039
414
+ id08392/RogKVSjaAH0/00293 id08392/RogKVSjaAH0/00293
415
+ id01066/lI1wGa1UhEM/00205 id01066/lI1wGa1UhEM/00205
416
+ id07621/zSdriAuJUKo/00485 id07621/zSdriAuJUKo/00485
417
+ id03862/JBkaiUNeMmk/00166 id03862/JBkaiUNeMmk/00166
418
+ id00017/E6aqL_Nc410/00027 id00017/E6aqL_Nc410/00027
419
+ id03839/fi-g--cBwnU/00348 id03839/fi-g--cBwnU/00348
420
+ id05654/eLztZmvnk-k/00095 id05654/eLztZmvnk-k/00095
421
+ id02548/wF5HfFXZCBI/00349 id02548/wF5HfFXZCBI/00349
422
+ id02576/LAipS5WJ29s/00075 id02576/LAipS5WJ29s/00075
423
+ id06692/SEPs17_AkTI/00295 id06692/SEPs17_AkTI/00295
424
+ id05459/kkaYxtBZnNo/00348 id05459/kkaYxtBZnNo/00348
425
+ id04232/MEGVEqgGCME/00167 id04232/MEGVEqgGCME/00167
426
+ id01989/8CUktsB_2bA/00031 id01989/8CUktsB_2bA/00031
427
+ id01066/kqP_NZ1FRlM/00176 id01066/kqP_NZ1FRlM/00176
428
+ id03382/ockh8KdXJP8/00059 id03382/ockh8KdXJP8/00059
429
+ id01593/pO180haP_vo/00410 id01593/pO180haP_vo/00410
430
+ id07396/nTQDZrnGXXY/00179 id07396/nTQDZrnGXXY/00179
431
+ id03030/rg-VUeksKaU/00257 id03030/rg-VUeksKaU/00257
432
+ id08911/IddDkZwRflE/00053 id08911/IddDkZwRflE/00053
433
+ id02317/K2GT02zavxo/00193 id02317/K2GT02zavxo/00193
434
+ id01298/5P4ldDRuo5c/00065 id01298/5P4ldDRuo5c/00065
435
+ id01989/Evbf6fMJNmk/00060 id01989/Evbf6fMJNmk/00060
436
+ id05124/fNJI2A0v8yI/00357 id05124/fNJI2A0v8yI/00357
437
+ id02465/RLi2ItGherA/00098 id02465/RLi2ItGherA/00098
438
+ id07868/qMNfMcG6sh0/00346 id07868/qMNfMcG6sh0/00346
439
+ id04366/tmoYV4kPOGU/00246 id04366/tmoYV4kPOGU/00246
440
+ id06484/_ZkoebnFkVA/00110 id06484/_ZkoebnFkVA/00110
441
+ id04276/I9gCyrZWFn0/00097 id04276/I9gCyrZWFn0/00097
442
+ id03978/IMn6f0iDOtE/00032 id03978/IMn6f0iDOtE/00032
443
+ id00419/w_0sK8WuSsg/00472 id00419/w_0sK8WuSsg/00472
444
+ id04478/RwcHXQ3MvsQ/00109 id04478/RwcHXQ3MvsQ/00109
445
+ id08696/cUmyIjpOYlY/00360 id08696/cUmyIjpOYlY/00360
446
+ id04366/DqBQx6AZ1Nk/00083 id04366/DqBQx6AZ1Nk/00083
447
+ id05459/RhOon49C3g8/00201 id05459/RhOon49C3g8/00201
448
+ id04656/OzgjshkHUiA/00166 id04656/OzgjshkHUiA/00166
449
+ id03969/x38Sqv819yE/00110 id03969/x38Sqv819yE/00110
450
+ id00061/0G9G9oyFHI8/00001 id00061/0G9G9oyFHI8/00001
451
+ id06913/IreNhnVfTkQ/00043 id06913/IreNhnVfTkQ/00043
452
+ id01618/NqYUgbuImpk/00096 id01618/NqYUgbuImpk/00096
453
+ id08552/y05_B9NXizo/00237 id08552/y05_B9NXizo/00237
454
+ id01460/zcTt06bjKuA/00365 id01460/zcTt06bjKuA/00365
455
+ id00866/nI-zVYcQX40/00220 id00866/nI-zVYcQX40/00220
456
+ id08374/9eMfNJiKBPQ/00056 id08374/9eMfNJiKBPQ/00056
457
+ id03524/nKxz0LxKZ58/00344 id03524/nKxz0LxKZ58/00344
458
+ id09017/A3CAugN2cjk/00021 id09017/A3CAugN2cjk/00021
459
+ id02685/NtHmnSLaGCA/00036 id02685/NtHmnSLaGCA/00036
460
+ id01224/atjwjz0vAk8/00213 id01224/atjwjz0vAk8/00213
461
+ id07961/gvLf2DggTu0/00271 id07961/gvLf2DggTu0/00271
462
+ id01567/CCs8rZLCdVw/00043 id01567/CCs8rZLCdVw/00043
463
+ id03347/nbmPriSE9NY/00316 id03347/nbmPriSE9NY/00316
464
+ id06104/snzG1OymFgs/00273 id06104/snzG1OymFgs/00273
465
+ id02019/xsXm-MSuD-E/00290 id02019/xsXm-MSuD-E/00290
466
+ id00061/VugwXDj1ka4/00088 id00061/VugwXDj1ka4/00088
467
+ id01224/4z68GFZuYKU/00028 id01224/4z68GFZuYKU/00028
468
+ id03839/ajkGXKUvTWY/00296 id03839/ajkGXKUvTWY/00296
469
+ id07874/N7fMpS_yaF4/00047 id07874/N7fMpS_yaF4/00047
470
+ id05124/fRhAX7v_R6A/00365 id05124/fRhAX7v_R6A/00365
471
+ id02181/ci_22Oqhwtc/00088 id02181/ci_22Oqhwtc/00088
472
+ id07414/njxmqS9ncTA/00399 id07414/njxmqS9ncTA/00399
473
+ id05176/yEMRxKA0vSw/00101 id05176/yEMRxKA0vSw/00101
474
+ id03862/VVaxYHNmtA8/00269 id03862/VVaxYHNmtA8/00269
475
+ id07396/X6KkvYh6rPA/00148 id07396/X6KkvYh6rPA/00148
476
+ id06310/TkxTnoic67U/00130 id06310/TkxTnoic67U/00130
477
+ id08374/Yh9O9ETuF_0/00250 id08374/Yh9O9ETuF_0/00250
478
+ id02317/5moKZXlJTEs/00058 id02317/5moKZXlJTEs/00058
479
+ id04536/EDCwhtRFARA/00172 id04536/EDCwhtRFARA/00172
480
+ id03789/pz1jGMsPY9M/00381 id03789/pz1jGMsPY9M/00381
481
+ id03127/wzS06bKAZ48/00354 id03127/wzS06bKAZ48/00354
482
+ id08911/wedpC4fN4YY/00096 id08911/wedpC4fN4YY/00096
483
+ id01106/6SFpvp42pMA/00014 id01106/6SFpvp42pMA/00014
484
+ id02465/6jp5YsZYtHI/00021 id02465/6jp5YsZYtHI/00021
485
+ id01618/Ay_BKx5-JOc/00046 id01618/Ay_BKx5-JOc/00046
486
+ id04478/x07vvSVm2Yo/00363 id04478/x07vvSVm2Yo/00363
487
+ id01593/u5AgUWl3fFU/00437 id01593/u5AgUWl3fFU/00437
488
+ id03030/IpwcoJajjJI/00124 id03030/IpwcoJajjJI/00124
489
+ id01593/t9TUbyp3xfs/00423 id01593/t9TUbyp3xfs/00423
490
+ id07414/hUxcsEMKssA/00320 id07414/hUxcsEMKssA/00320
491
+ id04366/L-56A5RNeWg/00124 id04366/L-56A5RNeWg/00124
492
+ id07961/3EPjXGhfst4/00001 id07961/3EPjXGhfst4/00001
493
+ id00061/mMOd25Ag7XY/00239 id00061/mMOd25Ag7XY/00239
494
+ id01567/RQMG0K5AchU/00218 id01567/RQMG0K5AchU/00218
495
+ id08552/PL5vk3XeKRM/00114 id08552/PL5vk3XeKRM/00114
496
+ id04862/eX3wAZ0yr7w/00260 id04862/eX3wAZ0yr7w/00260
497
+ id02086/CBNOvx4Phxw/00146 id02086/CBNOvx4Phxw/00146
498
+ id01228/3wAkCYQR3fQ/00011 id01228/3wAkCYQR3fQ/00011
499
+ id06484/MXwPpo1Dg7U/00073 id06484/MXwPpo1Dg7U/00073
500
+ id01460/9fJy9zGdESI/00045 id01460/9fJy9zGdESI/00045
dataset/filelists/voxceleb2_test_n_500_seed_797_cross.txt ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ id05459/18XmQEiGLnQ/00001 id07961/3EPjXGhfst4/00001
2
+ id03980/7MRUusImkno/00001 id08696/0H1PxInJCK0/00001
3
+ id05654/07pANazoyJg/00001 id04570/0YMGn6BI9rg/00001
4
+ id00817/0GmSijZelGY/00001 id07354/0NjekFZqaY0/00001
5
+ id05202/2gnLcAbAoSc/00001 id00817/0GmSijZelGY/00001
6
+ id03041/5CfnYwQCW48/00001 id07354/0NjekFZqaY0/00001
7
+ id03980/7MRUusImkno/00001 id07621/0CiFdFegqZM/00001
8
+ id05850/B8kp8ed48JE/00001 id04253/1HOlzefgLu8/00001
9
+ id01298/2K5F6xG-Rbs/00001 id05816/1dyCBbJ94iw/00001
10
+ id07494/0P1wPmgz0Bk/00001 id07621/0CiFdFegqZM/00001
11
+ id06913/4Ug7aJemzpg/00001 id04030/7mXUMuo5_NE/00001
12
+ id02286/4LAIxvdvguc/00001 id05850/B8kp8ed48JE/00001
13
+ id02548/0pAkJZmlFqc/00001 id05459/18XmQEiGLnQ/00001
14
+ id08456/29EhSZDqzas/00001 id04295/1fSjOItVYVg/00001
15
+ id04295/1fSjOItVYVg/00001 id02685/4JDRxqYC0a4/00001
16
+ id04276/5M8NmCwTHZ0/00001 id05654/07pANazoyJg/00001
17
+ id03030/5wOxV1wAgqA/00001 id03041/5CfnYwQCW48/00001
18
+ id04656/1tZYt8jey54/00001 id07961/3EPjXGhfst4/00001
19
+ id03980/7MRUusImkno/00001 id04536/0f_Yi_1CoeM/00001
20
+ id03980/7MRUusImkno/00001 id05202/2gnLcAbAoSc/00001
21
+ id01298/2K5F6xG-Rbs/00001 id04862/0zJh2FMTaDE/00001
22
+ id02548/0pAkJZmlFqc/00001 id04478/2grMtwdG93I/00001
23
+ id02685/4JDRxqYC0a4/00001 id01892/3vKPgjwFjbo/00001
24
+ id07494/0P1wPmgz0Bk/00001 id04656/1tZYt8jey54/00001
25
+ id00812/1Xfgvdu7oDo/00001 id00926/2Nd7f1yNQzE/00001
26
+ id07426/1KNFfOFEhyI/00001 id03030/5wOxV1wAgqA/00001
27
+ id00866/03SSllwNkGk/00001 id00812/1Xfgvdu7oDo/00001
28
+ id04570/0YMGn6BI9rg/00001 id01892/3vKPgjwFjbo/00001
29
+ id03041/5CfnYwQCW48/00001 id04030/7mXUMuo5_NE/00001
30
+ id07494/0P1wPmgz0Bk/00001 id00081/2xYrsnvtUWc/00001
31
+ id08392/0fwuibKviJU/00001 id05015/0Cu3AvWWOFI/00001
32
+ id06692/2ptBBNIZXtI/00001 id04536/0f_Yi_1CoeM/00001
33
+ id04253/1HOlzefgLu8/00001 id06104/02L1L9RFAgI/00001
34
+ id02725/37kUrf6RJdw/00001 id02685/4JDRxqYC0a4/00001
35
+ id04006/113VkmVVz1Q/00001 id04119/1uH67UruKlE/00001
36
+ id01567/1Lx_ZqrK1bM/00001 id04030/7mXUMuo5_NE/00001
37
+ id02445/3Rnk8eja3TU/00001 id05816/1dyCBbJ94iw/00001
38
+ id03041/5CfnYwQCW48/00001 id03347/4xXZ75_TeSM/00001
39
+ id04570/0YMGn6BI9rg/00001 id02317/0q4X8kPTlEY/00001
40
+ id07426/1KNFfOFEhyI/00001 id01822/0QcHowaLAF0/00001
41
+ id02577/0euHS_r5JH4/00001 id02725/37kUrf6RJdw/00001
42
+ id07354/0NjekFZqaY0/00001 id05459/18XmQEiGLnQ/00001
43
+ id06692/2ptBBNIZXtI/00001 id05850/B8kp8ed48JE/00001
44
+ id01822/0QcHowaLAF0/00001 id07961/3EPjXGhfst4/00001
45
+ id04366/0iG2Ub9zETM/00001 id03347/4xXZ75_TeSM/00001
46
+ id03030/5wOxV1wAgqA/00001 id03789/0kdVSujPa9g/00001
47
+ id04366/0iG2Ub9zETM/00001 id02286/4LAIxvdvguc/00001
48
+ id00926/2Nd7f1yNQzE/00001 id02548/0pAkJZmlFqc/00001
49
+ id03030/5wOxV1wAgqA/00001 id01298/2K5F6xG-Rbs/00001
50
+ id01892/3vKPgjwFjbo/00001 id02317/0q4X8kPTlEY/00001
51
+ id05202/2gnLcAbAoSc/00001 id04253/1HOlzefgLu8/00001
52
+ id05714/2gvpaZcvAY4/00001 id06692/2ptBBNIZXtI/00001
53
+ id07621/0CiFdFegqZM/00001 id07802/0RUpqvi3sPU/00001
54
+ id03030/5wOxV1wAgqA/00001 id02317/0q4X8kPTlEY/00001
55
+ id01822/0QcHowaLAF0/00001 id02445/3Rnk8eja3TU/00001
56
+ id07961/3EPjXGhfst4/00001 id05459/18XmQEiGLnQ/00001
57
+ id04950/2n4sGPqU9M8/00001 id07426/1KNFfOFEhyI/00001
58
+ id04862/0zJh2FMTaDE/00001 id02465/0Ocu8l1eAng/00001
59
+ id06104/02L1L9RFAgI/00001 id07312/0LWllHGohPY/00001
60
+ id07414/110UMQovTR0/00001 id06692/2ptBBNIZXtI/00001
61
+ id05015/0Cu3AvWWOFI/00001 id08696/0H1PxInJCK0/00001
62
+ id02181/02gIO4WrZLY/00001 id00812/1Xfgvdu7oDo/00001
63
+ id08392/0fwuibKviJU/00001 id01041/1UYZqPpavtk/00001
64
+ id03347/4xXZ75_TeSM/00001 id04950/2n4sGPqU9M8/00001
65
+ id07312/0LWllHGohPY/00001 id04950/2n4sGPqU9M8/00001
66
+ id05202/2gnLcAbAoSc/00001 id05654/07pANazoyJg/00001
67
+ id01041/1UYZqPpavtk/00001 id02317/0q4X8kPTlEY/00001
68
+ id02057/0xZU7Oi9nvM/00001 id03178/2CT-6fnBC_o/00001
69
+ id04006/113VkmVVz1Q/00001 id00817/0GmSijZelGY/00001
70
+ id05850/B8kp8ed48JE/00001 id01892/3vKPgjwFjbo/00001
71
+ id08696/0H1PxInJCK0/00001 id06692/2ptBBNIZXtI/00001
72
+ id02057/0xZU7Oi9nvM/00001 id01541/2P7hzPq5iDw/00001
73
+ id04006/113VkmVVz1Q/00001 id02057/0xZU7Oi9nvM/00001
74
+ id04276/5M8NmCwTHZ0/00001 id04570/0YMGn6BI9rg/00001
75
+ id07868/5YYJq3fSbH8/00001 id03030/5wOxV1wAgqA/00001
76
+ id00812/1Xfgvdu7oDo/00001 id00154/0hjW3eTGAy8/00001
77
+ id06692/2ptBBNIZXtI/00001 id05594/0ohBiepcHWI/00001
78
+ id04536/0f_Yi_1CoeM/00001 id05202/2gnLcAbAoSc/00001
79
+ id06310/1IAgr_CRnuE/00001 id05816/1dyCBbJ94iw/00001
80
+ id01541/2P7hzPq5iDw/00001 id00419/1zffAxBod_c/00001
81
+ id07354/0NjekFZqaY0/00001 id00866/03SSllwNkGk/00001
82
+ id03347/4xXZ75_TeSM/00001 id02577/0euHS_r5JH4/00001
83
+ id04119/1uH67UruKlE/00001 id04006/113VkmVVz1Q/00001
84
+ id05714/2gvpaZcvAY4/00001 id07961/3EPjXGhfst4/00001
85
+ id06104/02L1L9RFAgI/00001 id03178/2CT-6fnBC_o/00001
86
+ id07354/0NjekFZqaY0/00001 id02445/3Rnk8eja3TU/00001
87
+ id04030/7mXUMuo5_NE/00001 id03030/5wOxV1wAgqA/00001
88
+ id07312/0LWllHGohPY/00001 id04536/0f_Yi_1CoeM/00001
89
+ id03839/1jWHvl2qCq0/00001 id07802/0RUpqvi3sPU/00001
90
+ id07621/0CiFdFegqZM/00001 id05816/1dyCBbJ94iw/00001
91
+ id03839/1jWHvl2qCq0/00001 id03980/7MRUusImkno/00001
92
+ id03030/5wOxV1wAgqA/00001 id02445/3Rnk8eja3TU/00001
93
+ id03862/0w8W8jp7MJk/00001 id04253/1HOlzefgLu8/00001
94
+ id05714/2gvpaZcvAY4/00001 id04119/1uH67UruKlE/00001
95
+ id08392/0fwuibKviJU/00001 id07868/5YYJq3fSbH8/00001
96
+ id01298/2K5F6xG-Rbs/00001 id03030/5wOxV1wAgqA/00001
97
+ id05459/18XmQEiGLnQ/00001 id00817/0GmSijZelGY/00001
98
+ id05850/B8kp8ed48JE/00001 id06692/2ptBBNIZXtI/00001
99
+ id04295/1fSjOItVYVg/00001 id08456/29EhSZDqzas/00001
100
+ id04570/0YMGn6BI9rg/00001 id02057/0xZU7Oi9nvM/00001
101
+ id01541/2P7hzPq5iDw/00001 id00817/0GmSijZelGY/00001
102
+ id07426/1KNFfOFEhyI/00001 id07354/0NjekFZqaY0/00001
103
+ id04253/1HOlzefgLu8/00001 id06209/2zM9EAPsZZQ/00001
104
+ id05850/B8kp8ed48JE/00001 id08392/0fwuibKviJU/00001
105
+ id07802/0RUpqvi3sPU/00001 id02465/0Ocu8l1eAng/00001
106
+ id04119/1uH67UruKlE/00001 id04862/0zJh2FMTaDE/00001
107
+ id01541/2P7hzPq5iDw/00001 id08696/0H1PxInJCK0/00001
108
+ id08696/0H1PxInJCK0/00001 id07802/0RUpqvi3sPU/00001
109
+ id01228/2TIFacjgehY/00001 id07621/0CiFdFegqZM/00001
110
+ id03178/2CT-6fnBC_o/00001 id07868/5YYJq3fSbH8/00001
111
+ id05654/07pANazoyJg/00001 id01298/2K5F6xG-Rbs/00001
112
+ id01822/0QcHowaLAF0/00001 id02548/0pAkJZmlFqc/00001
113
+ id01618/0iFlmfmWVlY/00001 id08696/0H1PxInJCK0/00001
114
+ id00812/1Xfgvdu7oDo/00001 id08456/29EhSZDqzas/00001
115
+ id05594/0ohBiepcHWI/00001 id07312/0LWllHGohPY/00001
116
+ id05714/2gvpaZcvAY4/00001 id06104/02L1L9RFAgI/00001
117
+ id02445/3Rnk8eja3TU/00001 id07426/1KNFfOFEhyI/00001
118
+ id05714/2gvpaZcvAY4/00001 id00817/0GmSijZelGY/00001
119
+ id08696/0H1PxInJCK0/00001 id02317/0q4X8kPTlEY/00001
120
+ id04950/2n4sGPqU9M8/00001 id04478/2grMtwdG93I/00001
121
+ id01228/2TIFacjgehY/00001 id07414/110UMQovTR0/00001
122
+ id00926/2Nd7f1yNQzE/00001 id01541/2P7hzPq5iDw/00001
123
+ id05714/2gvpaZcvAY4/00001 id06913/4Ug7aJemzpg/00001
124
+ id01228/2TIFacjgehY/00001 id03862/0w8W8jp7MJk/00001
125
+ id03030/5wOxV1wAgqA/00001 id05015/0Cu3AvWWOFI/00001
126
+ id02548/0pAkJZmlFqc/00001 id06692/2ptBBNIZXtI/00001
127
+ id05202/2gnLcAbAoSc/00001 id04119/1uH67UruKlE/00001
128
+ id04656/1tZYt8jey54/00001 id07426/1KNFfOFEhyI/00001
129
+ id07312/0LWllHGohPY/00001 id03980/7MRUusImkno/00001
130
+ id04366/0iG2Ub9zETM/00001 id00817/0GmSijZelGY/00001
131
+ id07961/3EPjXGhfst4/00001 id01228/2TIFacjgehY/00001
132
+ id00154/0hjW3eTGAy8/00001 id04295/1fSjOItVYVg/00001
133
+ id04478/2grMtwdG93I/00001 id00154/0hjW3eTGAy8/00001
134
+ id04570/0YMGn6BI9rg/00001 id05202/2gnLcAbAoSc/00001
135
+ id04478/2grMtwdG93I/00001 id06913/4Ug7aJemzpg/00001
136
+ id06104/02L1L9RFAgI/00001 id04295/1fSjOItVYVg/00001
137
+ id05816/1dyCBbJ94iw/00001 id08392/0fwuibKviJU/00001
138
+ id00926/2Nd7f1yNQzE/00001 id04536/0f_Yi_1CoeM/00001
139
+ id00926/2Nd7f1yNQzE/00001 id02181/02gIO4WrZLY/00001
140
+ id05459/18XmQEiGLnQ/00001 id02317/0q4X8kPTlEY/00001
141
+ id05594/0ohBiepcHWI/00001 id01228/2TIFacjgehY/00001
142
+ id02181/02gIO4WrZLY/00001 id07312/0LWllHGohPY/00001
143
+ id00154/0hjW3eTGAy8/00001 id03839/1jWHvl2qCq0/00001
144
+ id04030/7mXUMuo5_NE/00001 id02725/37kUrf6RJdw/00001
145
+ id04295/1fSjOItVYVg/00001 id05714/2gvpaZcvAY4/00001
146
+ id02548/0pAkJZmlFqc/00001 id04570/0YMGn6BI9rg/00001
147
+ id04478/2grMtwdG93I/00001 id00866/03SSllwNkGk/00001
148
+ id03030/5wOxV1wAgqA/00001 id04366/0iG2Ub9zETM/00001
149
+ id02685/4JDRxqYC0a4/00001 id07426/1KNFfOFEhyI/00001
150
+ id07802/0RUpqvi3sPU/00001 id07312/0LWllHGohPY/00001
151
+ id02317/0q4X8kPTlEY/00001 id01892/3vKPgjwFjbo/00001
152
+ id00154/0hjW3eTGAy8/00001 id00866/03SSllwNkGk/00001
153
+ id02181/02gIO4WrZLY/00001 id02685/4JDRxqYC0a4/00001
154
+ id03178/2CT-6fnBC_o/00001 id05459/18XmQEiGLnQ/00001
155
+ id00926/2Nd7f1yNQzE/00001 id05202/2gnLcAbAoSc/00001
156
+ id03041/5CfnYwQCW48/00001 id03178/2CT-6fnBC_o/00001
157
+ id05850/B8kp8ed48JE/00001 id04006/113VkmVVz1Q/00001
158
+ id01822/0QcHowaLAF0/00001 id04570/0YMGn6BI9rg/00001
159
+ id04478/2grMtwdG93I/00001 id03839/1jWHvl2qCq0/00001
160
+ id01298/2K5F6xG-Rbs/00001 id01228/2TIFacjgehY/00001
161
+ id06310/1IAgr_CRnuE/00001 id04006/113VkmVVz1Q/00001
162
+ id00154/0hjW3eTGAy8/00001 id04006/113VkmVVz1Q/00001
163
+ id05816/1dyCBbJ94iw/00001 id01041/1UYZqPpavtk/00001
164
+ id04570/0YMGn6BI9rg/00001 id04862/0zJh2FMTaDE/00001
165
+ id06913/4Ug7aJemzpg/00001 id04862/0zJh2FMTaDE/00001
166
+ id03862/0w8W8jp7MJk/00001 id02465/0Ocu8l1eAng/00001
167
+ id04253/1HOlzefgLu8/00001 id01567/1Lx_ZqrK1bM/00001
168
+ id06209/2zM9EAPsZZQ/00001 id01298/2K5F6xG-Rbs/00001
169
+ id01822/0QcHowaLAF0/00001 id01541/2P7hzPq5iDw/00001
170
+ id07312/0LWllHGohPY/00001 id02317/0q4X8kPTlEY/00001
171
+ id06692/2ptBBNIZXtI/00001 id02445/3Rnk8eja3TU/00001
172
+ id07414/110UMQovTR0/00001 id00154/0hjW3eTGAy8/00001
173
+ id04478/2grMtwdG93I/00001 id03347/4xXZ75_TeSM/00001
174
+ id04656/1tZYt8jey54/00001 id07802/0RUpqvi3sPU/00001
175
+ id03839/1jWHvl2qCq0/00001 id06310/1IAgr_CRnuE/00001
176
+ id02057/0xZU7Oi9nvM/00001 id01228/2TIFacjgehY/00001
177
+ id00081/2xYrsnvtUWc/00001 id02057/0xZU7Oi9nvM/00001
178
+ id03862/0w8W8jp7MJk/00001 id01892/3vKPgjwFjbo/00001
179
+ id04570/0YMGn6BI9rg/00001 id06913/4Ug7aJemzpg/00001
180
+ id08392/0fwuibKviJU/00001 id01567/1Lx_ZqrK1bM/00001
181
+ id00081/2xYrsnvtUWc/00001 id07494/0P1wPmgz0Bk/00001
182
+ id04536/0f_Yi_1CoeM/00001 id00081/2xYrsnvtUWc/00001
183
+ id03839/1jWHvl2qCq0/00001 id05850/B8kp8ed48JE/00001
184
+ id07621/0CiFdFegqZM/00001 id08456/29EhSZDqzas/00001
185
+ id01822/0QcHowaLAF0/00001 id07868/5YYJq3fSbH8/00001
186
+ id05202/2gnLcAbAoSc/00001 id03178/2CT-6fnBC_o/00001
187
+ id06692/2ptBBNIZXtI/00001 id06913/4Ug7aJemzpg/00001
188
+ id01041/1UYZqPpavtk/00001 id03030/5wOxV1wAgqA/00001
189
+ id07426/1KNFfOFEhyI/00001 id08456/29EhSZDqzas/00001
190
+ id04478/2grMtwdG93I/00001 id02548/0pAkJZmlFqc/00001
191
+ id08392/0fwuibKviJU/00001 id01298/2K5F6xG-Rbs/00001
192
+ id03041/5CfnYwQCW48/00001 id08696/0H1PxInJCK0/00001
193
+ id04366/0iG2Ub9zETM/00001 id07426/1KNFfOFEhyI/00001
194
+ id04950/2n4sGPqU9M8/00001 id07494/0P1wPmgz0Bk/00001
195
+ id01822/0QcHowaLAF0/00001 id08392/0fwuibKviJU/00001
196
+ id02577/0euHS_r5JH4/00001 id06692/2ptBBNIZXtI/00001
197
+ id04570/0YMGn6BI9rg/00001 id00866/03SSllwNkGk/00001
198
+ id05850/B8kp8ed48JE/00001 id08456/29EhSZDqzas/00001
199
+ id01618/0iFlmfmWVlY/00001 id01041/1UYZqPpavtk/00001
200
+ id07414/110UMQovTR0/00001 id04536/0f_Yi_1CoeM/00001
201
+ id02057/0xZU7Oi9nvM/00001 id06913/4Ug7aJemzpg/00001
202
+ id04536/0f_Yi_1CoeM/00001 id01041/1UYZqPpavtk/00001
203
+ id04030/7mXUMuo5_NE/00001 id05850/B8kp8ed48JE/00001
204
+ id04656/1tZYt8jey54/00001 id05459/18XmQEiGLnQ/00001
205
+ id03789/0kdVSujPa9g/00001 id02057/0xZU7Oi9nvM/00001
206
+ id01041/1UYZqPpavtk/00001 id05594/0ohBiepcHWI/00001
207
+ id07494/0P1wPmgz0Bk/00001 id04006/113VkmVVz1Q/00001
208
+ id00812/1Xfgvdu7oDo/00001 id04295/1fSjOItVYVg/00001
209
+ id01541/2P7hzPq5iDw/00001 id02465/0Ocu8l1eAng/00001
210
+ id04862/0zJh2FMTaDE/00001 id05594/0ohBiepcHWI/00001
211
+ id05714/2gvpaZcvAY4/00001 id02286/4LAIxvdvguc/00001
212
+ id06209/2zM9EAPsZZQ/00001 id05816/1dyCBbJ94iw/00001
213
+ id05850/B8kp8ed48JE/00001 id00866/03SSllwNkGk/00001
214
+ id07494/0P1wPmgz0Bk/00001 id07312/0LWllHGohPY/00001
215
+ id04366/0iG2Ub9zETM/00001 id04570/0YMGn6BI9rg/00001
216
+ id00866/03SSllwNkGk/00001 id03347/4xXZ75_TeSM/00001
217
+ id02445/3Rnk8eja3TU/00001 id07802/0RUpqvi3sPU/00001
218
+ id08696/0H1PxInJCK0/00001 id06209/2zM9EAPsZZQ/00001
219
+ id02445/3Rnk8eja3TU/00001 id07621/0CiFdFegqZM/00001
220
+ id08392/0fwuibKviJU/00001 id05850/B8kp8ed48JE/00001
221
+ id00419/1zffAxBod_c/00001 id01228/2TIFacjgehY/00001
222
+ id07354/0NjekFZqaY0/00001 id01041/1UYZqPpavtk/00001
223
+ id04570/0YMGn6BI9rg/00001 id03347/4xXZ75_TeSM/00001
224
+ id01892/3vKPgjwFjbo/00001 id02445/3Rnk8eja3TU/00001
225
+ id00081/2xYrsnvtUWc/00001 id05459/18XmQEiGLnQ/00001
226
+ id06104/02L1L9RFAgI/00001 id04570/0YMGn6BI9rg/00001
227
+ id07961/3EPjXGhfst4/00001 id05654/07pANazoyJg/00001
228
+ id00926/2Nd7f1yNQzE/00001 id03839/1jWHvl2qCq0/00001
229
+ id02181/02gIO4WrZLY/00001 id08696/0H1PxInJCK0/00001
230
+ id07426/1KNFfOFEhyI/00001 id05459/18XmQEiGLnQ/00001
231
+ id03041/5CfnYwQCW48/00001 id06104/02L1L9RFAgI/00001
232
+ id01298/2K5F6xG-Rbs/00001 id01541/2P7hzPq5iDw/00001
233
+ id04570/0YMGn6BI9rg/00001 id01618/0iFlmfmWVlY/00001
234
+ id02685/4JDRxqYC0a4/00001 id02548/0pAkJZmlFqc/00001
235
+ id01822/0QcHowaLAF0/00001 id07426/1KNFfOFEhyI/00001
236
+ id07868/5YYJq3fSbH8/00001 id07494/0P1wPmgz0Bk/00001
237
+ id07802/0RUpqvi3sPU/00001 id03041/5CfnYwQCW48/00001
238
+ id04656/1tZYt8jey54/00001 id01541/2P7hzPq5iDw/00001
239
+ id03347/4xXZ75_TeSM/00001 id02445/3Rnk8eja3TU/00001
240
+ id02548/0pAkJZmlFqc/00001 id01298/2K5F6xG-Rbs/00001
241
+ id07354/0NjekFZqaY0/00001 id07426/1KNFfOFEhyI/00001
242
+ id03862/0w8W8jp7MJk/00001 id01298/2K5F6xG-Rbs/00001
243
+ id04536/0f_Yi_1CoeM/00001 id02465/0Ocu8l1eAng/00001
244
+ id00081/2xYrsnvtUWc/00001 id04366/0iG2Ub9zETM/00001
245
+ id04950/2n4sGPqU9M8/00001 id01822/0QcHowaLAF0/00001
246
+ id06692/2ptBBNIZXtI/00001 id03030/5wOxV1wAgqA/00001
247
+ id07312/0LWllHGohPY/00001 id04478/2grMtwdG93I/00001
248
+ id03862/0w8W8jp7MJk/00001 id03030/5wOxV1wAgqA/00001
249
+ id00081/2xYrsnvtUWc/00001 id08392/0fwuibKviJU/00001
250
+ id02317/0q4X8kPTlEY/00001 id00154/0hjW3eTGAy8/00001
251
+ id05594/0ohBiepcHWI/00001 id04536/0f_Yi_1CoeM/00001
252
+ id07868/5YYJq3fSbH8/00001 id03839/1jWHvl2qCq0/00001
253
+ id02577/0euHS_r5JH4/00001 id06913/4Ug7aJemzpg/00001
254
+ id08456/29EhSZDqzas/00001 id01541/2P7hzPq5iDw/00001
255
+ id01567/1Lx_ZqrK1bM/00001 id04119/1uH67UruKlE/00001
256
+ id04253/1HOlzefgLu8/00001 id01228/2TIFacjgehY/00001
257
+ id02445/3Rnk8eja3TU/00001 id02685/4JDRxqYC0a4/00001
258
+ id05015/0Cu3AvWWOFI/00001 id02465/0Ocu8l1eAng/00001
259
+ id07494/0P1wPmgz0Bk/00001 id05714/2gvpaZcvAY4/00001
260
+ id02548/0pAkJZmlFqc/00001 id04006/113VkmVVz1Q/00001
261
+ id00866/03SSllwNkGk/00001 id02317/0q4X8kPTlEY/00001
262
+ id07354/0NjekFZqaY0/00001 id04253/1HOlzefgLu8/00001
263
+ id00812/1Xfgvdu7oDo/00001 id03030/5wOxV1wAgqA/00001
264
+ id02465/0Ocu8l1eAng/00001 id07354/0NjekFZqaY0/00001
265
+ id04276/5M8NmCwTHZ0/00001 id03862/0w8W8jp7MJk/00001
266
+ id01567/1Lx_ZqrK1bM/00001 id04253/1HOlzefgLu8/00001
267
+ id01618/0iFlmfmWVlY/00001 id06913/4Ug7aJemzpg/00001
268
+ id03862/0w8W8jp7MJk/00001 id08392/0fwuibKviJU/00001
269
+ id07961/3EPjXGhfst4/00001 id00154/0hjW3eTGAy8/00001
270
+ id02577/0euHS_r5JH4/00001 id01228/2TIFacjgehY/00001
271
+ id05654/07pANazoyJg/00001 id03041/5CfnYwQCW48/00001
272
+ id03980/7MRUusImkno/00001 id08392/0fwuibKviJU/00001
273
+ id03178/2CT-6fnBC_o/00001 id04295/1fSjOItVYVg/00001
274
+ id02317/0q4X8kPTlEY/00001 id03347/4xXZ75_TeSM/00001
275
+ id02548/0pAkJZmlFqc/00001 id07426/1KNFfOFEhyI/00001
276
+ id03839/1jWHvl2qCq0/00001 id05654/07pANazoyJg/00001
277
+ id02548/0pAkJZmlFqc/00001 id07868/5YYJq3fSbH8/00001
278
+ id04570/0YMGn6BI9rg/00001 id01041/1UYZqPpavtk/00001
279
+ id07414/110UMQovTR0/00001 id00419/1zffAxBod_c/00001
280
+ id00154/0hjW3eTGAy8/00001 id01618/0iFlmfmWVlY/00001
281
+ id07494/0P1wPmgz0Bk/00001 id05654/07pANazoyJg/00001
282
+ id01822/0QcHowaLAF0/00001 id06310/1IAgr_CRnuE/00001
283
+ id05015/0Cu3AvWWOFI/00001 id05459/18XmQEiGLnQ/00001
284
+ id05816/1dyCBbJ94iw/00001 id02317/0q4X8kPTlEY/00001
285
+ id01541/2P7hzPq5iDw/00001 id05816/1dyCBbJ94iw/00001
286
+ id06104/02L1L9RFAgI/00001 id01892/3vKPgjwFjbo/00001
287
+ id04862/0zJh2FMTaDE/00001 id05850/B8kp8ed48JE/00001
288
+ id05202/2gnLcAbAoSc/00001 id04366/0iG2Ub9zETM/00001
289
+ id02286/4LAIxvdvguc/00001 id02725/37kUrf6RJdw/00001
290
+ id04276/5M8NmCwTHZ0/00001 id01541/2P7hzPq5iDw/00001
291
+ id02057/0xZU7Oi9nvM/00001 id03862/0w8W8jp7MJk/00001
292
+ id06104/02L1L9RFAgI/00001 id00419/1zffAxBod_c/00001
293
+ id04950/2n4sGPqU9M8/00001 id02181/02gIO4WrZLY/00001
294
+ id04478/2grMtwdG93I/00001 id02685/4JDRxqYC0a4/00001
295
+ id04006/113VkmVVz1Q/00001 id00081/2xYrsnvtUWc/00001
296
+ id06692/2ptBBNIZXtI/00001 id03347/4xXZ75_TeSM/00001
297
+ id03030/5wOxV1wAgqA/00001 id02465/0Ocu8l1eAng/00001
298
+ id07312/0LWllHGohPY/00001 id03839/1jWHvl2qCq0/00001
299
+ id04950/2n4sGPqU9M8/00001 id05654/07pANazoyJg/00001
300
+ id02465/0Ocu8l1eAng/00001 id01618/0iFlmfmWVlY/00001
301
+ id00419/1zffAxBod_c/00001 id02181/02gIO4WrZLY/00001
302
+ id07426/1KNFfOFEhyI/00001 id05202/2gnLcAbAoSc/00001
303
+ id07621/0CiFdFegqZM/00001 id08696/0H1PxInJCK0/00001
304
+ id04006/113VkmVVz1Q/00001 id08392/0fwuibKviJU/00001
305
+ id04478/2grMtwdG93I/00001 id02445/3Rnk8eja3TU/00001
306
+ id03347/4xXZ75_TeSM/00001 id00154/0hjW3eTGAy8/00001
307
+ id07312/0LWllHGohPY/00001 id02181/02gIO4WrZLY/00001
308
+ id06310/1IAgr_CRnuE/00001 id02057/0xZU7Oi9nvM/00001
309
+ id04366/0iG2Ub9zETM/00001 id05654/07pANazoyJg/00001
310
+ id00419/1zffAxBod_c/00001 id04570/0YMGn6BI9rg/00001
311
+ id04862/0zJh2FMTaDE/00001 id03862/0w8W8jp7MJk/00001
312
+ id04366/0iG2Ub9zETM/00001 id00154/0hjW3eTGAy8/00001
313
+ id00866/03SSllwNkGk/00001 id00081/2xYrsnvtUWc/00001
314
+ id01618/0iFlmfmWVlY/00001 id02725/37kUrf6RJdw/00001
315
+ id01892/3vKPgjwFjbo/00001 id07621/0CiFdFegqZM/00001
316
+ id05015/0Cu3AvWWOFI/00001 id00926/2Nd7f1yNQzE/00001
317
+ id06913/4Ug7aJemzpg/00001 id03839/1jWHvl2qCq0/00001
318
+ id07312/0LWllHGohPY/00001 id07802/0RUpqvi3sPU/00001
319
+ id06104/02L1L9RFAgI/00001 id02465/0Ocu8l1eAng/00001
320
+ id04295/1fSjOItVYVg/00001 id01298/2K5F6xG-Rbs/00001
321
+ id00866/03SSllwNkGk/00001 id05714/2gvpaZcvAY4/00001
322
+ id06104/02L1L9RFAgI/00001 id01541/2P7hzPq5iDw/00001
323
+ id02445/3Rnk8eja3TU/00001 id03789/0kdVSujPa9g/00001
324
+ id00081/2xYrsnvtUWc/00001 id05816/1dyCBbJ94iw/00001
325
+ id02548/0pAkJZmlFqc/00001 id03030/5wOxV1wAgqA/00001
326
+ id04276/5M8NmCwTHZ0/00001 id01041/1UYZqPpavtk/00001
327
+ id06913/4Ug7aJemzpg/00001 id07868/5YYJq3fSbH8/00001
328
+ id04656/1tZYt8jey54/00001 id06692/2ptBBNIZXtI/00001
329
+ id07494/0P1wPmgz0Bk/00001 id08696/0H1PxInJCK0/00001
330
+ id04119/1uH67UruKlE/00001 id02317/0q4X8kPTlEY/00001
331
+ id00419/1zffAxBod_c/00001 id04862/0zJh2FMTaDE/00001
332
+ id03862/0w8W8jp7MJk/00001 id02445/3Rnk8eja3TU/00001
333
+ id01892/3vKPgjwFjbo/00001 id04862/0zJh2FMTaDE/00001
334
+ id04950/2n4sGPqU9M8/00001 id01618/0iFlmfmWVlY/00001
335
+ id01228/2TIFacjgehY/00001 id01298/2K5F6xG-Rbs/00001
336
+ id01041/1UYZqPpavtk/00001 id07961/3EPjXGhfst4/00001
337
+ id07802/0RUpqvi3sPU/00001 id06913/4Ug7aJemzpg/00001
338
+ id04276/5M8NmCwTHZ0/00001 id03030/5wOxV1wAgqA/00001
339
+ id01567/1Lx_ZqrK1bM/00001 id05459/18XmQEiGLnQ/00001
340
+ id02465/0Ocu8l1eAng/00001 id02725/37kUrf6RJdw/00001
341
+ id05816/1dyCBbJ94iw/00001 id02181/02gIO4WrZLY/00001
342
+ id06913/4Ug7aJemzpg/00001 id04950/2n4sGPqU9M8/00001
343
+ id04276/5M8NmCwTHZ0/00001 id04253/1HOlzefgLu8/00001
344
+ id07414/110UMQovTR0/00001 id06209/2zM9EAPsZZQ/00001
345
+ id06310/1IAgr_CRnuE/00001 id03839/1jWHvl2qCq0/00001
346
+ id03347/4xXZ75_TeSM/00001 id04006/113VkmVVz1Q/00001
347
+ id01541/2P7hzPq5iDw/00001 id04253/1HOlzefgLu8/00001
348
+ id08456/29EhSZDqzas/00001 id07494/0P1wPmgz0Bk/00001
349
+ id07621/0CiFdFegqZM/00001 id05594/0ohBiepcHWI/00001
350
+ id02685/4JDRxqYC0a4/00001 id04536/0f_Yi_1CoeM/00001
351
+ id02317/0q4X8kPTlEY/00001 id08696/0H1PxInJCK0/00001
352
+ id04253/1HOlzefgLu8/00001 id01041/1UYZqPpavtk/00001
353
+ id01041/1UYZqPpavtk/00001 id03178/2CT-6fnBC_o/00001
354
+ id05654/07pANazoyJg/00001 id01892/3vKPgjwFjbo/00001
355
+ id04862/0zJh2FMTaDE/00001 id06310/1IAgr_CRnuE/00001
356
+ id01541/2P7hzPq5iDw/00001 id04478/2grMtwdG93I/00001
357
+ id02445/3Rnk8eja3TU/00001 id02057/0xZU7Oi9nvM/00001
358
+ id08392/0fwuibKviJU/00001 id04570/0YMGn6BI9rg/00001
359
+ id06692/2ptBBNIZXtI/00001 id02057/0xZU7Oi9nvM/00001
360
+ id04950/2n4sGPqU9M8/00001 id04862/0zJh2FMTaDE/00001
361
+ id03862/0w8W8jp7MJk/00001 id07621/0CiFdFegqZM/00001
362
+ id07312/0LWllHGohPY/00001 id04656/1tZYt8jey54/00001
363
+ id02577/0euHS_r5JH4/00001 id00866/03SSllwNkGk/00001
364
+ id01228/2TIFacjgehY/00001 id02685/4JDRxqYC0a4/00001
365
+ id00081/2xYrsnvtUWc/00001 id00419/1zffAxBod_c/00001
366
+ id00154/0hjW3eTGAy8/00001 id04656/1tZYt8jey54/00001
367
+ id03839/1jWHvl2qCq0/00001 id01618/0iFlmfmWVlY/00001
368
+ id03862/0w8W8jp7MJk/00001 id02286/4LAIxvdvguc/00001
369
+ id06310/1IAgr_CRnuE/00001 id08456/29EhSZDqzas/00001
370
+ id02317/0q4X8kPTlEY/00001 id04276/5M8NmCwTHZ0/00001
371
+ id06913/4Ug7aJemzpg/00001 id04366/0iG2Ub9zETM/00001
372
+ id06310/1IAgr_CRnuE/00001 id00926/2Nd7f1yNQzE/00001
373
+ id01228/2TIFacjgehY/00001 id02181/02gIO4WrZLY/00001
374
+ id07414/110UMQovTR0/00001 id05594/0ohBiepcHWI/00001
375
+ id03980/7MRUusImkno/00001 id03178/2CT-6fnBC_o/00001
376
+ id03347/4xXZ75_TeSM/00001 id04478/2grMtwdG93I/00001
377
+ id06692/2ptBBNIZXtI/00001 id05459/18XmQEiGLnQ/00001
378
+ id00154/0hjW3eTGAy8/00001 id02725/37kUrf6RJdw/00001
379
+ id01228/2TIFacjgehY/00001 id04006/113VkmVVz1Q/00001
380
+ id00866/03SSllwNkGk/00001 id00926/2Nd7f1yNQzE/00001
381
+ id05594/0ohBiepcHWI/00001 id04006/113VkmVVz1Q/00001
382
+ id04656/1tZYt8jey54/00001 id01822/0QcHowaLAF0/00001
383
+ id07354/0NjekFZqaY0/00001 id04536/0f_Yi_1CoeM/00001
384
+ id07354/0NjekFZqaY0/00001 id04656/1tZYt8jey54/00001
385
+ id04366/0iG2Ub9zETM/00001 id02057/0xZU7Oi9nvM/00001
386
+ id03789/0kdVSujPa9g/00001 id01822/0QcHowaLAF0/00001
387
+ id07621/0CiFdFegqZM/00001 id03347/4xXZ75_TeSM/00001
388
+ id04030/7mXUMuo5_NE/00001 id04366/0iG2Ub9zETM/00001
389
+ id00812/1Xfgvdu7oDo/00001 id07354/0NjekFZqaY0/00001
390
+ id04536/0f_Yi_1CoeM/00001 id07494/0P1wPmgz0Bk/00001
391
+ id04536/0f_Yi_1CoeM/00001 id05816/1dyCBbJ94iw/00001
392
+ id03862/0w8W8jp7MJk/00001 id07868/5YYJq3fSbH8/00001
393
+ id02685/4JDRxqYC0a4/00001 id05459/18XmQEiGLnQ/00001
394
+ id06209/2zM9EAPsZZQ/00001 id07426/1KNFfOFEhyI/00001
395
+ id07426/1KNFfOFEhyI/00001 id02317/0q4X8kPTlEY/00001
396
+ id00926/2Nd7f1yNQzE/00001 id05594/0ohBiepcHWI/00001
397
+ id00154/0hjW3eTGAy8/00001 id04950/2n4sGPqU9M8/00001
398
+ id03041/5CfnYwQCW48/00001 id01892/3vKPgjwFjbo/00001
399
+ id00419/1zffAxBod_c/00001 id00866/03SSllwNkGk/00001
400
+ id02725/37kUrf6RJdw/00001 id05202/2gnLcAbAoSc/00001
401
+ id04656/1tZYt8jey54/00001 id06913/4Ug7aJemzpg/00001
402
+ id03862/0w8W8jp7MJk/00001 id04006/113VkmVVz1Q/00001
403
+ id00419/1zffAxBod_c/00001 id04030/7mXUMuo5_NE/00001
404
+ id06692/2ptBBNIZXtI/00001 id01541/2P7hzPq5iDw/00001
405
+ id07354/0NjekFZqaY0/00001 id03041/5CfnYwQCW48/00001
406
+ id03347/4xXZ75_TeSM/00001 id07802/0RUpqvi3sPU/00001
407
+ id07354/0NjekFZqaY0/00001 id01298/2K5F6xG-Rbs/00001
408
+ id02725/37kUrf6RJdw/00001 id03980/7MRUusImkno/00001
409
+ id01618/0iFlmfmWVlY/00001 id02445/3Rnk8eja3TU/00001
410
+ id05816/1dyCBbJ94iw/00001 id00081/2xYrsnvtUWc/00001
411
+ id07354/0NjekFZqaY0/00001 id04478/2grMtwdG93I/00001
412
+ id03980/7MRUusImkno/00001 id04295/1fSjOItVYVg/00001
413
+ id02548/0pAkJZmlFqc/00001 id00081/2xYrsnvtUWc/00001
414
+ id05459/18XmQEiGLnQ/00001 id03347/4xXZ75_TeSM/00001
415
+ id04570/0YMGn6BI9rg/00001 id04006/113VkmVVz1Q/00001
416
+ id06209/2zM9EAPsZZQ/00001 id01041/1UYZqPpavtk/00001
417
+ id01228/2TIFacjgehY/00001 id02317/0q4X8kPTlEY/00001
418
+ id07802/0RUpqvi3sPU/00001 id01541/2P7hzPq5iDw/00001
419
+ id04862/0zJh2FMTaDE/00001 id01892/3vKPgjwFjbo/00001
420
+ id04253/1HOlzefgLu8/00001 id07802/0RUpqvi3sPU/00001
421
+ id06692/2ptBBNIZXtI/00001 id02286/4LAIxvdvguc/00001
422
+ id01228/2TIFacjgehY/00001 id07961/3EPjXGhfst4/00001
423
+ id05714/2gvpaZcvAY4/00001 id00812/1Xfgvdu7oDo/00001
424
+ id03789/0kdVSujPa9g/00001 id03862/0w8W8jp7MJk/00001
425
+ id04295/1fSjOItVYVg/00001 id07868/5YYJq3fSbH8/00001
426
+ id04276/5M8NmCwTHZ0/00001 id02057/0xZU7Oi9nvM/00001
427
+ id02286/4LAIxvdvguc/00001 id03862/0w8W8jp7MJk/00001
428
+ id04478/2grMtwdG93I/00001 id05816/1dyCBbJ94iw/00001
429
+ id08456/29EhSZDqzas/00001 id02725/37kUrf6RJdw/00001
430
+ id02577/0euHS_r5JH4/00001 id07961/3EPjXGhfst4/00001
431
+ id01618/0iFlmfmWVlY/00001 id00812/1Xfgvdu7oDo/00001
432
+ id07312/0LWllHGohPY/00001 id03789/0kdVSujPa9g/00001
433
+ id02685/4JDRxqYC0a4/00001 id03839/1jWHvl2qCq0/00001
434
+ id04030/7mXUMuo5_NE/00001 id07802/0RUpqvi3sPU/00001
435
+ id01567/1Lx_ZqrK1bM/00001 id04478/2grMtwdG93I/00001
436
+ id02577/0euHS_r5JH4/00001 id02548/0pAkJZmlFqc/00001
437
+ id04536/0f_Yi_1CoeM/00001 id03030/5wOxV1wAgqA/00001
438
+ id03347/4xXZ75_TeSM/00001 id00081/2xYrsnvtUWc/00001
439
+ id03980/7MRUusImkno/00001 id06209/2zM9EAPsZZQ/00001
440
+ id01567/1Lx_ZqrK1bM/00001 id00154/0hjW3eTGAy8/00001
441
+ id06104/02L1L9RFAgI/00001 id02057/0xZU7Oi9nvM/00001
442
+ id04570/0YMGn6BI9rg/00001 id03980/7MRUusImkno/00001
443
+ id08456/29EhSZDqzas/00001 id02286/4LAIxvdvguc/00001
444
+ id07312/0LWllHGohPY/00001 id04366/0iG2Ub9zETM/00001
445
+ id05654/07pANazoyJg/00001 id07426/1KNFfOFEhyI/00001
446
+ id03839/1jWHvl2qCq0/00001 id03347/4xXZ75_TeSM/00001
447
+ id04536/0f_Yi_1CoeM/00001 id04478/2grMtwdG93I/00001
448
+ id05816/1dyCBbJ94iw/00001 id04862/0zJh2FMTaDE/00001
449
+ id04950/2n4sGPqU9M8/00001 id00817/0GmSijZelGY/00001
450
+ id07426/1KNFfOFEhyI/00001 id04862/0zJh2FMTaDE/00001
451
+ id05459/18XmQEiGLnQ/00001 id00812/1Xfgvdu7oDo/00001
452
+ id00154/0hjW3eTGAy8/00001 id03178/2CT-6fnBC_o/00001
453
+ id04295/1fSjOItVYVg/00001 id07312/0LWllHGohPY/00001
454
+ id05594/0ohBiepcHWI/00001 id04862/0zJh2FMTaDE/00001
455
+ id03347/4xXZ75_TeSM/00001 id01541/2P7hzPq5iDw/00001
456
+ id04536/0f_Yi_1CoeM/00001 id02445/3Rnk8eja3TU/00001
457
+ id03862/0w8W8jp7MJk/00001 id04030/7mXUMuo5_NE/00001
458
+ id00154/0hjW3eTGAy8/00001 id01541/2P7hzPq5iDw/00001
459
+ id06913/4Ug7aJemzpg/00001 id03347/4xXZ75_TeSM/00001
460
+ id08696/0H1PxInJCK0/00001 id04478/2grMtwdG93I/00001
461
+ id04366/0iG2Ub9zETM/00001 id02445/3Rnk8eja3TU/00001
462
+ id07354/0NjekFZqaY0/00001 id01567/1Lx_ZqrK1bM/00001
463
+ id06913/4Ug7aJemzpg/00001 id05202/2gnLcAbAoSc/00001
464
+ id04862/0zJh2FMTaDE/00001 id08696/0H1PxInJCK0/00001
465
+ id03178/2CT-6fnBC_o/00001 id02685/4JDRxqYC0a4/00001
466
+ id01822/0QcHowaLAF0/00001 id04950/2n4sGPqU9M8/00001
467
+ id00081/2xYrsnvtUWc/00001 id06913/4Ug7aJemzpg/00001
468
+ id07868/5YYJq3fSbH8/00001 id02465/0Ocu8l1eAng/00001
469
+ id02181/02gIO4WrZLY/00001 id03862/0w8W8jp7MJk/00001
470
+ id07868/5YYJq3fSbH8/00001 id05202/2gnLcAbAoSc/00001
471
+ id02286/4LAIxvdvguc/00001 id03178/2CT-6fnBC_o/00001
472
+ id01298/2K5F6xG-Rbs/00001 id01618/0iFlmfmWVlY/00001
473
+ id03980/7MRUusImkno/00001 id04006/113VkmVVz1Q/00001
474
+ id03862/0w8W8jp7MJk/00001 id08456/29EhSZDqzas/00001
475
+ id01567/1Lx_ZqrK1bM/00001 id03041/5CfnYwQCW48/00001
476
+ id02465/0Ocu8l1eAng/00001 id00419/1zffAxBod_c/00001
477
+ id04570/0YMGn6BI9rg/00001 id04295/1fSjOItVYVg/00001
478
+ id03862/0w8W8jp7MJk/00001 id04295/1fSjOItVYVg/00001
479
+ id03789/0kdVSujPa9g/00001 id00866/03SSllwNkGk/00001
480
+ id05654/07pANazoyJg/00001 id00926/2Nd7f1yNQzE/00001
481
+ id05850/B8kp8ed48JE/00001 id02685/4JDRxqYC0a4/00001
482
+ id03347/4xXZ75_TeSM/00001 id08392/0fwuibKviJU/00001
483
+ id00926/2Nd7f1yNQzE/00001 id07312/0LWllHGohPY/00001
484
+ id05850/B8kp8ed48JE/00001 id01041/1UYZqPpavtk/00001
485
+ id03030/5wOxV1wAgqA/00001 id06913/4Ug7aJemzpg/00001
486
+ id02057/0xZU7Oi9nvM/00001 id01041/1UYZqPpavtk/00001
487
+ id03030/5wOxV1wAgqA/00001 id01041/1UYZqPpavtk/00001
488
+ id01618/0iFlmfmWVlY/00001 id04366/0iG2Ub9zETM/00001
489
+ id06310/1IAgr_CRnuE/00001 id04119/1uH67UruKlE/00001
490
+ id05594/0ohBiepcHWI/00001 id02317/0q4X8kPTlEY/00001
491
+ id01228/2TIFacjgehY/00001 id04119/1uH67UruKlE/00001
492
+ id02286/4LAIxvdvguc/00001 id02445/3Rnk8eja3TU/00001
493
+ id04030/7mXUMuo5_NE/00001 id00419/1zffAxBod_c/00001
494
+ id01298/2K5F6xG-Rbs/00001 id02445/3Rnk8eja3TU/00001
495
+ id07802/0RUpqvi3sPU/00001 id04862/0zJh2FMTaDE/00001
496
+ id04006/113VkmVVz1Q/00001 id03347/4xXZ75_TeSM/00001
497
+ id02317/0q4X8kPTlEY/00001 id05850/B8kp8ed48JE/00001
498
+ id08456/29EhSZDqzas/00001 id04656/1tZYt8jey54/00001
499
+ id04656/1tZYt8jey54/00001 id05816/1dyCBbJ94iw/00001
500
+ id05202/2gnLcAbAoSc/00001 id06209/2zM9EAPsZZQ/00001
face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = 'adrian.bulat@nottingham.ac.uk'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize
face_detection/api.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ try:
9
+ import urllib.request as request_file
10
+ except BaseException:
11
+ import urllib as request_file
12
+
13
+ from .models import FAN, ResNetDepth
14
+ from .utils import *
15
+
16
+
17
+ class LandmarksType(Enum):
18
+ """Enum class defining the type of landmarks to detect.
19
+
20
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
22
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23
+
24
+ """
25
+ _2D = 1
26
+ _2halfD = 2
27
+ _3D = 3
28
+
29
+
30
+ class NetworkSize(Enum):
31
+ # TINY = 1
32
+ # SMALL = 2
33
+ # MEDIUM = 3
34
+ LARGE = 4
35
+
36
+ def __new__(cls, value):
37
+ member = object.__new__(cls)
38
+ member._value_ = value
39
+ return member
40
+
41
+ def __int__(self):
42
+ return self.value
43
+
44
+ ROOT = os.path.dirname(os.path.abspath(__file__))
45
+
46
+ class FaceAlignment:
47
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49
+ self.device = device
50
+ self.flip_input = flip_input
51
+ self.landmarks_type = landmarks_type
52
+ self.verbose = verbose
53
+
54
+ network_size = int(network_size)
55
+
56
+ if 'cuda' in device:
57
+ torch.backends.cudnn.benchmark = True
58
+
59
+ # Get the face detector
60
+ face_detector_module = __import__('face_detection.detection.' + face_detector,
61
+ globals(), locals(), [face_detector], 0)
62
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
63
+
64
+ def get_detections_for_batch(self, images):
65
+ images = images[..., ::-1]
66
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
67
+ results = []
68
+
69
+ for i, d in enumerate(detected_faces):
70
+ # print("Inside facedection:", i, len(d))
71
+ if len(d) == 0:
72
+ results.append(None)
73
+ continue
74
+ d = d[0]
75
+ d = np.clip(d, 0, None)
76
+
77
+ x1, y1, x2, y2 = map(int, d[:-1])
78
+ results.append((x1, y1, x2, y2))
79
+
80
+ return results
81
+
82
+ def get_all_detections_for_batch(self, images):
83
+ #for multiface facedetection
84
+ images = images[..., ::-1]
85
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
86
+ results = []
87
+
88
+ for i, d in enumerate(detected_faces):
89
+ # print("Inside facedection:", i, len(d))
90
+ if len(d) == 0:
91
+ results.append(None)
92
+ continue
93
+ d = [np.clip(dd, 0, None) for dd in d]
94
+ # d = [map(int, dd[:-1]) for dd in d]
95
+ d = [[int(ddd) for ddd in dd[:-1]] for dd in d]
96
+ results.append(d)
97
+
98
+ return results
face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
face_detection/detection/core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+
27
+ if 'cpu' not in device and 'cuda' not in device:
28
+ if verbose:
29
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30
+ raise ValueError
31
+
32
+ def detect_from_image(self, tensor_or_path):
33
+ """Detects faces in a given image.
34
+
35
+ This function detects the faces present in a provided BGR(usually)
36
+ image. The input can be either the image itself or the path to it.
37
+
38
+ Arguments:
39
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40
+ to an image or the image itself.
41
+
42
+ Example::
43
+
44
+ >>> path_to_image = 'data/image_01.jpg'
45
+ ... detected_faces = detect_from_image(path_to_image)
46
+ [A list of bounding boxes (x1, y1, x2, y2)]
47
+ >>> image = cv2.imread(path_to_image)
48
+ ... detected_faces = detect_from_image(image)
49
+ [A list of bounding boxes (x1, y1, x2, y2)]
50
+
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55
+ """Detects faces from all the images present in a given directory.
56
+
57
+ Arguments:
58
+ path {string} -- a string containing a path that points to the folder containing the images
59
+
60
+ Keyword Arguments:
61
+ extensions {list} -- list of string containing the extensions to be
62
+ consider in the following format: ``.extension_name`` (default:
63
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64
+ folder recursively (default: {False}) show_progress_bar {bool} --
65
+ display a progressbar (default: {True})
66
+
67
+ Example:
68
+ >>> directory = 'data'
69
+ ... detected_faces = detect_from_directory(directory)
70
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71
+
72
+ """
73
+ if self.verbose:
74
+ logger = logging.getLogger(__name__)
75
+
76
+ if len(extensions) == 0:
77
+ if self.verbose:
78
+ logger.error("Expected at list one extension, but none was received.")
79
+ raise ValueError
80
+
81
+ if self.verbose:
82
+ logger.info("Constructing the list of images.")
83
+ additional_pattern = '/**/*' if recursive else '/*'
84
+ files = []
85
+ for extension in extensions:
86
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87
+
88
+ if self.verbose:
89
+ logger.info("Finished searching for images. %s images found", len(files))
90
+ logger.info("Preparing to run the detection.")
91
+
92
+ predictions = {}
93
+ for image_path in tqdm(files, disable=not show_progress_bar):
94
+ if self.verbose:
95
+ logger.info("Running the face detector on image: %s", image_path)
96
+ predictions[image_path] = self.detect_from_image(image_path)
97
+
98
+ if self.verbose:
99
+ logger.info("The detector was successfully run on all %s images", len(files))
100
+
101
+ return predictions
102
+
103
+ @property
104
+ def reference_scale(self):
105
+ raise NotImplementedError
106
+
107
+ @property
108
+ def reference_x_shift(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_y_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @staticmethod
116
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118
+
119
+ Arguments:
120
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121
+ """
122
+ if isinstance(tensor_or_path, str):
123
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124
+ elif torch.is_tensor(tensor_or_path):
125
+ # Call cpu in case its coming from cuda
126
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127
+ elif isinstance(tensor_or_path, np.ndarray):
128
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129
+ else:
130
+ raise TypeError
face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+
70
+ bboxlist = []
71
+ for i in range(len(olist) // 2):
72
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
73
+ olist = [oelem.data.cpu() for oelem in olist]
74
+ for i in range(len(olist) // 2):
75
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
76
+ FB, FC, FH, FW = ocls.size() # feature map size
77
+ stride = 2**(i + 2) # 4,8,16,32,64,128
78
+ anchor = stride * 4
79
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
80
+ for Iindex, hindex, windex in poss:
81
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
82
+ score = ocls[:, 1, hindex, windex]
83
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
84
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
85
+ variances = [0.1, 0.2]
86
+ box = batch_decode(loc, priors, variances)
87
+ box = box[:, 0] * 1.0
88
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
89
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
90
+ bboxlist = np.array(bboxlist)
91
+ if 0 == len(bboxlist):
92
+ bboxlist = np.zeros((1, BB, 5))
93
+
94
+ return bboxlist
95
+
96
+ def flip_detect(net, img, device):
97
+ img = cv2.flip(img, 1)
98
+ b = detect(net, img, device)
99
+
100
+ bboxlist = np.zeros(b.shape)
101
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
102
+ bboxlist[:, 1] = b[:, 1]
103
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
104
+ bboxlist[:, 3] = b[:, 3]
105
+ bboxlist[:, 4] = b[:, 4]
106
+ return bboxlist
107
+
108
+
109
+ def pts_to_bb(pts):
110
+ min_x, min_y = np.min(pts, axis=0)
111
+ max_x, max_y = np.max(pts, axis=0)
112
+ return np.array([min_x, min_y, max_x, max_y])
face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ if not os.path.isfile(path_to_detector):
22
+ model_weights = load_url(models_urls['s3fd'])
23
+ else:
24
+ model_weights = torch.load(path_to_detector)
25
+
26
+ self.face_detector = s3fd()
27
+ self.face_detector.load_state_dict(model_weights)
28
+ self.face_detector.to(device)
29
+ self.face_detector.eval()
30
+
31
+ def detect_from_image(self, tensor_or_path):
32
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
33
+
34
+ bboxlist = detect(self.face_detector, image, device=self.device)
35
+ keep = nms(bboxlist, 0.3)
36
+ bboxlist = bboxlist[keep, :]
37
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38
+
39
+ return bboxlist
40
+
41
+ def detect_from_batch(self, images):
42
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
43
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46
+
47
+ return bboxlists
48
+
49
+ @property
50
+ def reference_scale(self):
51
+ return 195
52
+
53
+ @property
54
+ def reference_x_shift(self):
55
+ return 0
56
+
57
+ @property
58
+ def reference_y_shift(self):
59
+ return 0
face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
generate.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' consistent initial noise for video generation'''
2
+ import cv2
3
+ import os
4
+ from os.path import join, basename, dirname, splitext
5
+ import shutil
6
+ import argparse
7
+ import numpy as np
8
+ import random
9
+ import torch, torchvision
10
+ import subprocess
11
+ from audio import audio
12
+ import face_detection
13
+ from tqdm import tqdm
14
+
15
+ from guided_diffusion import dist_util, logger
16
+ from guided_diffusion.resample import create_named_schedule_sampler
17
+ from guided_diffusion.script_util import (
18
+ tfg_model_and_diffusion_defaults,
19
+ tfg_create_model_and_diffusion,
20
+ args_to_dict,
21
+ add_dict_to_argparser,
22
+ )
23
+
24
+ from guided_diffusion.tfg_data_util import (
25
+ tfg_process_batch,
26
+ )
27
+
28
+ def get_frame_id(frame):
29
+ return int(basename(frame).split('.')[0])
30
+
31
+ def crop_audio_window(spec, start_frame, args ):
32
+ if type(start_frame) == int:
33
+ start_frame_num = start_frame
34
+ else:
35
+ start_frame_num = get_frame_id(start_frame)
36
+ start_idx = int(args.mel_steps_per_sec * (start_frame_num / float(args.video_fps)))
37
+ end_idx = start_idx + args.syncnet_mel_step_size
38
+ return spec[start_idx : end_idx, :]
39
+
40
+ def load_all_indiv_mels(path, args):
41
+ in_path = path
42
+ out_dir = join(args.sample_path, "temp", basename(in_path).replace(".mp4", ""))
43
+ os.makedirs(out_dir, exist_ok= True)
44
+ out_path = join(out_dir, "audio.wav")
45
+ command2 = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(in_path, out_path)
46
+ subprocess.call(command2, shell=True)
47
+ wav = audio.load_wav(out_path, args.sample_rate)
48
+ orig_mel = audio.melspectrogram(wav).T
49
+
50
+ all_indiv_mels = []
51
+ # i=0
52
+ i=1
53
+ while True:
54
+ m = crop_audio_window(orig_mel.copy(), max(i - args.syncnet_T//2,0), args)
55
+ if (m.shape[0] != args.syncnet_mel_step_size):
56
+ break
57
+ all_indiv_mels.append(m.T)
58
+ i+=1
59
+
60
+ #clean up
61
+ shutil.rmtree(join(args.sample_path, "temp"))
62
+
63
+ return all_indiv_mels, wav
64
+
65
+ def load_video_frames(path, args):
66
+ in_path = path
67
+ out_dir = join(args.sample_path, "temp", basename(in_path).replace(".mp4", ""), "image")
68
+ os.makedirs(out_dir, exist_ok= True)
69
+
70
+
71
+ command = "ffmpeg -loglevel error -y -i {} -vf fps={} -q:v 2 -qmin 1 {}/%05d.jpg".format(in_path, args.video_fps, out_dir)
72
+ subprocess.call(command, shell=True)
73
+
74
+ video_frames=[]
75
+ for i, img_name in enumerate(sorted(os.listdir(out_dir))):
76
+ img_path=join(out_dir, img_name)
77
+ img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
78
+ video_frames.append(img)
79
+
80
+ #clean up
81
+ shutil.rmtree(join(args.sample_path, "temp"))
82
+
83
+
84
+ return video_frames
85
+
86
+
87
+ def get_smoothened_boxes(boxes, T):
88
+ for i in range(len(boxes)):
89
+ if i + T > len(boxes):
90
+ window = boxes[len(boxes) - T:]
91
+ else:
92
+ window = boxes[i : i + T]
93
+ boxes[i] = np.mean(window, axis=0)
94
+ return boxes
95
+
96
+ def my_voxceleb2_crop(img):
97
+ return img[:-int(img.shape[0]*2.36/8) , int(img.shape[1]*1.8/8): -int(img.shape[1]*1.8/8)]
98
+
99
+ def my_voxceleb2_crop_bboxs(img):
100
+ return 0,img.shape[0]-int(img.shape[0]*2.36/8), int(img.shape[1]*1.8/8), img.shape[1]-int(img.shape[1]*1.8/8)
101
+
102
+ def face_detect(images, detector, args, resize=False):
103
+ batch_size = args.face_det_batch_size
104
+
105
+ while 1:
106
+ predictions = []
107
+ try:
108
+ for i in range(0, len(images), batch_size):
109
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
110
+ except RuntimeError:
111
+ if batch_size == 1:
112
+ raise RuntimeError('Image too big to run face detection on GPU')
113
+ batch_size //= 2
114
+ args.face_det_batch_size = batch_size
115
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
116
+ continue
117
+ break
118
+
119
+ results = []
120
+ if type(args.pads) == str :
121
+ args.pads = [int(x) for x in args.pads.split(",")]
122
+ pady1, pady2, padx1, padx2 = args.pads
123
+ for rect, image in zip(predictions, images):
124
+ if rect is None:
125
+ raise ValueError('Face not detected!')
126
+
127
+ y1 = max(0, rect[1] - pady1)
128
+ y2 = min(image.shape[0], rect[3] + pady2)
129
+ x1 = max(0, rect[0] - padx1)
130
+ x2 = min(image.shape[1], rect[2] + padx2)
131
+
132
+ results.append([x1, y1, x2, y2])
133
+
134
+ boxes = get_smoothened_boxes(np.array(results), T=5)
135
+
136
+ if resize:
137
+ if args.is_voxceleb2:
138
+ results = [[cv2.resize(my_voxceleb2_crop(image),(args.image_size, args.image_size)), my_voxceleb2_crop_bboxs(image), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
139
+ else:
140
+ results = [[cv2.resize(image[y1: y2, x1:x2],(args.image_size, args.image_size)), (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
141
+ else:
142
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
143
+ return results
144
+
145
+ def normalise(tensor):
146
+ """ [-1,1]->[0,1]"""
147
+ return ((tensor+1)*0.5).clamp(0,1)
148
+
149
+ def normalise2(tensor):
150
+ """ [0,1]->[-1,1]"""
151
+ return (tensor*2-1).clamp(-1,1)
152
+
153
+
154
+ def sample_batch(batch, model, diffusion, args):
155
+ B, F, C, H, W = batch[f'image'].shape
156
+ sample_shape = (B*F, C, H, W)
157
+
158
+
159
+ #generate fixed noise
160
+ init_noise = None
161
+ if args.sampling_seed:
162
+
163
+ state = torch.get_rng_state()
164
+ torch.manual_seed(args.sampling_seed)
165
+ torch.cuda.manual_seed_all(args.sampling_seed)
166
+ init_noise = torch.randn((1,C,H,W))
167
+ #repeat noise for all frames
168
+ init_noise = init_noise.repeat(B*F,1,1,1)
169
+ torch.set_rng_state(state)
170
+
171
+
172
+ img_batch, model_kwargs = tfg_process_batch(batch, args.face_hide_percentage,
173
+ use_ref=args.use_ref,
174
+ use_audio=args.use_audio,
175
+ # sampling_use_gt_for_ref=args.sampling_use_gt_for_ref,
176
+ noise=init_noise)
177
+
178
+
179
+ img_batch = img_batch.to(dist_util.dev())
180
+ model_kwargs = {k: v.to(dist_util.dev()) for k,v in model_kwargs.items()}
181
+ init_noise = init_noise.to(dist_util.dev()) if init_noise is not None else None
182
+
183
+ sample_fn = (
184
+ diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
185
+ )
186
+ sample = sample_fn(
187
+ model,
188
+ sample_shape,
189
+ clip_denoised=args.clip_denoised,
190
+ model_kwargs=model_kwargs,
191
+ noise = init_noise
192
+ )
193
+ return sample, img_batch, model_kwargs
194
+
195
+
196
+ def generate(video_path, audio_path, model, diffusion, detector, args, out_path=None, save_orig=True):
197
+ video_frames = load_video_frames(video_path, args)
198
+ try:
199
+ face_det_results = face_detect(video_frames.copy(), detector, args, resize=True)
200
+ except Exception as e:
201
+ print("Error:", e, video_path, audio_path)
202
+ import traceback
203
+ print(traceback.format_exc())
204
+ wrong_all_indiv_mels, wrong_audio_wavform = load_all_indiv_mels(audio_path, args)
205
+
206
+ min_frames = min(len(video_frames), len(wrong_all_indiv_mels))
207
+ video_frames = video_frames[:min_frames]
208
+ face_det_results = face_det_results[:min_frames]
209
+ face_bboxes = [face_det_results[i][1] for i in range(min_frames)]
210
+ face_frames = torch.FloatTensor(np.transpose(np.asarray([face_det_results[i][0] for i in range(min_frames)], dtype=np.float32)/255.,(0,3,1,2)))#[N, C, H, W]
211
+ wrong_all_indiv_mels = torch.FloatTensor(np.asarray(wrong_all_indiv_mels[:min_frames])).unsqueeze(1) #[N, 1, h, w]
212
+
213
+ if save_orig:
214
+ if out_path is None:
215
+ out_path_orig = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+"_orig.mp4")
216
+ else:
217
+ out_path_orig = out_path.replace(".mp4", "_orig.mp4")
218
+ torchvision.io.write_video(
219
+ out_path_orig,
220
+ video_array=torch.from_numpy(np.array(video_frames)), fps = args.video_fps, video_codec='libx264',
221
+ audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
222
+ )
223
+
224
+ if args.sampling_ref_type=='gt':
225
+ ref_frames = face_frames.clone()
226
+ elif args.sampling_ref_type=='first_frame':
227
+ ref_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
228
+ elif args.sampling_ref_type=='random':
229
+ rand_idx = random.Random(args.sampling_seed).randint(0, len(face_frames)-1)
230
+ ref_frames = face_frames[rand_idx:rand_idx+1].repeat(len(face_frames),1,1,1)
231
+
232
+ if args.sampling_input_type=='first_frame':
233
+ face_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
234
+ video_frames = np.array(video_frames[0:1]*len(video_frames))
235
+ face_bboxes = np.array(face_bboxes[0:1]*len(face_bboxes))
236
+
237
+
238
+ generated_video_frames = []
239
+ b_s = args.sampling_batch_size
240
+ for i in range(0,min_frames, b_s*args.nframes):
241
+ video_frames_batch = video_frames[i:i+b_s*args.nframes]
242
+ face_bboxes_batch = face_bboxes[i:i+b_s*args.nframes]
243
+
244
+ try:
245
+ img_batch = face_frames[i:i+b_s*args.nframes] #[BF, C, H, W]
246
+ img_batch = img_batch.reshape(-1, args.nframes, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
247
+ ref_batch = ref_frames[i:i+b_s*args.nframes]
248
+ ref_batch = ref_batch.reshape(-1, args.nframes, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
249
+ wrong_indiv_mel_batch = wrong_all_indiv_mels[i:i+b_s*args.nframes] #[BF, 1, h, w]
250
+ wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, args.nframes, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
251
+ except: # of the last batch, if B*F % nframes!=0, then the above reshape throws error
252
+ # but internally everything is going to get converted to BF
253
+ # ie. (B,F, C, H, W) -> (B*F, C, H, W) but (B*F, 1, C, H, W) -> (B*F, C, H, W)
254
+ img_batch = face_frames[i:i+b_s*args.nframes] #[BF, C, H, W]
255
+ img_batch = img_batch.reshape(-1, 1, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
256
+ ref_batch = ref_frames[i:i+b_s*args.nframes]
257
+ ref_batch = ref_batch.reshape(-1, 1, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
258
+ wrong_indiv_mel_batch = wrong_all_indiv_mels[i:i+b_s*args.nframes] #[BF, 1, h, w]
259
+ wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, 1, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
260
+
261
+
262
+ batch = {"image":img_batch,
263
+ "ref_img":ref_batch,
264
+ "indiv_mels":wrong_indiv_mel_batch}
265
+
266
+ sample, img_batch, model_kwargs = sample_batch(batch, model, diffusion, args)
267
+ mask = model_kwargs['mask']
268
+ recon_batch = sample * mask + (1. -mask)*img_batch #[BF, C, H, W]
269
+ recon_batch = (normalise(recon_batch)*255).cpu().numpy().transpose(0,2,3,1) #[-1,1] -> [0,255]
270
+
271
+ for g,v,b in zip(recon_batch, video_frames_batch, face_bboxes_batch):
272
+ y1, y2, x1, x2 = b
273
+ g = cv2.resize(g.astype(np.uint8), (x2 - x1, y2 - y1))
274
+ v[y1:y2, x1:x2] = g
275
+ generated_video_frames.append(v)
276
+
277
+
278
+
279
+ print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
280
+ min_time = len(generated_video_frames)/args.video_fps # because video is already smaller because it got chopped accoding to the mel array length
281
+ wrong_audio_wavform = wrong_audio_wavform[:int(min_time*args.sample_rate)]
282
+ print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
283
+ if out_path is None:
284
+ out_path = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+".mp4")
285
+ torchvision.io.write_video(
286
+ out_path,
287
+ video_array=torch.from_numpy(np.array(generated_video_frames)), fps = args.video_fps, video_codec='libx264',
288
+ audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
289
+ )
290
+
291
+
292
+
293
+
294
+
295
+ def generate_from_filelist(test_video_dir, filelist, model, diffusion, detector, args):
296
+ video_names = []
297
+ audio_names = []
298
+ with open(filelist, "r") as f:
299
+ lines = f.readlines()
300
+ for line in tqdm(lines):
301
+ try:
302
+ audio_name, video_name = line.strip().split()
303
+ audio_path = join(test_video_dir, audio_name+'.mp4')
304
+ video_path = join(test_video_dir, video_name+'.mp4')
305
+ out_path = join(args.sample_path,audio_name.replace('/','.')+"_"+video_name.replace('/','.')+".mp4")
306
+ generate(video_path, audio_path, model, diffusion, detector, args, out_path=out_path ,save_orig=args.save_orig)
307
+ except Exception as e:
308
+ print("Error:", e, video_path, audio_path)
309
+ import traceback
310
+ print(traceback.format_exc())
311
+
312
+
313
+
314
+ def main():
315
+ args = create_argparser().parse_args()
316
+ dist_util.setup_dist()
317
+ logger.configure(dir=args.sample_path, format_strs=["stdout", "log"])
318
+
319
+ logger.log("creating model...")
320
+ model, diffusion = tfg_create_model_and_diffusion(
321
+ **args_to_dict(args, tfg_model_and_diffusion_defaults().keys())
322
+ )
323
+ model.load_state_dict(
324
+ dist_util.load_state_dict(args.model_path, map_location='cpu')
325
+ )
326
+ model.to(dist_util.dev())
327
+ if args.use_fp16:
328
+ model.convert_to_fp16()
329
+ model.eval()
330
+
331
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda' if torch.cuda.is_available() else 'cpu')
332
+
333
+ if args.generate_from_filelist:
334
+ generate_from_filelist(args.test_video_dir, args.filelist, model, diffusion, detector, args)
335
+ else:
336
+ generate(args.video_path, args.audio_path, model, diffusion, detector, args, out_path=args.out_path, save_orig=args.save_orig)
337
+
338
+
339
+ def create_argparser():
340
+ defaults = dict(
341
+ # generate from a single audio-video pair
342
+ generate_from_filelist = False,
343
+ video_path = "",
344
+ audio_path = "",
345
+ out_path = None,
346
+ save_orig = True,
347
+
348
+ #generate from filelist : generate_from_filelist = True
349
+ test_video_dir = "test_videos",
350
+ filelist = "test_filelist.txt",
351
+
352
+
353
+ use_fp16 = True,
354
+ #tfg specific
355
+ face_hide_percentage=0.5,
356
+ use_ref=False,
357
+ use_audio=False,
358
+ audio_as_style=False,
359
+ audio_as_style_encoder_mlp=False,
360
+
361
+ #data args
362
+ nframes=1,
363
+ nrefer=0,
364
+ image_size=128,
365
+ syncnet_T = 5,
366
+ syncnet_mel_step_size = 16,
367
+ audio_frames_per_video = 16, #for tfg model, we use sound corresponding to 5 frames centred at that frame
368
+ audio_dim=80,
369
+ is_voxceleb2=True,
370
+
371
+ video_fps=25,
372
+ sample_rate=16000, #audio sampling rate
373
+ mel_steps_per_sec=80.,
374
+
375
+ #sampling args
376
+ clip_denoised=True, # not used in training
377
+ sampling_batch_size=2,
378
+ use_ddim=False,
379
+ model_path="",
380
+ sample_path="d2l_gen",
381
+ sample_partition="",
382
+ sampling_seed=None,
383
+ sampling_use_gt_for_ref=False,
384
+ sampling_ref_type='gt', #one of ['gt', 'first_frame', 'random']
385
+ sampling_input_type='gt', #one of ['gt', 'first_frame']
386
+
387
+ # face detection args
388
+ face_det_batch_size=64,
389
+ pads = "0,0,0,0"
390
+ )
391
+ defaults.update(tfg_model_and_diffusion_defaults())
392
+ parser = argparse.ArgumentParser()
393
+ add_dict_to_argparser(parser, defaults)
394
+ return parser
395
+
396
+
397
+ if __name__=="__main__":
398
+ main()
generate_dist.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' consistent initial noise for video generation'''
2
+ import cv2
3
+ import os
4
+ from os.path import join, basename, dirname, splitext
5
+ import shutil
6
+ import argparse
7
+ import numpy as np
8
+ import random
9
+ import torch, torchvision
10
+ import subprocess
11
+ from audio import audio
12
+ import face_detection
13
+ from tqdm import tqdm
14
+ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
15
+ from guided_diffusion import dist_util, logger
16
+ from guided_diffusion.resample import create_named_schedule_sampler
17
+ from guided_diffusion.script_util import (
18
+ tfg_model_and_diffusion_defaults,
19
+ tfg_create_model_and_diffusion,
20
+ args_to_dict,
21
+ add_dict_to_argparser,
22
+ )
23
+ from time import time
24
+ import torch.distributed as dist
25
+ from guided_diffusion.tfg_data_util import (
26
+ tfg_process_batch,
27
+ )
28
+
29
+ def get_frame_id(frame):
30
+ return int(basename(frame).split('.')[0])
31
+
32
+ def crop_audio_window(spec, start_frame, args ):
33
+ if type(start_frame) == int:
34
+ start_frame_num = start_frame
35
+ else:
36
+ start_frame_num = get_frame_id(start_frame)
37
+ start_idx = int(args.mel_steps_per_sec * (start_frame_num / float(args.video_fps)))
38
+ end_idx = start_idx + args.syncnet_mel_step_size
39
+ return spec[start_idx : end_idx, :]
40
+
41
+ def load_all_indiv_mels(path, args):
42
+ in_path = path
43
+ out_dir = join(args.sample_path, "temp",str(dist.get_rank()), basename(in_path).replace(".mp4", ""))
44
+ os.makedirs(out_dir, exist_ok= True)
45
+ out_path = join(out_dir, "audio.wav")
46
+ command2 = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(in_path, out_path)
47
+ subprocess.call(command2, shell=True)
48
+ wav = audio.load_wav(out_path, args.sample_rate)
49
+ orig_mel = audio.melspectrogram(wav).T
50
+
51
+ all_indiv_mels = []
52
+ # i=0
53
+ i=1
54
+ while True:
55
+ m = crop_audio_window(orig_mel.copy(), max(i - args.syncnet_T//2,0), args)
56
+ if (m.shape[0] != args.syncnet_mel_step_size):
57
+ break
58
+ all_indiv_mels.append(m.T)
59
+ i+=1
60
+
61
+ #clean up
62
+ shutil.rmtree(join(args.sample_path, "temp", str(dist.get_rank())))
63
+
64
+ return all_indiv_mels, wav
65
+
66
+ def load_video_frames(path, args):
67
+ in_path = path
68
+ out_dir = join(args.sample_path, "temp", str(dist.get_rank()), basename(in_path).replace(".mp4", ""), "image")
69
+ os.makedirs(out_dir, exist_ok= True)
70
+
71
+
72
+ command = "ffmpeg -loglevel error -y -i {} -vf fps={} -q:v 2 -qmin 1 {}/%05d.jpg".format(in_path, args.video_fps, out_dir)
73
+ subprocess.call(command, shell=True)
74
+
75
+ video_frames=[]
76
+ for i, img_name in enumerate(sorted(os.listdir(out_dir))):
77
+ img_path=join(out_dir, img_name)
78
+ img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
79
+ video_frames.append(img)
80
+
81
+ #clean up
82
+ shutil.rmtree(join(args.sample_path, "temp", str(dist.get_rank())))
83
+
84
+
85
+ return video_frames
86
+
87
+
88
+ def get_smoothened_boxes(boxes, T):
89
+ for i in range(len(boxes)):
90
+ if i + T > len(boxes):
91
+ window = boxes[len(boxes) - T:]
92
+ else:
93
+ window = boxes[i : i + T]
94
+ boxes[i] = np.mean(window, axis=0)
95
+ return boxes
96
+
97
+ def my_voxceleb2_crop(img):
98
+ return img[:-int(img.shape[0]*2.36/8) , int(img.shape[1]*1.8/8): -int(img.shape[1]*1.8/8)]
99
+
100
+ def my_voxceleb2_crop_bboxs(img):
101
+ return 0,img.shape[0]-int(img.shape[0]*2.36/8), int(img.shape[1]*1.8/8), img.shape[1]-int(img.shape[1]*1.8/8)
102
+
103
+ def face_detect(images, detector, args, resize=False):
104
+ batch_size = args.face_det_batch_size
105
+
106
+ while 1:
107
+ predictions = []
108
+ try:
109
+ for i in range(0, len(images), batch_size):
110
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
111
+ except RuntimeError:
112
+ if batch_size == 1:
113
+ raise RuntimeError('Image too big to run face detection on GPU')
114
+ batch_size //= 2
115
+ args.face_det_batch_size = batch_size
116
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
117
+ continue
118
+ break
119
+
120
+ results = []
121
+ if type(args.pads) == str :
122
+ args.pads = [int(x) for x in args.pads.split(",")]
123
+ pady1, pady2, padx1, padx2 = args.pads
124
+ for rect, image in zip(predictions, images):
125
+ if rect is None:
126
+ raise ValueError('Face not detected!')
127
+
128
+ y1 = max(0, rect[1] - pady1)
129
+ y2 = min(image.shape[0], rect[3] + pady2)
130
+ x1 = max(0, rect[0] - padx1)
131
+ x2 = min(image.shape[1], rect[2] + padx2)
132
+
133
+ results.append([x1, y1, x2, y2])
134
+
135
+ boxes = get_smoothened_boxes(np.array(results), T=5)
136
+
137
+ if resize:
138
+ if args.is_voxceleb2:
139
+ results = [[cv2.resize(my_voxceleb2_crop(image),(args.image_size, args.image_size)), my_voxceleb2_crop_bboxs(image), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
140
+ else:
141
+ results = [[cv2.resize(image[y1: y2, x1:x2],(args.image_size, args.image_size)), (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
142
+ else:
143
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
144
+ return results
145
+
146
+ def normalise(tensor):
147
+ """ [-1,1]->[0,1]"""
148
+ return ((tensor+1)*0.5).clamp(0,1)
149
+
150
+ def normalise2(tensor):
151
+ """ [0,1]->[-1,1]"""
152
+ return (tensor*2-1).clamp(-1,1)
153
+
154
+
155
+ def sample_batch(batch, model, diffusion, args):
156
+ B, F, C, H, W = batch[f'image'].shape
157
+ sample_shape = (B*F, C, H, W)
158
+
159
+
160
+ #generate fixed noise
161
+ init_noise = None
162
+ if args.sampling_seed:
163
+
164
+ state = torch.get_rng_state()
165
+ torch.manual_seed(args.sampling_seed)
166
+ torch.cuda.manual_seed_all(args.sampling_seed)
167
+ init_noise = torch.randn((1,C,H,W))
168
+ #repeat noise for all frames
169
+ init_noise = init_noise.repeat(B*F,1,1,1)
170
+ torch.set_rng_state(state)
171
+
172
+
173
+ img_batch, model_kwargs = tfg_process_batch(batch, args.face_hide_percentage,
174
+ use_ref=args.use_ref,
175
+ use_audio=args.use_audio,
176
+ # sampling_use_gt_for_ref=args.sampling_use_gt_for_ref,
177
+ noise=init_noise)
178
+
179
+
180
+ img_batch = img_batch.to(dist_util.dev())
181
+ model_kwargs = {k: v.to(dist_util.dev()) for k,v in model_kwargs.items()}
182
+ init_noise = init_noise.to(dist_util.dev()) if init_noise is not None else None
183
+
184
+ sample_fn = (
185
+ diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
186
+ )
187
+ sample = sample_fn(
188
+ model,
189
+ sample_shape,
190
+ clip_denoised=args.clip_denoised,
191
+ model_kwargs=model_kwargs,
192
+ noise = init_noise
193
+ )
194
+ return sample, img_batch, model_kwargs
195
+
196
+
197
+ def generate(video_path, audio_path, model, diffusion, detector, args, out_path=None, save_orig=True):
198
+ video_frames = load_video_frames(video_path, args)
199
+ try:
200
+ face_det_results = face_detect(video_frames.copy(), detector, args, resize=True)
201
+ except Exception as e:
202
+ print("Error:", e, video_path, audio_path)
203
+ import traceback
204
+ print(traceback.format_exc())
205
+ wrong_all_indiv_mels, wrong_audio_wavform = load_all_indiv_mels(audio_path, args)
206
+
207
+ min_frames = min(len(video_frames), len(wrong_all_indiv_mels))
208
+ video_frames = video_frames[:min_frames]
209
+ face_det_results = face_det_results[:min_frames]
210
+ face_bboxes = [face_det_results[i][1] for i in range(min_frames)]
211
+ face_frames = torch.FloatTensor(np.transpose(np.asarray([face_det_results[i][0] for i in range(min_frames)], dtype=np.float32)/255.,(0,3,1,2)))#[N, C, H, W]
212
+ wrong_all_indiv_mels = torch.FloatTensor(np.asarray(wrong_all_indiv_mels[:min_frames])).unsqueeze(1) #[N, 1, h, w]
213
+
214
+ if save_orig:
215
+ if out_path is None:
216
+ out_path_orig = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+"_orig.mp4")
217
+ else:
218
+ out_path_orig = out_path.replace(".mp4", "_orig.mp4")
219
+ torchvision.io.write_video(
220
+ out_path_orig,
221
+ video_array=torch.from_numpy(np.array(video_frames)), fps = args.video_fps, video_codec='libx264',
222
+ audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
223
+ )
224
+
225
+ if args.sampling_ref_type=='gt':
226
+ ref_frames = face_frames.clone()
227
+ elif args.sampling_ref_type=='first_frame':
228
+ ref_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
229
+ elif args.sampling_ref_type=='random':
230
+ rand_idx = random.Random(args.sampling_seed).randint(0, len(face_frames)-1)
231
+ ref_frames = face_frames[rand_idx:rand_idx+1].repeat(len(face_frames),1,1,1)
232
+
233
+ if args.sampling_input_type=='first_frame':
234
+ face_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
235
+ video_frames = np.array(video_frames[0:1]*len(video_frames))
236
+ face_bboxes = np.array(face_bboxes[0:1]*len(face_bboxes))
237
+
238
+
239
+ rank = dist.get_rank()
240
+ world_size = dist.get_world_size()
241
+ chunk_size = int(np.ceil(min_frames/world_size))
242
+ start_idx = rank * chunk_size
243
+ end_idx = min(start_idx + chunk_size, min_frames)
244
+ generated_video_frames = []
245
+ b_s = args.sampling_batch_size
246
+
247
+ # print(rank,"/",world_size, "chunk: [",start_idx,"-", end_idx,"/",min_frames,"]")
248
+
249
+ dist.barrier()
250
+ torch.cuda.synchronize()
251
+ t1=time()
252
+ # for i in range(0,min_frames, b_s*args.nframes):
253
+ for i in range(start_idx,end_idx, b_s*args.nframes):
254
+ slice_end = min(i+b_s*args.nframes, end_idx)
255
+ # if rank==0:
256
+ # print("rank 0: slice:",i,":",slice_end)
257
+ video_frames_batch = video_frames[i:slice_end]
258
+ face_bboxes_batch = face_bboxes[i:slice_end]
259
+
260
+ # try:
261
+ if (slice_end-i) % args.nframes==0:
262
+ img_batch = face_frames[i:slice_end] #[BF, C, H, W]
263
+ img_batch = img_batch.reshape(-1, args.nframes, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
264
+ ref_batch = ref_frames[i:slice_end]
265
+ ref_batch = ref_batch.reshape(-1, args.nframes, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
266
+ wrong_indiv_mel_batch = wrong_all_indiv_mels[i:slice_end] #[BF, 1, h, w]
267
+ wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, args.nframes, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
268
+ # except:
269
+ else: # of the last batch, if B*F % nframes!=0, then the above reshape throws error
270
+ # but internally everything is going to get converted to BF
271
+ # ie. (B,F, C, H, W) -> (B*F, C, H, W) but (B*F, 1, C, H, W) -> (B*F, C, H, W)
272
+ img_batch = face_frames[i:slice_end] #[BF, C, H, W]
273
+ img_batch = img_batch.reshape(-1, 1, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
274
+ ref_batch = ref_frames[i:slice_end]
275
+ ref_batch = ref_batch.reshape(-1, 1, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
276
+ wrong_indiv_mel_batch = wrong_all_indiv_mels[i:slice_end] #[BF, 1, h, w]
277
+ wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, 1, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
278
+
279
+
280
+ batch = {"image":img_batch,
281
+ "ref_img":ref_batch,
282
+ "indiv_mels":wrong_indiv_mel_batch}
283
+
284
+ sample, img_batch, model_kwargs = sample_batch(batch, model, diffusion, args)
285
+ mask = model_kwargs['mask']
286
+ recon_batch = sample * mask + (1. -mask)*img_batch #[BF, C, H, W]
287
+ recon_batch = (normalise(recon_batch)*255).cpu().numpy().transpose(0,2,3,1) #[-1,1] -> [0,255]
288
+
289
+ for g,v,b in zip(recon_batch, video_frames_batch, face_bboxes_batch):
290
+ y1, y2, x1, x2 = b
291
+ g = cv2.resize(g.astype(np.uint8), (x2 - x1, y2 - y1))
292
+ v[y1:y2, x1:x2] = g
293
+ generated_video_frames.append(v)
294
+
295
+ torch.cuda.synchronize()
296
+ t3=time()
297
+ all_generated_video_frames = [None for _ in range(dist.get_world_size())]
298
+ dist.all_gather_object(all_generated_video_frames, generated_video_frames) # gather not supported with NCCL
299
+ all_generated_video_frames_combined = []
300
+ [all_generated_video_frames_combined.extend(gvf) for gvf in all_generated_video_frames]
301
+ generated_video_frames = all_generated_video_frames_combined
302
+
303
+ torch.cuda.synchronize()
304
+ t2=time()
305
+
306
+ if dist.get_rank() == 0:
307
+ print("Time taken for sampling, ", t2-t1, ",time without all gather, ", t3-t1, ",frames/gpu, ", len(generated_video_frames), ",total frames, ", min_frames)
308
+ print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
309
+ min_time = len(generated_video_frames)/args.video_fps # because video is already smaller because it got chopped accoding to the mel array length
310
+ wrong_audio_wavform = wrong_audio_wavform[:int(min_time*args.sample_rate)]
311
+ print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
312
+ if out_path is None:
313
+ out_path = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+".mp4")
314
+ torchvision.io.write_video(
315
+ out_path,
316
+ video_array=torch.from_numpy(np.array(generated_video_frames)), fps = args.video_fps, video_codec='libx264',
317
+ audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
318
+ )
319
+ dist.barrier()
320
+
321
+
322
+
323
+
324
+
325
+ def generate_from_filelist(test_video_dir, filelist, model, diffusion, detector, args):
326
+ video_names = []
327
+ audio_names = []
328
+ with open(filelist, "r") as f:
329
+ lines = f.readlines()
330
+ for line in tqdm(lines):
331
+ try:
332
+ audio_name, video_name = line.strip().split()
333
+ audio_path = join(test_video_dir, audio_name+'.mp4')
334
+ video_path = join(test_video_dir, video_name+'.mp4')
335
+ out_path = join(args.sample_path,audio_name.replace('/','.')+"_"+video_name.replace('/','.')+".mp4")
336
+ generate(video_path, audio_path, model, diffusion, detector, args, out_path=out_path ,save_orig=args.save_orig)
337
+ except Exception as e:
338
+ print("Error:", e, video_path, audio_path)
339
+ import traceback
340
+ print(traceback.format_exc())
341
+
342
+
343
+
344
+ def main():
345
+ args = create_argparser().parse_args()
346
+ dist_util.setup_dist()
347
+ logger.configure(dir=args.sample_path, format_strs=["stdout", "log"])
348
+
349
+ logger.log("creating model...")
350
+ model, diffusion = tfg_create_model_and_diffusion(
351
+ **args_to_dict(args, tfg_model_and_diffusion_defaults().keys())
352
+ )
353
+ model.load_state_dict(
354
+ dist_util.load_state_dict(args.model_path, map_location='cpu')
355
+ )
356
+ model.to(dist_util.dev())
357
+ if args.use_fp16:
358
+ model.convert_to_fp16()
359
+ model.eval()
360
+
361
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda' if torch.cuda.is_available() else 'cpu')
362
+
363
+ if args.generate_from_filelist:
364
+ generate_from_filelist(args.test_video_dir, args.filelist, model, diffusion, detector, args)
365
+ else:
366
+ generate(args.video_path, args.audio_path, model, diffusion, detector, args, out_path=args.out_path, save_orig=args.save_orig)
367
+
368
+
369
+ def create_argparser():
370
+ defaults = dict(
371
+ # generate from a single audio-video pair
372
+ generate_from_filelist = False,
373
+ video_path = "",
374
+ audio_path = "",
375
+ out_path = None,
376
+ save_orig = True,
377
+
378
+ #generate from filelist : generate_from_filelist = True
379
+ test_video_dir = "test_videos",
380
+ filelist = "test_filelist.txt",
381
+
382
+
383
+ use_fp16 = True,
384
+ #tfg specific
385
+ face_hide_percentage=0.5,
386
+ use_ref=False,
387
+ use_audio=False,
388
+ audio_as_style=False,
389
+ audio_as_style_encoder_mlp=False,
390
+
391
+ #data args
392
+ nframes=1,
393
+ nrefer=0,
394
+ image_size=128,
395
+ syncnet_T = 5,
396
+ syncnet_mel_step_size = 16,
397
+ audio_frames_per_video = 16, #for tfg model, we use sound corresponding to 5 frames centred at that frame
398
+ audio_dim=80,
399
+ is_voxceleb2=True,
400
+
401
+ video_fps=25,
402
+ sample_rate=16000, #audio sampling rate
403
+ mel_steps_per_sec=80.,
404
+
405
+ #sampling args
406
+ clip_denoised=True, # not used in training
407
+ sampling_batch_size=2,
408
+ use_ddim=False,
409
+ model_path="",
410
+ sample_path="d2l_gen",
411
+ sample_partition="",
412
+ sampling_seed=None,
413
+ sampling_use_gt_for_ref=False,
414
+ sampling_ref_type='gt', #one of ['gt', 'first_frame', 'random']
415
+ sampling_input_type='gt', #one of ['gt', 'first_frame']
416
+
417
+ # face detection args
418
+ face_det_batch_size=64,
419
+ pads = "0,0,0,0"
420
+ )
421
+ defaults.update(tfg_model_and_diffusion_defaults())
422
+ parser = argparse.ArgumentParser()
423
+ add_dict_to_argparser(parser, defaults)
424
+ return parser
425
+
426
+
427
+ if __name__=="__main__":
428
+ main()
guided-diffusion/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 OpenAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
guided-diffusion/guided_diffusion/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Codebase for "Improved Denoising Diffusion Probabilistic Models".
3
+ """
guided-diffusion/guided_diffusion/dist_util.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import io
6
+ import os
7
+ import socket
8
+
9
+ import blobfile as bf
10
+ from mpi4py import MPI
11
+ import torch as th
12
+ import torch.distributed as dist
13
+
14
+ # Change this to reflect your cluster layout.
15
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
16
+ GPUS_PER_NODE = 8
17
+
18
+ SETUP_RETRY_COUNT = 3
19
+
20
+
21
+ def setup_dist():
22
+ """
23
+ Setup a distributed process group.
24
+ """
25
+ if dist.is_initialized():
26
+ return
27
+ print("MPI.COMM_WORLD.Get_rank()", MPI.COMM_WORLD.Get_rank())
28
+ os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
29
+ print('os.environ["CUDA_VISIBLE_DEVICES"]', os.environ["CUDA_VISIBLE_DEVICES"])
30
+ comm = MPI.COMM_WORLD
31
+ backend = "gloo" if not th.cuda.is_available() else "nccl"
32
+
33
+ if backend == "gloo":
34
+ hostname = "localhost"
35
+ else:
36
+ hostname = socket.gethostbyname(socket.getfqdn())
37
+ os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
38
+ os.environ["RANK"] = str(comm.rank)
39
+ os.environ["WORLD_SIZE"] = str(comm.size)
40
+
41
+ port = comm.bcast(_find_free_port(), root=0)
42
+ os.environ["MASTER_PORT"] = str(port)
43
+ dist.init_process_group(backend=backend, init_method="env://")
44
+
45
+
46
+ def dev():
47
+ """
48
+ Get the device to use for torch.distributed.
49
+ """
50
+ if th.cuda.is_available():
51
+ return th.device(f"cuda")
52
+ return th.device("cpu")
53
+
54
+
55
+ def load_state_dict(path, **kwargs):
56
+ """
57
+ Load a PyTorch file without redundant fetches across MPI ranks.
58
+ """
59
+ chunk_size = 2 ** 30 # MPI has a relatively small size limit
60
+ if MPI.COMM_WORLD.Get_rank() == 0:
61
+ with bf.BlobFile(path, "rb") as f:
62
+ data = f.read()
63
+ num_chunks = len(data) // chunk_size
64
+ if len(data) % chunk_size:
65
+ num_chunks += 1
66
+ MPI.COMM_WORLD.bcast(num_chunks)
67
+ for i in range(0, len(data), chunk_size):
68
+ MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
69
+ else:
70
+ num_chunks = MPI.COMM_WORLD.bcast(None)
71
+ data = bytes()
72
+ for _ in range(num_chunks):
73
+ data += MPI.COMM_WORLD.bcast(None)
74
+
75
+ return th.load(io.BytesIO(data), **kwargs)
76
+
77
+
78
+ def sync_params(params):
79
+ """
80
+ Synchronize a sequence of Tensors across ranks from rank 0.
81
+ """
82
+ for p in params:
83
+ with th.no_grad():
84
+ dist.broadcast(p, 0)
85
+
86
+
87
+ def _find_free_port():
88
+ try:
89
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
90
+ s.bind(("", 0))
91
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
92
+ return s.getsockname()[1]
93
+ finally:
94
+ s.close()
guided-diffusion/guided_diffusion/fp16_util.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ from . import logger
11
+
12
+ INITIAL_LOG_LOSS_SCALE = 20.0
13
+
14
+
15
+ def convert_module_to_f16(l):
16
+ """
17
+ Convert primitive modules to float16.
18
+ """
19
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20
+ l.weight.data = l.weight.data.half()
21
+ if l.bias is not None:
22
+ l.bias.data = l.bias.data.half()
23
+
24
+
25
+ def convert_module_to_f32(l):
26
+ """
27
+ Convert primitive modules to float32, undoing convert_module_to_f16().
28
+ """
29
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30
+ l.weight.data = l.weight.data.float()
31
+ if l.bias is not None:
32
+ l.bias.data = l.bias.data.float()
33
+
34
+
35
+ def make_master_params(param_groups_and_shapes):
36
+ """
37
+ Copy model parameters into a (differently-shaped) list of full-precision
38
+ parameters.
39
+ """
40
+ master_params = []
41
+ for param_group, shape in param_groups_and_shapes:
42
+ master_param = nn.Parameter(
43
+ _flatten_dense_tensors(
44
+ [param.detach().float() for (_, param) in param_group]
45
+ ).view(shape)
46
+ )
47
+ master_param.requires_grad = True
48
+ master_params.append(master_param)
49
+ return master_params
50
+
51
+
52
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53
+ """
54
+ Copy the gradients from the model parameters into the master parameters
55
+ from make_master_params().
56
+ """
57
+ for master_param, (param_group, shape) in zip(
58
+ master_params, param_groups_and_shapes
59
+ ):
60
+ master_param.grad = _flatten_dense_tensors(
61
+ [param_grad_or_zeros(param) for (_, param) in param_group]
62
+ ).view(shape)
63
+
64
+
65
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
66
+ """
67
+ Copy the master parameter data back into the model parameters.
68
+ """
69
+ # Without copying to a list, if a generator is passed, this will
70
+ # silently not copy any parameters.
71
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72
+ for (_, param), unflat_master_param in zip(
73
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
74
+ ):
75
+ param.detach().copy_(unflat_master_param)
76
+
77
+
78
+ def unflatten_master_params(param_group, master_param):
79
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80
+
81
+
82
+ def get_param_groups_and_shapes(named_model_params):
83
+ named_model_params = list(named_model_params)
84
+ scalar_vector_named_params = (
85
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86
+ (-1),
87
+ )
88
+ matrix_named_params = (
89
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90
+ (1, -1),
91
+ )
92
+ return [scalar_vector_named_params, matrix_named_params]
93
+
94
+
95
+ def master_params_to_state_dict(
96
+ model, param_groups_and_shapes, master_params, use_fp16
97
+ ):
98
+ if use_fp16:
99
+ state_dict = model.state_dict()
100
+ for master_param, (param_group, _) in zip(
101
+ master_params, param_groups_and_shapes
102
+ ):
103
+ for (name, _), unflat_master_param in zip(
104
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
105
+ ):
106
+ assert name in state_dict
107
+ state_dict[name] = unflat_master_param
108
+ else:
109
+ state_dict = model.state_dict()
110
+ for i, (name, _value) in enumerate(model.named_parameters()):
111
+ assert name in state_dict
112
+ state_dict[name] = master_params[i]
113
+ return state_dict
114
+
115
+
116
+ def state_dict_to_master_params(model, state_dict, use_fp16):
117
+ if use_fp16:
118
+ named_model_params = [
119
+ (name, state_dict[name]) for name, _ in model.named_parameters()
120
+ ]
121
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122
+ master_params = make_master_params(param_groups_and_shapes)
123
+ else:
124
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
125
+ return master_params
126
+
127
+
128
+ def zero_master_grads(master_params):
129
+ for param in master_params:
130
+ param.grad = None
131
+
132
+
133
+ def zero_grad(model_params):
134
+ for param in model_params:
135
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136
+ if param.grad is not None:
137
+ param.grad.detach_()
138
+ param.grad.zero_()
139
+
140
+
141
+ def param_grad_or_zeros(param):
142
+ if param.grad is not None:
143
+ return param.grad.data.detach()
144
+ else:
145
+ return th.zeros_like(param)
146
+
147
+
148
+ class MixedPrecisionTrainer:
149
+ def __init__(
150
+ self,
151
+ *,
152
+ model,
153
+ use_fp16=False,
154
+ fp16_scale_growth=1e-3,
155
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156
+ ):
157
+ self.model = model
158
+ self.use_fp16 = use_fp16
159
+ self.fp16_scale_growth = fp16_scale_growth
160
+
161
+ self.model_params = list(self.model.parameters())
162
+ self.master_params = self.model_params
163
+ self.param_groups_and_shapes = None
164
+ self.lg_loss_scale = initial_lg_loss_scale
165
+
166
+ if self.use_fp16:
167
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
168
+ self.model.named_parameters()
169
+ )
170
+ self.master_params = make_master_params(self.param_groups_and_shapes)
171
+ self.model.convert_to_fp16()
172
+
173
+ def zero_grad(self):
174
+ zero_grad(self.model_params)
175
+
176
+ def backward(self, loss: th.Tensor):
177
+ if self.use_fp16:
178
+ loss_scale = 2 ** self.lg_loss_scale
179
+ (loss * loss_scale).backward()
180
+ else:
181
+ loss.backward()
182
+
183
+ def optimize(self, opt: th.optim.Optimizer):
184
+ if self.use_fp16:
185
+ return self._optimize_fp16(opt)
186
+ else:
187
+ return self._optimize_normal(opt)
188
+
189
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
190
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193
+ if check_overflow(grad_norm):
194
+ self.lg_loss_scale -= 1
195
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196
+ zero_master_grads(self.master_params)
197
+ return False
198
+
199
+ logger.logkv_mean("grad_norm", grad_norm)
200
+ logger.logkv_mean("param_norm", param_norm)
201
+
202
+ for p in self.master_params:
203
+ p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204
+ opt.step()
205
+ zero_master_grads(self.master_params)
206
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
207
+ self.lg_loss_scale += self.fp16_scale_growth
208
+ return True
209
+
210
+ def _optimize_normal(self, opt: th.optim.Optimizer):
211
+ grad_norm, param_norm = self._compute_norms()
212
+ logger.logkv_mean("grad_norm", grad_norm)
213
+ logger.logkv_mean("param_norm", param_norm)
214
+ opt.step()
215
+ return True
216
+
217
+ def _compute_norms(self, grad_scale=1.0):
218
+ grad_norm = 0.0
219
+ param_norm = 0.0
220
+ for p in self.master_params:
221
+ with th.no_grad():
222
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
223
+ if p.grad is not None:
224
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
225
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
226
+
227
+ def master_params_to_state_dict(self, master_params):
228
+ return master_params_to_state_dict(
229
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
230
+ )
231
+
232
+ def state_dict_to_master_params(self, state_dict):
233
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
234
+
235
+
236
+ def check_overflow(value):
237
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
guided-diffusion/guided_diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ import enum
9
+ import math
10
+
11
+ import numpy as np
12
+ import torch as th
13
+ import os
14
+
15
+ from . import dist_util
16
+ from .nn import mean_flat
17
+ from .losses import normal_kl, discretized_gaussian_log_likelihood
18
+
19
+
20
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
21
+ """
22
+ Get a pre-defined beta schedule for the given name.
23
+
24
+ The beta schedule library consists of beta schedules which remain similar
25
+ in the limit of num_diffusion_timesteps.
26
+ Beta schedules may be added, but should not be removed or changed once
27
+ they are committed to maintain backwards compatibility.
28
+ """
29
+ if schedule_name == "linear":
30
+ # Linear schedule from Ho et al, extended to work for any number of
31
+ # diffusion steps.
32
+ scale = 1000 / num_diffusion_timesteps
33
+ beta_start = scale * 0.0001
34
+ beta_end = scale * 0.02
35
+ return np.linspace(
36
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
37
+ )
38
+ elif schedule_name == "cosine":
39
+ return betas_for_alpha_bar(
40
+ num_diffusion_timesteps,
41
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
42
+ )
43
+ else:
44
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
45
+
46
+
47
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
48
+ """
49
+ Create a beta schedule that discretizes the given alpha_t_bar function,
50
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
51
+
52
+ :param num_diffusion_timesteps: the number of betas to produce.
53
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
54
+ produces the cumulative product of (1-beta) up to that
55
+ part of the diffusion process.
56
+ :param max_beta: the maximum beta to use; use values lower than 1 to
57
+ prevent singularities.
58
+ """
59
+ betas = []
60
+ for i in range(num_diffusion_timesteps):
61
+ t1 = i / num_diffusion_timesteps
62
+ t2 = (i + 1) / num_diffusion_timesteps
63
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
64
+ return np.array(betas)
65
+
66
+
67
+ class ModelMeanType(enum.Enum):
68
+ """
69
+ Which type of output the model predicts.
70
+ """
71
+
72
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
73
+ START_X = enum.auto() # the model predicts x_0
74
+ EPSILON = enum.auto() # the model predicts epsilon
75
+
76
+
77
+ class ModelVarType(enum.Enum):
78
+ """
79
+ What is used as the model's output variance.
80
+
81
+ The LEARNED_RANGE option has been added to allow the model to predict
82
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
83
+ """
84
+
85
+ LEARNED = enum.auto()
86
+ FIXED_SMALL = enum.auto()
87
+ FIXED_LARGE = enum.auto()
88
+ LEARNED_RANGE = enum.auto()
89
+
90
+
91
+ class LossType(enum.Enum):
92
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
93
+ RESCALED_MSE = (
94
+ enum.auto()
95
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
96
+ KL = enum.auto() # use the variational lower-bound
97
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
98
+
99
+ def is_vb(self):
100
+ return self == LossType.KL or self == LossType.RESCALED_KL
101
+
102
+
103
+ class GaussianDiffusion:
104
+ """
105
+ Utilities for training and sampling diffusion models.
106
+
107
+ Ported directly from here, and then adapted over time to further experimentation.
108
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
109
+
110
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
111
+ starting at T and going to 1.
112
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
113
+ :param model_var_type: a ModelVarType determining how variance is output.
114
+ :param loss_type: a LossType determining the loss function to use.
115
+ :param rescale_timesteps: if True, pass floating point timesteps into the
116
+ model so that they are always scaled like in the
117
+ original paper (0 to 1000).
118
+ :param loss_variation: if True, then use composite loss
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ *,
124
+ betas,
125
+ model_mean_type,
126
+ model_var_type,
127
+ loss_type,
128
+ rescale_timesteps=False,
129
+ loss_variation=False,
130
+ ):
131
+ self.model_mean_type = model_mean_type
132
+ self.model_var_type = model_var_type
133
+ self.loss_type = loss_type
134
+ self.rescale_timesteps = rescale_timesteps
135
+ self.loss_variation = loss_variation
136
+
137
+ # Use float64 for accuracy.
138
+ betas = np.array(betas, dtype=np.float64)
139
+ self.betas = betas
140
+ assert len(betas.shape) == 1, "betas must be 1-D"
141
+ assert (betas > 0).all() and (betas <= 1).all()
142
+
143
+ self.num_timesteps = int(betas.shape[0])
144
+
145
+ alphas = 1.0 - betas
146
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
147
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
148
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
149
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
150
+
151
+ # calculations for diffusion q(x_t | x_{t-1}) and others
152
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
153
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
154
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
155
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
156
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
157
+
158
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
159
+ self.posterior_variance = (
160
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
161
+ )
162
+ # log calculation clipped because the posterior variance is 0 at the
163
+ # beginning of the diffusion chain.
164
+ self.posterior_log_variance_clipped = np.log(
165
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
166
+ )
167
+ self.posterior_mean_coef1 = (
168
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
169
+ )
170
+ self.posterior_mean_coef2 = (
171
+ (1.0 - self.alphas_cumprod_prev)
172
+ * np.sqrt(alphas)
173
+ / (1.0 - self.alphas_cumprod)
174
+ )
175
+
176
+ def q_mean_variance(self, x_start, t):
177
+ """
178
+ Get the distribution q(x_t | x_0).
179
+
180
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
181
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
182
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
183
+ """
184
+ mean = (
185
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
186
+ )
187
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
188
+ log_variance = _extract_into_tensor(
189
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
190
+ )
191
+ return mean, variance, log_variance
192
+
193
+ def q_sample(self, x_start, t, noise=None):
194
+ """
195
+ Diffuse the data for a given number of diffusion steps.
196
+
197
+ In other words, sample from q(x_t | x_0).
198
+
199
+ :param x_start: the initial data batch.
200
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
201
+ :param noise: if specified, the split-out normal noise.
202
+ :return: A noisy version of x_start.
203
+ """
204
+ if noise is None:
205
+ noise = th.randn_like(x_start)
206
+ assert noise.shape == x_start.shape
207
+ return (
208
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
209
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
210
+ * noise
211
+ )
212
+
213
+ def q_posterior_mean_variance(self, x_start, x_t, t):
214
+ """
215
+ Compute the mean and variance of the diffusion posterior:
216
+
217
+ q(x_{t-1} | x_t, x_0)
218
+
219
+ """
220
+ assert x_start.shape == x_t.shape
221
+ posterior_mean = (
222
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
223
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
224
+ )
225
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
226
+ posterior_log_variance_clipped = _extract_into_tensor(
227
+ self.posterior_log_variance_clipped, t, x_t.shape
228
+ )
229
+ assert (
230
+ posterior_mean.shape[0]
231
+ == posterior_variance.shape[0]
232
+ == posterior_log_variance_clipped.shape[0]
233
+ == x_start.shape[0]
234
+ )
235
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
236
+
237
+ def p_mean_variance(
238
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
239
+ ):
240
+ """
241
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
242
+ the initial x, x_0.
243
+
244
+ :param model: the model, which takes a signal and a batch of timesteps
245
+ as input.
246
+ :param x: the [N x C x ...] tensor at time t.
247
+ :param t: a 1-D Tensor of timesteps.
248
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
249
+ :param denoised_fn: if not None, a function which applies to the
250
+ x_start prediction before it is used to sample. Applies before
251
+ clip_denoised.
252
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
253
+ pass to the model. This can be used for conditioning.
254
+ :return: a dict with the following keys:
255
+ - 'mean': the model mean output.
256
+ - 'variance': the model variance output.
257
+ - 'log_variance': the log of 'variance'.
258
+ - 'pred_xstart': the prediction for x_0.
259
+ """
260
+ if model_kwargs is None:
261
+ model_kwargs = {}
262
+
263
+ B, C = x.shape[:2]
264
+ assert t.shape == (B,)
265
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
266
+
267
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
268
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
269
+ model_output, model_var_values = th.split(model_output, C, dim=1)
270
+ if self.model_var_type == ModelVarType.LEARNED:
271
+ model_log_variance = model_var_values
272
+ model_variance = th.exp(model_log_variance)
273
+ else:
274
+ min_log = _extract_into_tensor(
275
+ self.posterior_log_variance_clipped, t, x.shape
276
+ )
277
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
278
+ # The model_var_values is [-1, 1] for [min_var, max_var].
279
+ frac = (model_var_values + 1) / 2
280
+ model_log_variance = frac * max_log + (1 - frac) * min_log
281
+ model_variance = th.exp(model_log_variance)
282
+ else:
283
+ model_variance, model_log_variance = {
284
+ # for fixedlarge, we set the initial (log-)variance like so
285
+ # to get a better decoder log likelihood.
286
+ ModelVarType.FIXED_LARGE: (
287
+ np.append(self.posterior_variance[1], self.betas[1:]),
288
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
289
+ ),
290
+ ModelVarType.FIXED_SMALL: (
291
+ self.posterior_variance,
292
+ self.posterior_log_variance_clipped,
293
+ ),
294
+ }[self.model_var_type]
295
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
296
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
297
+
298
+ def process_xstart(x):
299
+ if denoised_fn is not None:
300
+ x = denoised_fn(x)
301
+ if clip_denoised:
302
+ return x.clamp(-1, 1)
303
+ return x
304
+
305
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
306
+ pred_xstart = process_xstart(
307
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
308
+ )
309
+ model_mean = model_output
310
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
311
+ if self.model_mean_type == ModelMeanType.START_X:
312
+ pred_xstart = process_xstart(model_output)
313
+ else:
314
+ pred_xstart = process_xstart(
315
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
316
+ )
317
+ model_mean, _, _ = self.q_posterior_mean_variance(
318
+ x_start=pred_xstart, x_t=x, t=t
319
+ )
320
+ else:
321
+ raise NotImplementedError(self.model_mean_type)
322
+
323
+ assert (
324
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
325
+ )
326
+ return {
327
+ "mean": model_mean,
328
+ "variance": model_variance,
329
+ "log_variance": model_log_variance,
330
+ "pred_xstart": pred_xstart,
331
+ }
332
+
333
+ def _predict_xstart_from_eps(self, x_t, t, eps):
334
+ assert x_t.shape == eps.shape
335
+ return (
336
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
337
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
338
+ )
339
+
340
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
341
+ assert x_t.shape == xprev.shape
342
+ return ( # (xprev - coef2*x_t) / coef1
343
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
344
+ - _extract_into_tensor(
345
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
346
+ )
347
+ * x_t
348
+ )
349
+
350
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
351
+ return (
352
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
353
+ - pred_xstart
354
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
355
+
356
+ def _scale_timesteps(self, t):
357
+ if self.rescale_timesteps:
358
+ return t.float() * (1000.0 / self.num_timesteps)
359
+ return t
360
+
361
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
362
+ """
363
+ Compute the mean for the previous step, given a function cond_fn that
364
+ computes the gradient of a conditional log probability with respect to
365
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
366
+ condition on y.
367
+
368
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
369
+ """
370
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
371
+ new_mean = (
372
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
373
+ )
374
+ return new_mean
375
+
376
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
377
+ """
378
+ Compute what the p_mean_variance output would have been, should the
379
+ model's score function be conditioned by cond_fn.
380
+
381
+ See condition_mean() for details on cond_fn.
382
+
383
+ Unlike condition_mean(), this instead uses the conditioning strategy
384
+ from Song et al (2020).
385
+ """
386
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
387
+
388
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
389
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
390
+ x, self._scale_timesteps(t), **model_kwargs
391
+ )
392
+
393
+ out = p_mean_var.copy()
394
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
395
+ out["mean"], _, _ = self.q_posterior_mean_variance(
396
+ x_start=out["pred_xstart"], x_t=x, t=t
397
+ )
398
+ return out
399
+
400
+ def p_sample(
401
+ self,
402
+ model,
403
+ x,
404
+ t,
405
+ clip_denoised=True,
406
+ denoised_fn=None,
407
+ cond_fn=None,
408
+ model_kwargs=None,
409
+ ):
410
+ """
411
+ Sample x_{t-1} from the model at the given timestep.
412
+
413
+ :param model: the model to sample from.
414
+ :param x: the current tensor at x_{t-1}.
415
+ :param t: the value of t, starting at 0 for the first diffusion step.
416
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
417
+ :param denoised_fn: if not None, a function which applies to the
418
+ x_start prediction before it is used to sample.
419
+ :param cond_fn: if not None, this is a gradient function that acts
420
+ similarly to the model.
421
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
422
+ pass to the model. This can be used for conditioning.
423
+ :return: a dict containing the following keys:
424
+ - 'sample': a random sample from the model.
425
+ - 'pred_xstart': a prediction of x_0.
426
+ """
427
+ out = self.p_mean_variance(
428
+ model,
429
+ x,
430
+ t,
431
+ clip_denoised=clip_denoised,
432
+ denoised_fn=denoised_fn,
433
+ model_kwargs=model_kwargs,
434
+ )
435
+ noise = th.randn_like(x)
436
+ nonzero_mask = (
437
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
438
+ ) # no noise when t == 0
439
+ if cond_fn is not None:
440
+ out["mean"] = self.condition_mean(
441
+ cond_fn, out, x, t, model_kwargs=model_kwargs
442
+ )
443
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
444
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
445
+
446
+ def p_sample_loop(
447
+ self,
448
+ model,
449
+ shape,
450
+ noise=None,
451
+ clip_denoised=True,
452
+ denoised_fn=None,
453
+ cond_fn=None,
454
+ model_kwargs=None,
455
+ device=None,
456
+ progress=False,
457
+ ):
458
+ """
459
+ Generate samples from the model.
460
+
461
+ :param model: the model module.
462
+ :param shape: the shape of the samples, (N, C, H, W).
463
+ :param noise: if specified, the noise from the encoder to sample.
464
+ Should be of the same shape as `shape`.
465
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
466
+ :param denoised_fn: if not None, a function which applies to the
467
+ x_start prediction before it is used to sample.
468
+ :param cond_fn: if not None, this is a gradient function that acts
469
+ similarly to the model.
470
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
471
+ pass to the model. This can be used for conditioning.
472
+ :param device: if specified, the device to create the samples on.
473
+ If not specified, use a model parameter's device.
474
+ :param progress: if True, show a tqdm progress bar.
475
+ :return: a non-differentiable batch of samples.
476
+ """
477
+ final = None
478
+ for sample in self.p_sample_loop_progressive(
479
+ model,
480
+ shape,
481
+ noise=noise,
482
+ clip_denoised=clip_denoised,
483
+ denoised_fn=denoised_fn,
484
+ cond_fn=cond_fn,
485
+ model_kwargs=model_kwargs,
486
+ device=device,
487
+ progress=progress,
488
+ ):
489
+ final = sample
490
+ return final["sample"]
491
+
492
+ def p_sample_loop_progressive(
493
+ self,
494
+ model,
495
+ shape,
496
+ noise=None,
497
+ clip_denoised=True,
498
+ denoised_fn=None,
499
+ cond_fn=None,
500
+ model_kwargs=None,
501
+ device=None,
502
+ progress=False,
503
+ ):
504
+ """
505
+ Generate samples from the model and yield intermediate samples from
506
+ each timestep of diffusion.
507
+
508
+ Arguments are the same as p_sample_loop().
509
+ Returns a generator over dicts, where each dict is the return value of
510
+ p_sample().
511
+ """
512
+ if device is None:
513
+ device = next(model.parameters()).device
514
+ assert isinstance(shape, (tuple, list))
515
+ if noise is not None:
516
+ img = noise
517
+ else:
518
+ img = th.randn(*shape, device=device)
519
+ indices = list(range(self.num_timesteps))[::-1]
520
+
521
+ if progress:
522
+ # Lazy import so that we don't depend on tqdm.
523
+ from tqdm.auto import tqdm
524
+
525
+ indices = tqdm(indices)
526
+
527
+ for i in indices:
528
+ t = th.tensor([i] * shape[0], device=device)
529
+ with th.no_grad():
530
+ out = self.p_sample(
531
+ model,
532
+ img,
533
+ t,
534
+ clip_denoised=clip_denoised,
535
+ denoised_fn=denoised_fn,
536
+ cond_fn=cond_fn,
537
+ model_kwargs=model_kwargs,
538
+ )
539
+ yield out
540
+ img = out["sample"]
541
+
542
+ def ddim_sample(
543
+ self,
544
+ model,
545
+ x,
546
+ t,
547
+ clip_denoised=True,
548
+ denoised_fn=None,
549
+ cond_fn=None,
550
+ model_kwargs=None,
551
+ eta=0.0,
552
+ ):
553
+ """
554
+ Sample x_{t-1} from the model using DDIM.
555
+
556
+ Same usage as p_sample().
557
+ """
558
+ out = self.p_mean_variance(
559
+ model,
560
+ x,
561
+ t,
562
+ clip_denoised=clip_denoised,
563
+ denoised_fn=denoised_fn,
564
+ model_kwargs=model_kwargs,
565
+ )
566
+ if cond_fn is not None:
567
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
568
+
569
+ # Usually our model outputs epsilon, but we re-derive it
570
+ # in case we used x_start or x_prev prediction.
571
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
572
+
573
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
574
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
575
+ sigma = (
576
+ eta
577
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
578
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
579
+ )
580
+ # Equation 12.
581
+ noise = th.randn_like(x)
582
+ mean_pred = (
583
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
584
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
585
+ )
586
+ nonzero_mask = (
587
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
588
+ ) # no noise when t == 0
589
+ sample = mean_pred + nonzero_mask * sigma * noise
590
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
591
+
592
+ def ddim_reverse_sample(
593
+ self,
594
+ model,
595
+ x,
596
+ t,
597
+ clip_denoised=True,
598
+ denoised_fn=None,
599
+ model_kwargs=None,
600
+ eta=0.0,
601
+ ):
602
+ """
603
+ Sample x_{t+1} from the model using DDIM reverse ODE.
604
+ """
605
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
606
+ out = self.p_mean_variance(
607
+ model,
608
+ x,
609
+ t,
610
+ clip_denoised=clip_denoised,
611
+ denoised_fn=denoised_fn,
612
+ model_kwargs=model_kwargs,
613
+ )
614
+ # Usually our model outputs epsilon, but we re-derive it
615
+ # in case we used x_start or x_prev prediction.
616
+ eps = (
617
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
618
+ - out["pred_xstart"]
619
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
620
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
621
+
622
+ # Equation 12. reversed
623
+ mean_pred = (
624
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
625
+ + th.sqrt(1 - alpha_bar_next) * eps
626
+ )
627
+
628
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
629
+
630
+ def ddim_sample_loop(
631
+ self,
632
+ model,
633
+ shape,
634
+ noise=None,
635
+ clip_denoised=True,
636
+ denoised_fn=None,
637
+ cond_fn=None,
638
+ model_kwargs=None,
639
+ device=None,
640
+ progress=False,
641
+ eta=0.0,
642
+ ):
643
+ """
644
+ Generate samples from the model using DDIM.
645
+
646
+ Same usage as p_sample_loop().
647
+ """
648
+ final = None
649
+ for sample in self.ddim_sample_loop_progressive(
650
+ model,
651
+ shape,
652
+ noise=noise,
653
+ clip_denoised=clip_denoised,
654
+ denoised_fn=denoised_fn,
655
+ cond_fn=cond_fn,
656
+ model_kwargs=model_kwargs,
657
+ device=device,
658
+ progress=progress,
659
+ eta=eta,
660
+ ):
661
+ final = sample
662
+ return final["sample"]
663
+
664
+ def ddim_sample_loop_progressive(
665
+ self,
666
+ model,
667
+ shape,
668
+ noise=None,
669
+ clip_denoised=True,
670
+ denoised_fn=None,
671
+ cond_fn=None,
672
+ model_kwargs=None,
673
+ device=None,
674
+ progress=False,
675
+ eta=0.0,
676
+ ):
677
+ """
678
+ Use DDIM to sample from the model and yield intermediate samples from
679
+ each timestep of DDIM.
680
+
681
+ Same usage as p_sample_loop_progressive().
682
+ """
683
+ if device is None:
684
+ device = next(model.parameters()).device
685
+ assert isinstance(shape, (tuple, list))
686
+ if noise is not None:
687
+ img = noise
688
+ else:
689
+ img = th.randn(*shape, device=device)
690
+ indices = list(range(self.num_timesteps))[::-1]
691
+
692
+ if progress:
693
+ # Lazy import so that we don't depend on tqdm.
694
+ from tqdm.auto import tqdm
695
+
696
+ indices = tqdm(indices)
697
+
698
+ for i in indices:
699
+ t = th.tensor([i] * shape[0], device=device)
700
+ with th.no_grad():
701
+ out = self.ddim_sample(
702
+ model,
703
+ img,
704
+ t,
705
+ clip_denoised=clip_denoised,
706
+ denoised_fn=denoised_fn,
707
+ cond_fn=cond_fn,
708
+ model_kwargs=model_kwargs,
709
+ eta=eta,
710
+ )
711
+ yield out
712
+ img = out["sample"]
713
+
714
+ def _vb_terms_bpd(
715
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
716
+ ):
717
+ """
718
+ Get a term for the variational lower-bound.
719
+
720
+ The resulting units are bits (rather than nats, as one might expect).
721
+ This allows for comparison to other papers.
722
+
723
+ :return: a dict with the following keys:
724
+ - 'output': a shape [N] tensor of NLLs or KLs.
725
+ - 'pred_xstart': the x_0 predictions.
726
+ """
727
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
728
+ x_start=x_start, x_t=x_t, t=t
729
+ )
730
+ out = self.p_mean_variance(
731
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
732
+ )
733
+ kl = normal_kl(
734
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
735
+ )
736
+ if ("cond_img" in model_kwargs) and ("mask" in model_kwargs): #added by soumik
737
+ kl = kl*model_kwargs["mask"]
738
+ kl = mean_flat(kl) / np.log(2.0)
739
+
740
+ decoder_nll = -discretized_gaussian_log_likelihood(
741
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
742
+ )
743
+ assert decoder_nll.shape == x_start.shape
744
+ if ("cond_img" in model_kwargs) and ("mask" in model_kwargs): #added by soumik
745
+ decoder_nll=decoder_nll*model_kwargs["mask"]
746
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
747
+
748
+ # At the first timestep return the decoder NLL,
749
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
750
+ output = th.where((t == 0), decoder_nll, kl)
751
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
752
+
753
+
754
+ def _prior_bpd(self, x_start):
755
+ """
756
+ Get the prior KL term for the variational lower-bound, measured in
757
+ bits-per-dim.
758
+
759
+ This term can't be optimized, as it only depends on the encoder.
760
+
761
+ :param x_start: the [N x C x ...] tensor of inputs.
762
+ :return: a batch of [N] KL values (in bits), one per batch element.
763
+ """
764
+ batch_size = x_start.shape[0]
765
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
766
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
767
+ kl_prior = normal_kl(
768
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
769
+ )
770
+ return mean_flat(kl_prior) / np.log(2.0)
771
+
772
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
773
+ """
774
+ Compute the entire variational lower-bound, measured in bits-per-dim,
775
+ as well as other related quantities.
776
+
777
+ :param model: the model to evaluate loss on.
778
+ :param x_start: the [N x C x ...] tensor of inputs.
779
+ :param clip_denoised: if True, clip denoised samples.
780
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
781
+ pass to the model. This can be used for conditioning.
782
+
783
+ :return: a dict containing the following keys:
784
+ - total_bpd: the total variational lower-bound, per batch element.
785
+ - prior_bpd: the prior term in the lower-bound.
786
+ - vb: an [N x T] tensor of terms in the lower-bound.
787
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
788
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
789
+ """
790
+ device = x_start.device
791
+ batch_size = x_start.shape[0]
792
+
793
+ vb = []
794
+ xstart_mse = []
795
+ mse = []
796
+ for t in list(range(self.num_timesteps))[::-1]:
797
+ t_batch = th.tensor([t] * batch_size, device=device)
798
+ noise = th.randn_like(x_start)
799
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
800
+ # Calculate VLB term at the current timestep
801
+ with th.no_grad():
802
+ out = self._vb_terms_bpd(
803
+ model,
804
+ x_start=x_start,
805
+ x_t=x_t,
806
+ t=t_batch,
807
+ clip_denoised=clip_denoised,
808
+ model_kwargs=model_kwargs,
809
+ )
810
+ vb.append(out["output"])
811
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
812
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
813
+ mse.append(mean_flat((eps - noise) ** 2))
814
+
815
+ vb = th.stack(vb, dim=1)
816
+ xstart_mse = th.stack(xstart_mse, dim=1)
817
+ mse = th.stack(mse, dim=1)
818
+
819
+ prior_bpd = self._prior_bpd(x_start)
820
+ total_bpd = vb.sum(dim=1) + prior_bpd
821
+ return {
822
+ "total_bpd": total_bpd,
823
+ "prior_bpd": prior_bpd,
824
+ "vb": vb,
825
+ "xstart_mse": xstart_mse,
826
+ "mse": mse,
827
+ }
828
+
829
+
830
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
831
+ """
832
+ Extract values from a 1-D numpy array for a batch of indices.
833
+
834
+ :param arr: the 1-D numpy array.
835
+ :param timesteps: a tensor of indices into the array to extract.
836
+ :param broadcast_shape: a larger shape of K dimensions with the batch
837
+ dimension equal to the length of timesteps.
838
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
839
+ """
840
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
841
+ while len(res.shape) < len(broadcast_shape):
842
+ res = res[..., None]
843
+ return res.expand(broadcast_shape)
guided-diffusion/guided_diffusion/image_datasets.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ from PIL import Image
5
+ import blobfile as bf
6
+ from mpi4py import MPI
7
+ import numpy as np
8
+ from torch.utils.data import DataLoader, Dataset
9
+
10
+
11
+ def load_data(
12
+ *,
13
+ data_dir,
14
+ batch_size,
15
+ image_size,
16
+ class_cond=False,
17
+ deterministic=False,
18
+ random_crop=False,
19
+ random_flip=True,
20
+ ):
21
+ """
22
+ For a dataset, create a generator over (images, kwargs) pairs.
23
+
24
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
25
+ more keys, each of which map to a batched Tensor of their own.
26
+ The kwargs dict can be used for class labels, in which case the key is "y"
27
+ and the values are integer tensors of class labels.
28
+
29
+ :param data_dir: a dataset directory.
30
+ :param batch_size: the batch size of each returned pair.
31
+ :param image_size: the size to which images are resized.
32
+ :param class_cond: if True, include a "y" key in returned dicts for class
33
+ label. If classes are not available and this is true, an
34
+ exception will be raised.
35
+ :param deterministic: if True, yield results in a deterministic order.
36
+ :param random_crop: if True, randomly crop the images for augmentation.
37
+ :param random_flip: if True, randomly flip the images for augmentation.
38
+ """
39
+ if not data_dir:
40
+ raise ValueError("unspecified data directory")
41
+ all_files = _list_image_files_recursively(data_dir)
42
+ classes = None
43
+ if class_cond:
44
+ # Assume classes are the first part of the filename,
45
+ # before an underscore.
46
+ class_names = [bf.basename(path).split("_")[0] for path in all_files]
47
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
48
+ classes = [sorted_classes[x] for x in class_names]
49
+ dataset = ImageDataset(
50
+ image_size,
51
+ all_files,
52
+ classes=classes,
53
+ shard=MPI.COMM_WORLD.Get_rank(),
54
+ num_shards=MPI.COMM_WORLD.Get_size(),
55
+ random_crop=random_crop,
56
+ random_flip=random_flip,
57
+ )
58
+ if deterministic:
59
+ loader = DataLoader(
60
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
61
+ )
62
+ else:
63
+ loader = DataLoader(
64
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
65
+ )
66
+ while True:
67
+ yield from loader
68
+
69
+
70
+ def _list_image_files_recursively(data_dir):
71
+ results = []
72
+ for entry in sorted(bf.listdir(data_dir)):
73
+ full_path = bf.join(data_dir, entry)
74
+ ext = entry.split(".")[-1]
75
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
76
+ results.append(full_path)
77
+ elif bf.isdir(full_path):
78
+ results.extend(_list_image_files_recursively(full_path))
79
+ return results
80
+
81
+
82
+ class ImageDataset(Dataset):
83
+ def __init__(
84
+ self,
85
+ resolution,
86
+ image_paths,
87
+ classes=None,
88
+ shard=0,
89
+ num_shards=1,
90
+ random_crop=False,
91
+ random_flip=True,
92
+ ):
93
+ super().__init__()
94
+ self.resolution = resolution
95
+ self.local_images = image_paths[shard:][::num_shards]
96
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
97
+ self.random_crop = random_crop
98
+ self.random_flip = random_flip
99
+
100
+ def __len__(self):
101
+ return len(self.local_images)
102
+
103
+ def __getitem__(self, idx):
104
+ path = self.local_images[idx]
105
+ with bf.BlobFile(path, "rb") as f:
106
+ pil_image = Image.open(f)
107
+ pil_image.load()
108
+ pil_image = pil_image.convert("RGB")
109
+
110
+ if self.random_crop:
111
+ arr = random_crop_arr(pil_image, self.resolution)
112
+ else:
113
+ arr = center_crop_arr(pil_image, self.resolution)
114
+
115
+ if self.random_flip and random.random() < 0.5:
116
+ arr = arr[:, ::-1]
117
+
118
+ arr = arr.astype(np.float32) / 127.5 - 1
119
+
120
+ out_dict = {}
121
+ if self.local_classes is not None:
122
+ out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
123
+ return np.transpose(arr, [2, 0, 1]), out_dict
124
+
125
+
126
+ def center_crop_arr(pil_image, image_size):
127
+ # We are not on a new enough PIL to support the `reducing_gap`
128
+ # argument, which uses BOX downsampling at powers of two first.
129
+ # Thus, we do it by hand to improve downsample quality.
130
+ while min(*pil_image.size) >= 2 * image_size:
131
+ pil_image = pil_image.resize(
132
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
133
+ )
134
+
135
+ scale = image_size / min(*pil_image.size)
136
+ pil_image = pil_image.resize(
137
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
138
+ )
139
+
140
+ arr = np.array(pil_image)
141
+ crop_y = (arr.shape[0] - image_size) // 2
142
+ crop_x = (arr.shape[1] - image_size) // 2
143
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
144
+
145
+
146
+ def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
147
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
148
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
149
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
150
+
151
+ # We are not on a new enough PIL to support the `reducing_gap`
152
+ # argument, which uses BOX downsampling at powers of two first.
153
+ # Thus, we do it by hand to improve downsample quality.
154
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
155
+ pil_image = pil_image.resize(
156
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
157
+ )
158
+
159
+ scale = smaller_dim_size / min(*pil_image.size)
160
+ pil_image = pil_image.resize(
161
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
162
+ )
163
+
164
+ arr = np.array(pil_image)
165
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
166
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
167
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
guided-diffusion/guided_diffusion/logger.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import os.path as osp
10
+ import json
11
+ import time
12
+ import datetime
13
+ import tempfile
14
+ import warnings
15
+ from collections import defaultdict
16
+ from contextlib import contextmanager
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ DEBUG = 10
20
+ INFO = 20
21
+ WARN = 30
22
+ ERROR = 40
23
+
24
+ DISABLED = 50
25
+
26
+
27
+ class KVWriter(object):
28
+ def writekvs(self, kvs):
29
+ raise NotImplementedError
30
+
31
+
32
+ class SeqWriter(object):
33
+ def writeseq(self, seq):
34
+ raise NotImplementedError
35
+
36
+
37
+ class HumanOutputFormat(KVWriter, SeqWriter):
38
+ def __init__(self, filename_or_file):
39
+ if isinstance(filename_or_file, str):
40
+ self.file = open(filename_or_file, "wt")
41
+ self.own_file = True
42
+ else:
43
+ assert hasattr(filename_or_file, "read"), (
44
+ "expected file or str, got %s" % filename_or_file
45
+ )
46
+ self.file = filename_or_file
47
+ self.own_file = False
48
+
49
+ def writekvs(self, kvs):
50
+ # Create strings for printing
51
+ key2str = {}
52
+ for (key, val) in sorted(kvs.items()):
53
+ if hasattr(val, "__float__"):
54
+ valstr = "%-8.3g" % val
55
+ else:
56
+ valstr = str(val)
57
+ key2str[self._truncate(key)] = self._truncate(valstr)
58
+
59
+ # Find max widths
60
+ if len(key2str) == 0:
61
+ print("WARNING: tried to write empty key-value dict")
62
+ return
63
+ else:
64
+ keywidth = max(map(len, key2str.keys()))
65
+ valwidth = max(map(len, key2str.values()))
66
+
67
+ # Write out the data
68
+ dashes = "-" * (keywidth + valwidth + 7)
69
+ lines = [dashes]
70
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
71
+ lines.append(
72
+ "| %s%s | %s%s |"
73
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
74
+ )
75
+ lines.append(dashes)
76
+ self.file.write("\n".join(lines) + "\n")
77
+
78
+ # Flush the output to the file
79
+ self.file.flush()
80
+
81
+ def _truncate(self, s):
82
+ maxlen = 30
83
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
84
+
85
+ def writeseq(self, seq):
86
+ seq = list(seq)
87
+ for (i, elem) in enumerate(seq):
88
+ self.file.write(elem)
89
+ if i < len(seq) - 1: # add space unless this is the last one
90
+ self.file.write(" ")
91
+ self.file.write("\n")
92
+ self.file.flush()
93
+
94
+ def close(self):
95
+ if self.own_file:
96
+ self.file.close()
97
+
98
+
99
+ class JSONOutputFormat(KVWriter):
100
+ def __init__(self, filename):
101
+ self.file = open(filename, "wt")
102
+
103
+ def writekvs(self, kvs):
104
+ for k, v in sorted(kvs.items()):
105
+ if hasattr(v, "dtype"):
106
+ kvs[k] = float(v)
107
+ self.file.write(json.dumps(kvs) + "\n")
108
+ self.file.flush()
109
+
110
+ def close(self):
111
+ self.file.close()
112
+
113
+
114
+ class CSVOutputFormat(KVWriter):
115
+ def __init__(self, filename):
116
+ self.file = open(filename, "w+t")
117
+ self.keys = []
118
+ self.sep = ","
119
+
120
+ def writekvs(self, kvs):
121
+ # Add our current row to the history
122
+ extra_keys = list(kvs.keys() - self.keys)
123
+ extra_keys.sort()
124
+ if extra_keys:
125
+ self.keys.extend(extra_keys)
126
+ self.file.seek(0)
127
+ lines = self.file.readlines()
128
+ self.file.seek(0)
129
+ for (i, k) in enumerate(self.keys):
130
+ if i > 0:
131
+ self.file.write(",")
132
+ self.file.write(k)
133
+ self.file.write("\n")
134
+ for line in lines[1:]:
135
+ self.file.write(line[:-1])
136
+ self.file.write(self.sep * len(extra_keys))
137
+ self.file.write("\n")
138
+ for (i, k) in enumerate(self.keys):
139
+ if i > 0:
140
+ self.file.write(",")
141
+ v = kvs.get(k)
142
+ if v is not None:
143
+ self.file.write(str(v))
144
+ self.file.write("\n")
145
+ self.file.flush()
146
+
147
+ def close(self):
148
+ self.file.close()
149
+
150
+
151
+ class TensorBoardOutputFormat(KVWriter):
152
+ """
153
+ Dumps key/value pairs into TensorBoard's numeric format.
154
+ """
155
+
156
+ def __init__(self, dir):
157
+ os.makedirs(dir, exist_ok=True)
158
+ self.dir = dir
159
+ self.step = -1
160
+ self.writer = SummaryWriter(self.dir)
161
+
162
+ def writekvs(self, kvs):
163
+ self.step = int(kvs["step"])
164
+ for k, v in sorted(kvs.items()):
165
+ self.writer.add_scalar(k, float(v), self.step)
166
+ self.writer.flush()
167
+
168
+ def writeimage(self, key, image_tensor):
169
+ self.writer.add_image(key, image_tensor, self.step)
170
+ self.writer.flush()
171
+
172
+ def close(self):
173
+ if self.writer:
174
+ self.writer.close()
175
+ self.writer = None
176
+
177
+
178
+ def make_output_format(format, ev_dir, log_suffix=""):
179
+ os.makedirs(ev_dir, exist_ok=True)
180
+ if format == "stdout":
181
+ return HumanOutputFormat(sys.stdout)
182
+ elif format == "log":
183
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
184
+ elif format == "json":
185
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
186
+ elif format == "csv":
187
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
188
+ elif format == "tensorboard":
189
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
190
+ else:
191
+ raise ValueError("Unknown format specified: %s" % (format,))
192
+
193
+
194
+ # ================================================================
195
+ # API
196
+ # ================================================================
197
+
198
+ def logimage(key, image_tensor):
199
+ """
200
+ Log one image to tensorboard
201
+ """
202
+ for fmt in get_current().output_formats:
203
+ if isinstance(fmt, TensorBoardOutputFormat):
204
+ tb_logger = fmt
205
+ tb_logger.writeimage(key, image_tensor)
206
+
207
+
208
+ def logkv(key, val):
209
+ """
210
+ Log a value of some diagnostic
211
+ Call this once for each diagnostic quantity, each iteration
212
+ If called many times, last value will be used.
213
+ """
214
+ get_current().logkv(key, val)
215
+
216
+
217
+ def logkv_mean(key, val):
218
+ """
219
+ The same as logkv(), but if called many times, values averaged.
220
+ """
221
+ get_current().logkv_mean(key, val)
222
+
223
+
224
+ def logkvs(d):
225
+ """
226
+ Log a dictionary of key-value pairs
227
+ """
228
+ for (k, v) in d.items():
229
+ logkv(k, v)
230
+
231
+
232
+ def dumpkvs():
233
+ """
234
+ Write all of the diagnostics from the current iteration
235
+ """
236
+ return get_current().dumpkvs()
237
+
238
+
239
+ def getkvs():
240
+ return get_current().name2val
241
+
242
+
243
+ def log(*args, level=INFO):
244
+ """
245
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
246
+ """
247
+ get_current().log(*args, level=level)
248
+
249
+
250
+ def debug(*args):
251
+ log(*args, level=DEBUG)
252
+
253
+
254
+ def info(*args):
255
+ log(*args, level=INFO)
256
+
257
+
258
+ def warn(*args):
259
+ log(*args, level=WARN)
260
+
261
+
262
+ def error(*args):
263
+ log(*args, level=ERROR)
264
+
265
+
266
+ def set_level(level):
267
+ """
268
+ Set logging threshold on current logger.
269
+ """
270
+ get_current().set_level(level)
271
+
272
+
273
+ def set_comm(comm):
274
+ get_current().set_comm(comm)
275
+
276
+
277
+ def get_dir():
278
+ """
279
+ Get directory that log files are being written to.
280
+ will be None if there is no output directory (i.e., if you didn't call start)
281
+ """
282
+ return get_current().get_dir()
283
+
284
+
285
+ record_tabular = logkv
286
+ dump_tabular = dumpkvs
287
+
288
+
289
+ @contextmanager
290
+ def profile_kv(scopename):
291
+ logkey = "wait_" + scopename
292
+ tstart = time.time()
293
+ try:
294
+ yield
295
+ finally:
296
+ get_current().name2val[logkey] += time.time() - tstart
297
+
298
+
299
+ def profile(n):
300
+ """
301
+ Usage:
302
+ @profile("my_func")
303
+ def my_func(): code
304
+ """
305
+
306
+ def decorator_with_name(func):
307
+ def func_wrapper(*args, **kwargs):
308
+ with profile_kv(n):
309
+ return func(*args, **kwargs)
310
+
311
+ return func_wrapper
312
+
313
+ return decorator_with_name
314
+
315
+
316
+ # ================================================================
317
+ # Backend
318
+ # ================================================================
319
+
320
+
321
+ def get_current():
322
+ if Logger.CURRENT is None:
323
+ _configure_default_logger()
324
+
325
+ return Logger.CURRENT
326
+
327
+
328
+ class Logger(object):
329
+ DEFAULT = None # A logger with no output files. (See right below class definition)
330
+ # So that you can still log to the terminal without setting up any output files
331
+ CURRENT = None # Current logger being used by the free functions above
332
+
333
+ def __init__(self, dir, output_formats, comm=None):
334
+ self.name2val = defaultdict(float) # values this iteration
335
+ self.name2cnt = defaultdict(int)
336
+ self.level = INFO
337
+ self.dir = dir
338
+ self.output_formats = output_formats
339
+ self.comm = comm
340
+
341
+ # Logging API, forwarded
342
+ # ----------------------------------------
343
+ def logkv(self, key, val):
344
+ self.name2val[key] = val
345
+
346
+ def logkv_mean(self, key, val):
347
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
348
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
349
+ self.name2cnt[key] = cnt + 1
350
+
351
+ def dumpkvs(self):
352
+ if self.comm is None:
353
+ d = self.name2val
354
+ else:
355
+ d = mpi_weighted_mean(
356
+ self.comm,
357
+ {
358
+ name: (val, self.name2cnt.get(name, 1))
359
+ for (name, val) in self.name2val.items()
360
+ },
361
+ )
362
+ if self.comm.rank != 0:
363
+ d["dummy"] = 1 # so we don't get a warning about empty dict
364
+ out = d.copy() # Return the dict for unit testing purposes
365
+ for fmt in self.output_formats:
366
+ if isinstance(fmt, KVWriter):
367
+ fmt.writekvs(d)
368
+ self.name2val.clear()
369
+ self.name2cnt.clear()
370
+ return out
371
+
372
+ def log(self, *args, level=INFO):
373
+ if self.level <= level:
374
+ self._do_log(args)
375
+
376
+ # Configuration
377
+ # ----------------------------------------
378
+ def set_level(self, level):
379
+ self.level = level
380
+
381
+ def set_comm(self, comm):
382
+ self.comm = comm
383
+
384
+ def get_dir(self):
385
+ return self.dir
386
+
387
+ def close(self):
388
+ for fmt in self.output_formats:
389
+ fmt.close()
390
+
391
+ # Misc
392
+ # ----------------------------------------
393
+ def _do_log(self, args):
394
+ for fmt in self.output_formats:
395
+ if isinstance(fmt, SeqWriter):
396
+ fmt.writeseq(map(str, args))
397
+
398
+
399
+ def get_rank_without_mpi_import():
400
+ # check environment variables here instead of importing mpi4py
401
+ # to avoid calling MPI_Init() when this module is imported
402
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
403
+ if varname in os.environ:
404
+ return int(os.environ[varname])
405
+ return 0
406
+
407
+
408
+ def mpi_weighted_mean(comm, local_name2valcount):
409
+ """
410
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
411
+ Perform a weighted average over dicts that are each on a different node
412
+ Input: local_name2valcount: dict mapping key -> (value, count)
413
+ Returns: key -> mean
414
+ """
415
+ all_name2valcount = comm.gather(local_name2valcount)
416
+ if comm.rank == 0:
417
+ name2sum = defaultdict(float)
418
+ name2count = defaultdict(float)
419
+ for n2vc in all_name2valcount:
420
+ for (name, (val, count)) in n2vc.items():
421
+ try:
422
+ val = float(val)
423
+ except ValueError:
424
+ if comm.rank == 0:
425
+ warnings.warn(
426
+ "WARNING: tried to compute mean on non-float {}={}".format(
427
+ name, val
428
+ )
429
+ )
430
+ else:
431
+ name2sum[name] += val * count
432
+ name2count[name] += count
433
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
434
+ else:
435
+ return {}
436
+
437
+
438
+ def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
439
+ """
440
+ If comm is provided, average all numerical stats across that comm
441
+ """
442
+ if dir is None:
443
+ dir = os.getenv("OPENAI_LOGDIR")
444
+ if dir is None:
445
+ dir = osp.join(
446
+ tempfile.gettempdir(),
447
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
448
+ )
449
+ assert isinstance(dir, str)
450
+ dir = os.path.expanduser(dir)
451
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
452
+
453
+ rank = get_rank_without_mpi_import()
454
+ if rank > 0:
455
+ log_suffix = log_suffix + "-rank%03i" % rank
456
+
457
+ if format_strs is None:
458
+ if rank == 0:
459
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv,tensorboard").split(",")
460
+ else:
461
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
462
+ format_strs = filter(None, format_strs)
463
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
464
+
465
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
466
+ if output_formats:
467
+ log("Logging to %s" % dir)
468
+
469
+
470
+ def _configure_default_logger():
471
+ configure()
472
+ Logger.DEFAULT = Logger.CURRENT
473
+
474
+
475
+ def reset():
476
+ if Logger.CURRENT is not Logger.DEFAULT:
477
+ Logger.CURRENT.close()
478
+ Logger.CURRENT = Logger.DEFAULT
479
+ log("Reset logger")
480
+
481
+
482
+ @contextmanager
483
+ def scoped_configure(dir=None, format_strs=None, comm=None):
484
+ prevlogger = Logger.CURRENT
485
+ configure(dir=dir, format_strs=format_strs, comm=comm)
486
+ try:
487
+ yield
488
+ finally:
489
+ Logger.CURRENT.close()
490
+ Logger.CURRENT = prevlogger
491
+
guided-diffusion/guided_diffusion/losses.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for various likelihood-based losses. These are ported from the original
3
+ Ho et al. diffusion models codebase:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
53
+ given image.
54
+
55
+ :param x: the target images. It is assumed that this was uint8 values,
56
+ rescaled to the range [-1, 1].
57
+ :param means: the Gaussian mean Tensor.
58
+ :param log_scales: the Gaussian log stddev Tensor.
59
+ :return: a tensor like x of log probabilities (in nats).
60
+ """
61
+ assert x.shape == means.shape == log_scales.shape
62
+ centered_x = x - means
63
+ inv_stdv = th.exp(-log_scales)
64
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65
+ cdf_plus = approx_standard_normal_cdf(plus_in)
66
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67
+ cdf_min = approx_standard_normal_cdf(min_in)
68
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70
+ cdf_delta = cdf_plus - cdf_min
71
+ log_probs = th.where(
72
+ x < -0.999,
73
+ log_cdf_plus,
74
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75
+ )
76
+ assert log_probs.shape == x.shape
77
+ return log_probs
guided-diffusion/guided_diffusion/lpips.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lpips_pytorch import LPIPS
2
+ import torch
3
+
4
+ class LPIPS1(LPIPS):
5
+ r"""
6
+ Overrriding the LPIPS to send loss without reducing the batch
7
+ Arguments:
8
+ net_type (str): the network type to compare the features:
9
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
10
+ version (str): the version of LPIPS. Default: 0.1.
11
+ """
12
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
13
+ super(LPIPS1, self).__init__(net_type = 'alex', version ='0.1')
14
+
15
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
16
+ feat_x, feat_y = self.net(x), self.net(y)
17
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
18
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
19
+ # return torch.sum(torch.cat(res, 0), 0, True)
20
+ return torch.sum(torch.cat(res, 1), 1, True)
guided-diffusion/guided_diffusion/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
guided-diffusion/guided_diffusion/resample.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ elif name == "loss-second-moment":
18
+ return LossSecondMomentResampler(diffusion)
19
+ else:
20
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
21
+
22
+
23
+ class ScheduleSampler(ABC):
24
+ """
25
+ A distribution over timesteps in the diffusion process, intended to reduce
26
+ variance of the objective.
27
+
28
+ By default, samplers perform unbiased importance sampling, in which the
29
+ objective's mean is unchanged.
30
+ However, subclasses may override sample() to change how the resampled
31
+ terms are reweighted, allowing for actual changes in the objective.
32
+ """
33
+
34
+ @abstractmethod
35
+ def weights(self):
36
+ """
37
+ Get a numpy array of weights, one per diffusion step.
38
+
39
+ The weights needn't be normalized, but must be positive.
40
+ """
41
+
42
+ def sample(self, batch_size, device):
43
+ """
44
+ Importance-sample timesteps for a batch.
45
+
46
+ :param batch_size: the number of timesteps.
47
+ :param device: the torch device to save to.
48
+ :return: a tuple (timesteps, weights):
49
+ - timesteps: a tensor of timestep indices.
50
+ - weights: a tensor of weights to scale the resulting losses.
51
+ """
52
+ w = self.weights()
53
+ p = w / np.sum(w)
54
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55
+ indices = th.from_numpy(indices_np).long().to(device)
56
+ weights_np = 1 / (len(p) * p[indices_np])
57
+ weights = th.from_numpy(weights_np).float().to(device)
58
+ return indices, weights
59
+
60
+
61
+ class UniformSampler(ScheduleSampler):
62
+ def __init__(self, diffusion):
63
+ self.diffusion = diffusion
64
+ self._weights = np.ones([diffusion.num_timesteps])
65
+
66
+ def weights(self):
67
+ return self._weights
68
+
69
+
70
+ class LossAwareSampler(ScheduleSampler):
71
+ def update_with_local_losses(self, local_ts, local_losses):
72
+ """
73
+ Update the reweighting using losses from a model.
74
+
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+
80
+ :param local_ts: an integer Tensor of timesteps.
81
+ :param local_losses: a 1D Tensor of losses.
82
+ """
83
+ batch_sizes = [
84
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
85
+ for _ in range(dist.get_world_size())
86
+ ]
87
+ dist.all_gather(
88
+ batch_sizes,
89
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90
+ )
91
+
92
+ # Pad all_gather batches to be the maximum batch size.
93
+ batch_sizes = [x.item() for x in batch_sizes]
94
+ max_bs = max(batch_sizes)
95
+
96
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98
+ dist.all_gather(timestep_batches, local_ts)
99
+ dist.all_gather(loss_batches, local_losses)
100
+ timesteps = [
101
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102
+ ]
103
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104
+ self.update_with_all_losses(timesteps, losses)
105
+
106
+ @abstractmethod
107
+ def update_with_all_losses(self, ts, losses):
108
+ """
109
+ Update the reweighting using losses from a model.
110
+
111
+ Sub-classes should override this method to update the reweighting
112
+ using losses from the model.
113
+
114
+ This method directly updates the reweighting without synchronizing
115
+ between workers. It is called by update_with_local_losses from all
116
+ ranks with identical arguments. Thus, it should have deterministic
117
+ behavior to maintain state across workers.
118
+
119
+ :param ts: a list of int timesteps.
120
+ :param losses: a list of float losses, one per timestep.
121
+ """
122
+
123
+
124
+ class LossSecondMomentResampler(LossAwareSampler):
125
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126
+ self.diffusion = diffusion
127
+ self.history_per_term = history_per_term
128
+ self.uniform_prob = uniform_prob
129
+ self._loss_history = np.zeros(
130
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
131
+ )
132
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133
+
134
+ def weights(self):
135
+ if not self._warmed_up():
136
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138
+ weights /= np.sum(weights)
139
+ weights *= 1 - self.uniform_prob
140
+ weights += self.uniform_prob / len(weights)
141
+ return weights
142
+
143
+ def update_with_all_losses(self, ts, losses):
144
+ for t, loss in zip(ts, losses):
145
+ if self._loss_counts[t] == self.history_per_term:
146
+ # Shift out the oldest loss term.
147
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
148
+ self._loss_history[t, -1] = loss
149
+ else:
150
+ self._loss_history[t, self._loss_counts[t]] = loss
151
+ self._loss_counts[t] += 1
152
+
153
+ def _warmed_up(self):
154
+ return (self._loss_counts == self.history_per_term).all()
guided-diffusion/guided_diffusion/respace.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+ from .gaussian_diffusion import GaussianDiffusion
5
+
6
+
7
+ def space_timesteps(num_timesteps, section_counts):
8
+ """
9
+ Create a list of timesteps to use from an original diffusion process,
10
+ given the number of timesteps we want to take from equally-sized portions
11
+ of the original process.
12
+
13
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
14
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
15
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
16
+
17
+ If the stride is a string starting with "ddim", then the fixed striding
18
+ from the DDIM paper is used, and only one section is allowed.
19
+
20
+ :param num_timesteps: the number of diffusion steps in the original
21
+ process to divide up.
22
+ :param section_counts: either a list of numbers, or a string containing
23
+ comma-separated numbers, indicating the step count
24
+ per section. As a special case, use "ddimN" where N
25
+ is a number of steps to use the striding from the
26
+ DDIM paper.
27
+ :return: a set of diffusion steps from the original process to use.
28
+ """
29
+ if isinstance(section_counts, str):
30
+ if section_counts.startswith("ddim"):
31
+ desired_count = int(section_counts[len("ddim") :])
32
+ for i in range(1, num_timesteps):
33
+ if len(range(0, num_timesteps, i)) == desired_count:
34
+ return set(range(0, num_timesteps, i))
35
+ raise ValueError(
36
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
37
+ )
38
+ section_counts = [int(x) for x in section_counts.split(",")]
39
+ size_per = num_timesteps // len(section_counts)
40
+ extra = num_timesteps % len(section_counts)
41
+ start_idx = 0
42
+ all_steps = []
43
+ for i, section_count in enumerate(section_counts):
44
+ size = size_per + (1 if i < extra else 0)
45
+ if size < section_count:
46
+ raise ValueError(
47
+ f"cannot divide section of {size} steps into {section_count}"
48
+ )
49
+ if section_count <= 1:
50
+ frac_stride = 1
51
+ else:
52
+ frac_stride = (size - 1) / (section_count - 1)
53
+ cur_idx = 0.0
54
+ taken_steps = []
55
+ for _ in range(section_count):
56
+ taken_steps.append(start_idx + round(cur_idx))
57
+ cur_idx += frac_stride
58
+ all_steps += taken_steps
59
+ start_idx += size
60
+ return set(all_steps)
61
+
62
+
63
+ class SpacedDiffusion(GaussianDiffusion):
64
+ """
65
+ A diffusion process which can skip steps in a base diffusion process.
66
+
67
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
68
+ original diffusion process to retain.
69
+ :param kwargs: the kwargs to create the base diffusion process.
70
+ """
71
+
72
+ def __init__(self, use_timesteps, **kwargs):
73
+ self.use_timesteps = set(use_timesteps)
74
+ self.timestep_map = []
75
+ self.original_num_steps = len(kwargs["betas"])
76
+
77
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78
+ last_alpha_cumprod = 1.0
79
+ new_betas = []
80
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81
+ if i in self.use_timesteps:
82
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83
+ last_alpha_cumprod = alpha_cumprod
84
+ self.timestep_map.append(i)
85
+ kwargs["betas"] = np.array(new_betas)
86
+ super().__init__(**kwargs)
87
+
88
+ def p_mean_variance(
89
+ self, model, *args, **kwargs
90
+ ): # pylint: disable=signature-differs
91
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92
+
93
+ def training_losses(
94
+ self, model, *args, **kwargs
95
+ ): # pylint: disable=signature-differs
96
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
97
+
98
+ def condition_mean(self, cond_fn, *args, **kwargs):
99
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100
+
101
+ def condition_score(self, cond_fn, *args, **kwargs):
102
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103
+
104
+ def _wrap_model(self, model):
105
+ if isinstance(model, _WrappedModel):
106
+ return model
107
+ return _WrappedModel(
108
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109
+ )
110
+
111
+ def _scale_timesteps(self, t):
112
+ # Scaling is done by the wrapped model.
113
+ return t
114
+
115
+
116
+ class _WrappedModel:
117
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118
+ self.model = model
119
+ self.timestep_map = timestep_map
120
+ self.rescale_timesteps = rescale_timesteps
121
+ self.original_num_steps = original_num_steps
122
+
123
+ def __call__(self, x, ts, **kwargs):
124
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125
+ new_ts = map_tensor[ts]
126
+ if self.rescale_timesteps:
127
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128
+ return self.model(x, new_ts, **kwargs)
guided-diffusion/guided_diffusion/script_util.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import inspect
3
+
4
+ from . import gaussian_diffusion as gd
5
+ from .respace import SpacedDiffusion, space_timesteps
6
+ from .unet import SuperResModel, UNetModel, EncoderUNetModel, TFGModel
7
+
8
+ NUM_CLASSES = 1000
9
+
10
+
11
+ def diffusion_defaults():
12
+ """
13
+ Defaults for image and classifier training.
14
+ """
15
+ return dict(
16
+ learn_sigma=False,
17
+ diffusion_steps=1000,
18
+ noise_schedule="linear",
19
+ timestep_respacing="",
20
+ use_kl=False,
21
+ predict_xstart=False,
22
+ rescale_timesteps=False,
23
+ rescale_learned_sigmas=False,
24
+ loss_variation=0, #added by soumik
25
+ )
26
+
27
+
28
+ def classifier_defaults():
29
+ """
30
+ Defaults for classifier models.
31
+ """
32
+ return dict(
33
+ image_size=64,
34
+ classifier_use_fp16=False,
35
+ classifier_width=128,
36
+ classifier_depth=2,
37
+ classifier_attention_resolutions="32,16,8", # 16
38
+ classifier_use_scale_shift_norm=True, # False
39
+ classifier_resblock_updown=True, # False
40
+ classifier_pool="attention",
41
+ )
42
+
43
+
44
+ def model_and_diffusion_defaults():
45
+ """
46
+ Defaults for image training.
47
+ """
48
+ res = dict(
49
+ image_size=64,
50
+ num_channels=128,
51
+ num_res_blocks=2,
52
+ num_heads=4,
53
+ num_heads_upsample=-1,
54
+ num_head_channels=-1,
55
+ attention_resolutions="16,8",
56
+ channel_mult="",
57
+ dropout=0.0,
58
+ class_cond=False,
59
+ use_checkpoint=False,
60
+ use_scale_shift_norm=True,
61
+ resblock_updown=False,
62
+ use_fp16=False,
63
+ use_new_attention_order=False,
64
+ )
65
+ res.update(diffusion_defaults())
66
+ return res
67
+
68
+
69
+ def classifier_and_diffusion_defaults():
70
+ res = classifier_defaults()
71
+ res.update(diffusion_defaults())
72
+ return res
73
+
74
+
75
+ def create_model_and_diffusion(
76
+ image_size,
77
+ class_cond,
78
+ learn_sigma,
79
+ num_channels,
80
+ num_res_blocks,
81
+ channel_mult,
82
+ num_heads,
83
+ num_head_channels,
84
+ num_heads_upsample,
85
+ attention_resolutions,
86
+ dropout,
87
+ diffusion_steps,
88
+ noise_schedule,
89
+ timestep_respacing,
90
+ use_kl,
91
+ predict_xstart,
92
+ rescale_timesteps,
93
+ rescale_learned_sigmas,
94
+ use_checkpoint,
95
+ use_scale_shift_norm,
96
+ resblock_updown,
97
+ use_fp16,
98
+ use_new_attention_order,
99
+ ):
100
+ model = create_model(
101
+ image_size,
102
+ num_channels,
103
+ num_res_blocks,
104
+ channel_mult=channel_mult,
105
+ learn_sigma=learn_sigma,
106
+ class_cond=class_cond,
107
+ use_checkpoint=use_checkpoint,
108
+ attention_resolutions=attention_resolutions,
109
+ num_heads=num_heads,
110
+ num_head_channels=num_head_channels,
111
+ num_heads_upsample=num_heads_upsample,
112
+ use_scale_shift_norm=use_scale_shift_norm,
113
+ dropout=dropout,
114
+ resblock_updown=resblock_updown,
115
+ use_fp16=use_fp16,
116
+ use_new_attention_order=use_new_attention_order,
117
+ )
118
+ diffusion = create_gaussian_diffusion(
119
+ steps=diffusion_steps,
120
+ learn_sigma=learn_sigma,
121
+ noise_schedule=noise_schedule,
122
+ use_kl=use_kl,
123
+ predict_xstart=predict_xstart,
124
+ rescale_timesteps=rescale_timesteps,
125
+ rescale_learned_sigmas=rescale_learned_sigmas,
126
+ timestep_respacing=timestep_respacing,
127
+ )
128
+ return model, diffusion
129
+
130
+
131
+ def create_model(
132
+ image_size,
133
+ num_channels,
134
+ num_res_blocks,
135
+ channel_mult="",
136
+ learn_sigma=False,
137
+ class_cond=False,
138
+ use_checkpoint=False,
139
+ attention_resolutions="16",
140
+ num_heads=1,
141
+ num_head_channels=-1,
142
+ num_heads_upsample=-1,
143
+ use_scale_shift_norm=False,
144
+ dropout=0,
145
+ resblock_updown=False,
146
+ use_fp16=False,
147
+ use_new_attention_order=False,
148
+ ):
149
+ if channel_mult == "":
150
+ if image_size == 512:
151
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
152
+ elif image_size == 256:
153
+ channel_mult = (1, 1, 2, 2, 4, 4)
154
+ elif image_size == 128:
155
+ channel_mult = (1, 1, 2, 3, 4)
156
+ elif image_size == 64:
157
+ channel_mult = (1, 2, 3, 4)
158
+ else:
159
+ raise ValueError(f"unsupported image size: {image_size}")
160
+ else:
161
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
162
+
163
+ attention_ds = []
164
+ for res in attention_resolutions.split(","):
165
+ attention_ds.append(image_size // int(res))
166
+
167
+ return UNetModel(
168
+ image_size=image_size,
169
+ in_channels=3,
170
+ model_channels=num_channels,
171
+ out_channels=(3 if not learn_sigma else 6),
172
+ num_res_blocks=num_res_blocks,
173
+ attention_resolutions=tuple(attention_ds),
174
+ dropout=dropout,
175
+ channel_mult=channel_mult,
176
+ num_classes=(NUM_CLASSES if class_cond else None),
177
+ use_checkpoint=use_checkpoint,
178
+ use_fp16=use_fp16,
179
+ num_heads=num_heads,
180
+ num_head_channels=num_head_channels,
181
+ num_heads_upsample=num_heads_upsample,
182
+ use_scale_shift_norm=use_scale_shift_norm,
183
+ resblock_updown=resblock_updown,
184
+ use_new_attention_order=use_new_attention_order,
185
+ )
186
+
187
+
188
+ def create_classifier_and_diffusion(
189
+ image_size,
190
+ classifier_use_fp16,
191
+ classifier_width,
192
+ classifier_depth,
193
+ classifier_attention_resolutions,
194
+ classifier_use_scale_shift_norm,
195
+ classifier_resblock_updown,
196
+ classifier_pool,
197
+ learn_sigma,
198
+ diffusion_steps,
199
+ noise_schedule,
200
+ timestep_respacing,
201
+ use_kl,
202
+ predict_xstart,
203
+ rescale_timesteps,
204
+ rescale_learned_sigmas,
205
+ ):
206
+ classifier = create_classifier(
207
+ image_size,
208
+ classifier_use_fp16,
209
+ classifier_width,
210
+ classifier_depth,
211
+ classifier_attention_resolutions,
212
+ classifier_use_scale_shift_norm,
213
+ classifier_resblock_updown,
214
+ classifier_pool,
215
+ )
216
+ diffusion = create_gaussian_diffusion(
217
+ steps=diffusion_steps,
218
+ learn_sigma=learn_sigma,
219
+ noise_schedule=noise_schedule,
220
+ use_kl=use_kl,
221
+ predict_xstart=predict_xstart,
222
+ rescale_timesteps=rescale_timesteps,
223
+ rescale_learned_sigmas=rescale_learned_sigmas,
224
+ timestep_respacing=timestep_respacing,
225
+ )
226
+ return classifier, diffusion
227
+
228
+
229
+ def create_classifier(
230
+ image_size,
231
+ classifier_use_fp16,
232
+ classifier_width,
233
+ classifier_depth,
234
+ classifier_attention_resolutions,
235
+ classifier_use_scale_shift_norm,
236
+ classifier_resblock_updown,
237
+ classifier_pool,
238
+ ):
239
+ if image_size == 512:
240
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
241
+ elif image_size == 256:
242
+ channel_mult = (1, 1, 2, 2, 4, 4)
243
+ elif image_size == 128:
244
+ channel_mult = (1, 1, 2, 3, 4)
245
+ elif image_size == 64:
246
+ channel_mult = (1, 2, 3, 4)
247
+ else:
248
+ raise ValueError(f"unsupported image size: {image_size}")
249
+
250
+ attention_ds = []
251
+ for res in classifier_attention_resolutions.split(","):
252
+ attention_ds.append(image_size // int(res))
253
+
254
+ return EncoderUNetModel(
255
+ image_size=image_size,
256
+ in_channels=3,
257
+ model_channels=classifier_width,
258
+ out_channels=1000,
259
+ num_res_blocks=classifier_depth,
260
+ attention_resolutions=tuple(attention_ds),
261
+ channel_mult=channel_mult,
262
+ use_fp16=classifier_use_fp16,
263
+ num_head_channels=64,
264
+ use_scale_shift_norm=classifier_use_scale_shift_norm,
265
+ resblock_updown=classifier_resblock_updown,
266
+ pool=classifier_pool,
267
+ )
268
+
269
+
270
+ def sr_model_and_diffusion_defaults():
271
+ res = model_and_diffusion_defaults()
272
+ res["large_size"] = 256
273
+ res["small_size"] = 64
274
+ arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
275
+ for k in res.copy().keys():
276
+ if k not in arg_names:
277
+ del res[k]
278
+ return res
279
+
280
+
281
+ def sr_create_model_and_diffusion(
282
+ large_size,
283
+ small_size,
284
+ class_cond,
285
+ learn_sigma,
286
+ num_channels,
287
+ num_res_blocks,
288
+ num_heads,
289
+ num_head_channels,
290
+ num_heads_upsample,
291
+ attention_resolutions,
292
+ dropout,
293
+ diffusion_steps,
294
+ noise_schedule,
295
+ timestep_respacing,
296
+ use_kl,
297
+ predict_xstart,
298
+ rescale_timesteps,
299
+ rescale_learned_sigmas,
300
+ use_checkpoint,
301
+ use_scale_shift_norm,
302
+ resblock_updown,
303
+ use_fp16,
304
+ ):
305
+ model = sr_create_model(
306
+ large_size,
307
+ small_size,
308
+ num_channels,
309
+ num_res_blocks,
310
+ learn_sigma=learn_sigma,
311
+ class_cond=class_cond,
312
+ use_checkpoint=use_checkpoint,
313
+ attention_resolutions=attention_resolutions,
314
+ num_heads=num_heads,
315
+ num_head_channels=num_head_channels,
316
+ num_heads_upsample=num_heads_upsample,
317
+ use_scale_shift_norm=use_scale_shift_norm,
318
+ dropout=dropout,
319
+ resblock_updown=resblock_updown,
320
+ use_fp16=use_fp16,
321
+ )
322
+ diffusion = create_gaussian_diffusion(
323
+ steps=diffusion_steps,
324
+ learn_sigma=learn_sigma,
325
+ noise_schedule=noise_schedule,
326
+ use_kl=use_kl,
327
+ predict_xstart=predict_xstart,
328
+ rescale_timesteps=rescale_timesteps,
329
+ rescale_learned_sigmas=rescale_learned_sigmas,
330
+ timestep_respacing=timestep_respacing,
331
+ )
332
+ return model, diffusion
333
+
334
+
335
+ def sr_create_model(
336
+ large_size,
337
+ small_size,
338
+ num_channels,
339
+ num_res_blocks,
340
+ learn_sigma,
341
+ class_cond,
342
+ use_checkpoint,
343
+ attention_resolutions,
344
+ num_heads,
345
+ num_head_channels,
346
+ num_heads_upsample,
347
+ use_scale_shift_norm,
348
+ dropout,
349
+ resblock_updown,
350
+ use_fp16,
351
+ ):
352
+ _ = small_size # hack to prevent unused variable
353
+
354
+ if large_size == 512:
355
+ channel_mult = (1, 1, 2, 2, 4, 4)
356
+ elif large_size == 256:
357
+ channel_mult = (1, 1, 2, 2, 4, 4)
358
+ elif large_size == 64:
359
+ channel_mult = (1, 2, 3, 4)
360
+ else:
361
+ raise ValueError(f"unsupported large size: {large_size}")
362
+
363
+ attention_ds = []
364
+ for res in attention_resolutions.split(","):
365
+ attention_ds.append(large_size // int(res))
366
+
367
+ return SuperResModel(
368
+ image_size=large_size,
369
+ in_channels=3,
370
+ model_channels=num_channels,
371
+ out_channels=(3 if not learn_sigma else 6),
372
+ num_res_blocks=num_res_blocks,
373
+ attention_resolutions=tuple(attention_ds),
374
+ dropout=dropout,
375
+ channel_mult=channel_mult,
376
+ num_classes=(NUM_CLASSES if class_cond else None),
377
+ use_checkpoint=use_checkpoint,
378
+ num_heads=num_heads,
379
+ num_head_channels=num_head_channels,
380
+ num_heads_upsample=num_heads_upsample,
381
+ use_scale_shift_norm=use_scale_shift_norm,
382
+ resblock_updown=resblock_updown,
383
+ use_fp16=use_fp16,
384
+ )
385
+
386
+
387
+ def create_gaussian_diffusion(
388
+ *,
389
+ steps=1000,
390
+ learn_sigma=False,
391
+ sigma_small=False,
392
+ noise_schedule="linear",
393
+ use_kl=False,
394
+ predict_xstart=False,
395
+ rescale_timesteps=False,
396
+ rescale_learned_sigmas=False,
397
+ timestep_respacing="",
398
+ loss_variation=0,
399
+ ):
400
+ betas = gd.get_named_beta_schedule(noise_schedule, steps)
401
+ if use_kl:
402
+ loss_type = gd.LossType.RESCALED_KL
403
+ elif rescale_learned_sigmas:
404
+ loss_type = gd.LossType.RESCALED_MSE
405
+ else:
406
+ loss_type = gd.LossType.MSE
407
+ if not timestep_respacing:
408
+ timestep_respacing = [steps]
409
+ return SpacedDiffusion(
410
+ use_timesteps=space_timesteps(steps, timestep_respacing),
411
+ betas=betas,
412
+ model_mean_type=(
413
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
414
+ ),
415
+ model_var_type=(
416
+ (
417
+ gd.ModelVarType.FIXED_LARGE
418
+ if not sigma_small
419
+ else gd.ModelVarType.FIXED_SMALL
420
+ )
421
+ if not learn_sigma
422
+ else gd.ModelVarType.LEARNED_RANGE
423
+ ),
424
+ loss_type=loss_type,
425
+ rescale_timesteps=rescale_timesteps,
426
+ loss_variation=loss_variation, #added by soumik
427
+ )
428
+
429
+
430
+ def add_dict_to_argparser(parser, default_dict):
431
+ for k, v in default_dict.items():
432
+ v_type = type(v)
433
+ if v is None:
434
+ v_type = str
435
+ elif isinstance(v, bool):
436
+ v_type = str2bool
437
+ parser.add_argument(f"--{k}", default=v, type=v_type)
438
+
439
+
440
+ def args_to_dict(args, keys):
441
+ return {k: getattr(args, k) for k in keys}
442
+
443
+
444
+ def str2bool(v):
445
+ """
446
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
447
+ """
448
+ if isinstance(v, bool):
449
+ return v
450
+ if v.lower() in ("yes", "true", "t", "y", "1"):
451
+ return True
452
+ elif v.lower() in ("no", "false", "f", "n", "0"):
453
+ return False
454
+ else:
455
+ raise argparse.ArgumentTypeError("boolean value expected")
456
+
457
+ #________________________________ tfg model ________________________________#
458
+ def tfg_model_and_diffusion_defaults():
459
+ res = model_and_diffusion_defaults()
460
+ arg_names = inspect.getfullargspec(tfg_create_model_and_diffusion)[0]
461
+ for k in res.copy().keys():
462
+ if k not in arg_names:
463
+ del res[k]
464
+
465
+ #tfg args
466
+ res["use_ref"]=False
467
+ res["nframes"]=1
468
+ res["nrefer"]=0
469
+ res["use_audio"]=False
470
+ res["audio_encoder_kwargs"]={}
471
+ res["audio_as_style"]=False
472
+ res["audio_as_style_encoder_mlp"]=False
473
+ return res
474
+
475
+ def tfg_create_model_and_diffusion(
476
+ image_size,
477
+ class_cond,
478
+ learn_sigma,
479
+ num_channels,
480
+ num_res_blocks,
481
+ num_heads,
482
+ num_head_channels,
483
+ num_heads_upsample,
484
+ attention_resolutions,
485
+ dropout,
486
+ diffusion_steps,
487
+ noise_schedule,
488
+ timestep_respacing,
489
+ use_kl,
490
+ predict_xstart,
491
+ rescale_timesteps,
492
+ rescale_learned_sigmas,
493
+ use_checkpoint,
494
+ use_scale_shift_norm,
495
+ resblock_updown,
496
+ use_fp16,
497
+ use_ref,
498
+ nframes,
499
+ nrefer,
500
+ use_audio,
501
+ audio_encoder_kwargs,
502
+ audio_as_style,
503
+ audio_as_style_encoder_mlp,
504
+ loss_variation,
505
+ ):
506
+ model = tfg_create_model(
507
+ image_size,
508
+ num_channels,
509
+ num_res_blocks,
510
+ learn_sigma=learn_sigma,
511
+ class_cond=class_cond,
512
+ use_checkpoint=use_checkpoint,
513
+ attention_resolutions=attention_resolutions,
514
+ num_heads=num_heads,
515
+ num_head_channels=num_head_channels,
516
+ num_heads_upsample=num_heads_upsample,
517
+ use_scale_shift_norm=use_scale_shift_norm,
518
+ dropout=dropout,
519
+ resblock_updown=resblock_updown,
520
+ use_fp16=use_fp16,
521
+ use_ref=use_ref,
522
+ nframes=nframes,
523
+ nrefer=nrefer,
524
+ use_audio=use_audio,
525
+ audio_encoder_kwargs=audio_encoder_kwargs,
526
+ audio_as_style=audio_as_style,
527
+ audio_as_style_encoder_mlp=audio_as_style_encoder_mlp,
528
+ )
529
+
530
+ diffusion = create_gaussian_diffusion(
531
+ steps=diffusion_steps,
532
+ learn_sigma=learn_sigma,
533
+ noise_schedule=noise_schedule,
534
+ use_kl=use_kl,
535
+ predict_xstart=predict_xstart,
536
+ rescale_timesteps=rescale_timesteps,
537
+ rescale_learned_sigmas=rescale_learned_sigmas,
538
+ timestep_respacing=timestep_respacing,
539
+ loss_variation=loss_variation,
540
+ )
541
+ return model, diffusion
542
+
543
+ def tfg_create_model(
544
+ image_size,
545
+ num_channels,
546
+ num_res_blocks,
547
+ learn_sigma,
548
+ class_cond,
549
+ use_checkpoint,
550
+ attention_resolutions,
551
+ num_heads,
552
+ num_head_channels,
553
+ num_heads_upsample,
554
+ use_scale_shift_norm,
555
+ dropout,
556
+ resblock_updown,
557
+ use_fp16,
558
+ use_ref,
559
+ nframes,
560
+ nrefer,
561
+ use_audio,
562
+ audio_encoder_kwargs,
563
+ audio_as_style,
564
+ audio_as_style_encoder_mlp,
565
+ ):
566
+
567
+ if image_size == 512:
568
+ channel_mult = (1, 1, 2, 2, 4, 4)
569
+ elif image_size == 256:
570
+ channel_mult = (1, 1, 2, 3, 4, 4)
571
+ elif image_size == 128:
572
+ channel_mult = (1, 1, 2, 3, 4)
573
+ elif image_size == 64:
574
+ channel_mult = (1, 2, 3, 4)
575
+ else:
576
+ raise ValueError(f"unsupported large size: {image_size}")
577
+
578
+ attention_ds = []
579
+ if "-1" not in attention_resolutions: # -1 = no attention
580
+ for res in attention_resolutions.split(","):
581
+ attention_ds.append(image_size // int(res))
582
+
583
+ return TFGModel(
584
+ image_size=image_size,
585
+ in_channels=3,
586
+ model_channels=num_channels,
587
+ out_channels=(3 if not learn_sigma else 6),
588
+ num_res_blocks=num_res_blocks,
589
+ attention_resolutions=tuple(attention_ds),
590
+ dropout=dropout,
591
+ channel_mult=channel_mult,
592
+ num_classes=(NUM_CLASSES if class_cond else None),
593
+ use_checkpoint=use_checkpoint,
594
+ use_fp16=use_fp16,
595
+ num_heads=num_heads,
596
+ num_head_channels=num_head_channels,
597
+ num_heads_upsample=num_heads_upsample,
598
+ use_scale_shift_norm=use_scale_shift_norm,
599
+ resblock_updown=resblock_updown,
600
+ use_ref=use_ref,
601
+ nframes=nframes,
602
+ nrefer=nrefer,
603
+ use_audio=use_audio,
604
+ audio_encoder_kwargs=audio_encoder_kwargs,
605
+ audio_as_style=audio_as_style,
606
+ audio_as_style_encoder_mlp=audio_as_style_encoder_mlp
607
+ )
608
+
609
+
610
+
611
+
612
+
613
+
614
+
guided-diffusion/guided_diffusion/tfg_data_util.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def normalise2(tensor):
4
+ '''[0,1] -> [-1,1]'''
5
+ return (tensor*2 - 1.).clamp(-1,1)
6
+
7
+ def tfg_data(dataloader, face_hide_percentage, use_ref, use_audio):#, sampling_use_gt_for_ref=False, noise = None):
8
+ def inf_gen(generator):
9
+ while True:
10
+ yield from generator
11
+ data = inf_gen(dataloader)
12
+ for batch in data:
13
+ img_batch, model_kwargs = tfg_process_batch(batch, face_hide_percentage, use_ref, use_audio)
14
+ yield img_batch, model_kwargs
15
+
16
+
17
+ def tfg_process_batch(batch, face_hide_percentage, use_ref=False, use_audio=False, sampling_use_gt_for_ref=False, noise = None):
18
+ model_kwargs = {}
19
+ B, F,C, H, W = batch["image"].shape
20
+ img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous())
21
+ model_kwargs = tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise)
22
+ if use_ref:
23
+ model_kwargs = tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref)
24
+ if use_audio:
25
+ model_kwargs = tfg_add_audio(batch,model_kwargs)
26
+ return img_batch, model_kwargs
27
+
28
+ def tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref=False):
29
+ # assuming nrefer = 1
30
+ #[B, nframes, C, H, W] -> #[B*nframes, C, H, W]
31
+ if sampling_use_gt_for_ref:
32
+ B, F,C, H, W = batch["image"].shape
33
+ img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous())
34
+ model_kwargs["ref_img"] = img_batch
35
+ else:
36
+ _, _, C, H , W = batch["ref_img"].shape
37
+ ref_img = normalise2(batch["ref_img"].reshape(-1, C, H, W).contiguous())
38
+ model_kwargs["ref_img"] = ref_img
39
+ return model_kwargs
40
+
41
+ def tfg_add_audio(batch, model_kwargs):
42
+ # unet needs [BF, h, w] as input
43
+ B, F, _, h, w = batch["indiv_mels"].shape
44
+ indiv_mels = batch["indiv_mels"] # [B, F, 1, h, w]
45
+ indiv_mels = indiv_mels.squeeze(dim=2).reshape(B*F, h , w)
46
+ model_kwargs["indiv_mels"] = indiv_mels
47
+ # syncloss needs [B, 1, 80, 16] as input
48
+ if "mel" in batch:
49
+ mel = batch["mel"] #[B, 1, h, w]
50
+ model_kwargs["mel"]=mel
51
+ return model_kwargs
52
+
53
+ def tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise=None):
54
+ B, C, H, W = img_batch.shape
55
+ mask = torch.zeros(B,1,H,W)
56
+ mask_start_idx = int (H*(1-face_hide_percentage))
57
+ mask[:,:,mask_start_idx:,:]=1.
58
+ if noise is None:
59
+ noise = torch.randn_like(img_batch)
60
+ assert noise.shape == img_batch.shape, "Noise shape != Image shape"
61
+ cond_img = img_batch *(1. - mask)+mask*noise
62
+
63
+ model_kwargs["cond_img"] = cond_img
64
+ model_kwargs["mask"] = mask
65
+ return model_kwargs
66
+
67
+
68
+ def get_n_params(model):
69
+ pp=0
70
+ for p in list(model.parameters()):
71
+ nn=1
72
+ for s in list(p.size()):
73
+ nn=nn*s
74
+ pp+=nn
75
+ return pp
guided-diffusion/guided_diffusion/unet.py ADDED
@@ -0,0 +1,1275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
11
+ from .nn import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+
21
+
22
+ class AttentionPool2d(nn.Module):
23
+ """
24
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ spacial_dim: int,
30
+ embed_dim: int,
31
+ num_heads_channels: int,
32
+ output_dim: int = None,
33
+ ):
34
+ super().__init__()
35
+ self.positional_embedding = nn.Parameter(
36
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
37
+ )
38
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
39
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
40
+ self.num_heads = embed_dim // num_heads_channels
41
+ self.attention = QKVAttention(self.num_heads)
42
+
43
+ def forward(self, x):
44
+ b, c, *_spatial = x.shape
45
+ x = x.reshape(b, c, -1) # NC(HW)
46
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
47
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
48
+ x = self.qkv_proj(x)
49
+ x = self.attention(x)
50
+ x = self.c_proj(x)
51
+ return x[:, :, 0]
52
+
53
+
54
+ class TimestepBlock(nn.Module):
55
+ """
56
+ Any module where forward() takes timestep embeddings as a second argument.
57
+ """
58
+
59
+ @abstractmethod
60
+ def forward(self, x, emb):
61
+ """
62
+ Apply the module to `x` given `emb` timestep embeddings.
63
+ """
64
+
65
+
66
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
67
+ """
68
+ A sequential module that passes timestep embeddings to the children that
69
+ support it as an extra input.
70
+ """
71
+
72
+ def forward(self, x, emb):
73
+ for layer in self:
74
+ if isinstance(layer, TimestepBlock):
75
+ x = layer(x, emb)
76
+ else:
77
+ x = layer(x)
78
+ return x
79
+
80
+
81
+ class Upsample(nn.Module):
82
+ """
83
+ An upsampling layer with an optional convolution.
84
+
85
+ :param channels: channels in the inputs and outputs.
86
+ :param use_conv: a bool determining if a convolution is applied.
87
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
88
+ upsampling occurs in the inner-two dimensions.
89
+ """
90
+
91
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
92
+ super().__init__()
93
+ self.channels = channels
94
+ self.out_channels = out_channels or channels
95
+ self.use_conv = use_conv
96
+ self.dims = dims
97
+ if use_conv:
98
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
99
+
100
+ def forward(self, x):
101
+ assert x.shape[1] == self.channels
102
+ if self.dims == 3:
103
+ x = F.interpolate(
104
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
105
+ )
106
+ else:
107
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
108
+ if self.use_conv:
109
+ x = self.conv(x)
110
+ return x
111
+
112
+
113
+ class Downsample(nn.Module):
114
+ """
115
+ A downsampling layer with an optional convolution.
116
+
117
+ :param channels: channels in the inputs and outputs.
118
+ :param use_conv: a bool determining if a convolution is applied.
119
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
120
+ downsampling occurs in the inner-two dimensions.
121
+ """
122
+
123
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, stride=None):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+ self.use_conv = use_conv
128
+ self.dims = dims
129
+ if stride is None:
130
+ stride = 2 if dims != 3 else (1, 2, 2)
131
+
132
+ if use_conv:
133
+ self.op = conv_nd(
134
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
135
+ )
136
+ else:
137
+ assert self.channels == self.out_channels
138
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
139
+
140
+ def forward(self, x):
141
+ assert x.shape[1] == self.channels
142
+ return self.op(x)
143
+
144
+
145
+ class ResBlock(TimestepBlock):
146
+ """
147
+ A residual block that can optionally change the number of channels.
148
+
149
+ :param channels: the number of input channels.
150
+ :param emb_channels: the number of timestep embedding channels.
151
+ :param dropout: the rate of dropout.
152
+ :param out_channels: if specified, the number of out channels.
153
+ :param use_conv: if True and out_channels is specified, use a spatial
154
+ convolution instead of a smaller 1x1 convolution to change the
155
+ channels in the skip connection.
156
+ :param dims: determines if the signal is 1D, 2D, or 3D.
157
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
158
+ :param up: if True, use this block for upsampling.
159
+ :param down: if True, use this block for downsampling.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ channels,
165
+ emb_channels,
166
+ dropout,
167
+ out_channels=None,
168
+ use_conv=False,
169
+ use_scale_shift_norm=False,
170
+ dims=2,
171
+ use_checkpoint=False,
172
+ up=False,
173
+ down=False,
174
+ down_stride = None,
175
+ ):
176
+ super().__init__()
177
+ self.channels = channels
178
+ self.emb_channels = emb_channels
179
+ self.dropout = dropout
180
+ self.out_channels = out_channels or channels
181
+ self.use_conv = use_conv
182
+ self.use_checkpoint = use_checkpoint
183
+ self.use_scale_shift_norm = use_scale_shift_norm
184
+ self.dims = dims
185
+
186
+ self.in_layers = nn.Sequential(
187
+ normalization(channels),
188
+ nn.SiLU(),
189
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
190
+ )
191
+
192
+ self.updown = up or down
193
+
194
+ if up:
195
+ self.h_upd = Upsample(channels, False, dims)
196
+ self.x_upd = Upsample(channels, False, dims)
197
+ elif down:
198
+ self.h_upd = Downsample(channels, False, dims, stride = down_stride)
199
+ self.x_upd = Downsample(channels, False, dims, stride = down_stride)
200
+ else:
201
+ self.h_upd = self.x_upd = nn.Identity()
202
+
203
+ self.emb_layers = nn.Sequential(
204
+ nn.SiLU(),
205
+ linear(
206
+ emb_channels,
207
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
208
+ ),
209
+ )
210
+ self.out_layers = nn.Sequential(
211
+ normalization(self.out_channels),
212
+ nn.SiLU(),
213
+ nn.Dropout(p=dropout),
214
+ zero_module(
215
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
216
+ ),
217
+ )
218
+
219
+ if self.out_channels == channels:
220
+ self.skip_connection = nn.Identity()
221
+ elif use_conv:
222
+ self.skip_connection = conv_nd(
223
+ dims, channels, self.out_channels, 3, padding=1
224
+ )
225
+ else:
226
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
227
+
228
+ def forward(self, x, emb):
229
+ """
230
+ Apply the block to a Tensor, conditioned on a timestep embedding.
231
+
232
+ :param x: an [N x C x ...] Tensor of features.
233
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
234
+ :return: an [N x C x ...] Tensor of outputs.
235
+ """
236
+ return checkpoint(
237
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
238
+ )
239
+
240
+ def _forward(self, x, emb):
241
+ if self.updown:
242
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
243
+ h = in_rest(x)
244
+ h = self.h_upd(h)
245
+ x = self.x_upd(x)
246
+ h = in_conv(h)
247
+ else:
248
+ h = self.in_layers(x)
249
+ emb_out = self.emb_layers(emb).type(h.dtype)
250
+ while len(emb_out.shape) < len(h.shape):
251
+ emb_out = emb_out[..., None]
252
+ if self.use_scale_shift_norm:
253
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
254
+ scale, shift = th.chunk(emb_out, 2, dim=1)
255
+ h = out_norm(h) * (1 + scale) + shift
256
+ h = out_rest(h)
257
+ else:
258
+ h = h + emb_out
259
+ h = self.out_layers(h)
260
+ return self.skip_connection(x) + h
261
+
262
+
263
+ class AttentionBlock(nn.Module):
264
+ """
265
+ An attention block that allows spatial positions to attend to each other.
266
+
267
+ Originally ported from here, but adapted to the N-d case.
268
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ channels,
274
+ num_heads=1,
275
+ num_head_channels=-1,
276
+ use_checkpoint=False,
277
+ use_new_attention_order=False,
278
+ ):
279
+ super().__init__()
280
+ self.channels = channels
281
+ if num_head_channels == -1:
282
+ self.num_heads = num_heads
283
+ else:
284
+ assert (
285
+ channels % num_head_channels == 0
286
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
287
+ self.num_heads = channels // num_head_channels
288
+ self.use_checkpoint = use_checkpoint
289
+ self.norm = normalization(channels)
290
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
291
+ if use_new_attention_order:
292
+ # split qkv before split heads
293
+ self.attention = QKVAttention(self.num_heads)
294
+ else:
295
+ # split heads before split qkv
296
+ self.attention = QKVAttentionLegacy(self.num_heads)
297
+
298
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
299
+
300
+ def forward(self, x):
301
+ return checkpoint(self._forward, (x,), self.parameters(), True)
302
+
303
+ def _forward(self, x):
304
+ b, c, *spatial = x.shape
305
+ x = x.reshape(b, c, -1)
306
+ qkv = self.qkv(self.norm(x))
307
+ h = self.attention(qkv)
308
+ h = self.proj_out(h)
309
+ return (x + h).reshape(b, c, *spatial)
310
+
311
+
312
+ def count_flops_attn(model, _x, y):
313
+ """
314
+ A counter for the `thop` package to count the operations in an
315
+ attention operation.
316
+ Meant to be used like:
317
+ macs, params = thop.profile(
318
+ model,
319
+ inputs=(inputs, timestamps),
320
+ custom_ops={QKVAttention: QKVAttention.count_flops},
321
+ )
322
+ """
323
+ b, c, *spatial = y[0].shape
324
+ num_spatial = int(np.prod(spatial))
325
+ # We perform two matmuls with the same number of ops.
326
+ # The first computes the weight matrix, the second computes
327
+ # the combination of the value vectors.
328
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
329
+ model.total_ops += th.DoubleTensor([matmul_ops])
330
+
331
+
332
+ class QKVAttentionLegacy(nn.Module):
333
+ """
334
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
335
+ """
336
+
337
+ def __init__(self, n_heads):
338
+ super().__init__()
339
+ self.n_heads = n_heads
340
+
341
+ def forward(self, qkv):
342
+ """
343
+ Apply QKV attention.
344
+
345
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
346
+ :return: an [N x (H * C) x T] tensor after attention.
347
+ """
348
+ bs, width, length = qkv.shape
349
+ assert width % (3 * self.n_heads) == 0
350
+ ch = width // (3 * self.n_heads)
351
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
352
+ scale = 1 / math.sqrt(math.sqrt(ch))
353
+ weight = th.einsum(
354
+ "bct,bcs->bts", q * scale, k * scale
355
+ ) # More stable with f16 than dividing afterwards
356
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
357
+ a = th.einsum("bts,bcs->bct", weight, v)
358
+ return a.reshape(bs, -1, length)
359
+
360
+ @staticmethod
361
+ def count_flops(model, _x, y):
362
+ return count_flops_attn(model, _x, y)
363
+
364
+
365
+ class QKVAttention(nn.Module):
366
+ """
367
+ A module which performs QKV attention and splits in a different order.
368
+ """
369
+
370
+ def __init__(self, n_heads):
371
+ super().__init__()
372
+ self.n_heads = n_heads
373
+
374
+ def forward(self, qkv):
375
+ """
376
+ Apply QKV attention.
377
+
378
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
379
+ :return: an [N x (H * C) x T] tensor after attention.
380
+ """
381
+ bs, width, length = qkv.shape
382
+ assert width % (3 * self.n_heads) == 0
383
+ ch = width // (3 * self.n_heads)
384
+ q, k, v = qkv.chunk(3, dim=1)
385
+ scale = 1 / math.sqrt(math.sqrt(ch))
386
+ weight = th.einsum(
387
+ "bct,bcs->bts",
388
+ (q * scale).view(bs * self.n_heads, ch, length),
389
+ (k * scale).view(bs * self.n_heads, ch, length),
390
+ ) # More stable with f16 than dividing afterwards
391
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
392
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
393
+ return a.reshape(bs, -1, length)
394
+
395
+ @staticmethod
396
+ def count_flops(model, _x, y):
397
+ return count_flops_attn(model, _x, y)
398
+
399
+
400
+ class UNetModel(nn.Module):
401
+ """
402
+ The full UNet model with attention and timestep embedding.
403
+
404
+ :param in_channels: channels in the input Tensor.
405
+ :param model_channels: base channel count for the model.
406
+ :param out_channels: channels in the output Tensor.
407
+ :param num_res_blocks: number of residual blocks per downsample.
408
+ :param attention_resolutions: a collection of downsample rates at which
409
+ attention will take place. May be a set, list, or tuple.
410
+ For example, if this contains 4, then at 4x downsampling, attention
411
+ will be used.
412
+ :param dropout: the dropout probability.
413
+ :param channel_mult: channel multiplier for each level of the UNet.
414
+ :param conv_resample: if True, use learned convolutions for upsampling and
415
+ downsampling.
416
+ :param dims: determines if the signal is 1D, 2D, or 3D.
417
+ :param num_classes: if specified (as an int), then this model will be
418
+ class-conditional with `num_classes` classes.
419
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
420
+ :param num_heads: the number of attention heads in each attention layer.
421
+ :param num_heads_channels: if specified, ignore num_heads and instead use
422
+ a fixed channel width per attention head.
423
+ :param num_heads_upsample: works with num_heads to set a different number
424
+ of heads for upsampling. Deprecated.
425
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
426
+ :param resblock_updown: use residual blocks for up/downsampling.
427
+ :param use_new_attention_order: use a different attention pattern for potentially
428
+ increased efficiency.
429
+ """
430
+
431
+ def __init__(
432
+ self,
433
+ image_size,
434
+ in_channels,
435
+ model_channels,
436
+ out_channels,
437
+ num_res_blocks,
438
+ attention_resolutions,
439
+ dropout=0,
440
+ channel_mult=(1, 2, 4, 8),
441
+ conv_resample=True,
442
+ dims=2,
443
+ num_classes=None,
444
+ use_checkpoint=False,
445
+ use_fp16=False,
446
+ num_heads=1,
447
+ num_head_channels=-1,
448
+ num_heads_upsample=-1,
449
+ use_scale_shift_norm=False,
450
+ resblock_updown=False,
451
+ use_new_attention_order=False,
452
+ ):
453
+ super().__init__()
454
+
455
+ if num_heads_upsample == -1:
456
+ num_heads_upsample = num_heads
457
+
458
+ self.image_size = image_size
459
+ self.in_channels = in_channels
460
+ self.model_channels = model_channels
461
+ self.out_channels = out_channels
462
+ self.num_res_blocks = num_res_blocks
463
+ self.attention_resolutions = attention_resolutions
464
+ self.dropout = dropout
465
+ self.channel_mult = channel_mult
466
+ self.conv_resample = conv_resample
467
+ self.dims = dims
468
+ self.num_classes = num_classes
469
+ self.use_checkpoint = use_checkpoint
470
+ self.use_fp16 = use_fp16
471
+ self.dtype = th.float16 if use_fp16 else th.float32
472
+ self.num_heads = num_heads
473
+ self.num_head_channels = num_head_channels
474
+ self.num_heads_upsample = num_heads_upsample
475
+ self.use_scale_shift_norm = use_scale_shift_norm
476
+ self.resblock_updown = resblock_updown
477
+
478
+ time_embed_dim = model_channels * 4
479
+ self.time_embed_dim = time_embed_dim
480
+ self.time_embed = nn.Sequential(
481
+ linear(model_channels, time_embed_dim),
482
+ nn.SiLU(),
483
+ linear(time_embed_dim, time_embed_dim),
484
+ )
485
+
486
+ if self.num_classes is not None:
487
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
488
+
489
+ ch = input_ch = int(channel_mult[0] * model_channels)
490
+ self.input_blocks = nn.ModuleList(
491
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
492
+ )
493
+ self._feature_size = ch
494
+ input_block_chans = [ch]
495
+ ds = 1
496
+ for level, mult in enumerate(channel_mult):
497
+ for _ in range(num_res_blocks):
498
+ layers = [
499
+ ResBlock(
500
+ ch,
501
+ time_embed_dim,
502
+ dropout,
503
+ out_channels=int(mult * model_channels),
504
+ dims=dims,
505
+ use_checkpoint=use_checkpoint,
506
+ use_scale_shift_norm=use_scale_shift_norm,
507
+ )
508
+ ]
509
+ ch = int(mult * model_channels)
510
+ if ds in attention_resolutions:
511
+ layers.append(
512
+ AttentionBlock(
513
+ ch,
514
+ use_checkpoint=use_checkpoint,
515
+ num_heads=num_heads,
516
+ num_head_channels=num_head_channels,
517
+ use_new_attention_order=use_new_attention_order,
518
+ )
519
+ )
520
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
521
+ self._feature_size += ch
522
+ input_block_chans.append(ch)
523
+ if level != len(channel_mult) - 1:
524
+ out_ch = ch
525
+ self.input_blocks.append(
526
+ TimestepEmbedSequential(
527
+ ResBlock(
528
+ ch,
529
+ time_embed_dim,
530
+ dropout,
531
+ out_channels=out_ch,
532
+ dims=dims,
533
+ use_checkpoint=use_checkpoint,
534
+ use_scale_shift_norm=use_scale_shift_norm,
535
+ down=True,
536
+ )
537
+ if resblock_updown
538
+ else Downsample(
539
+ ch, conv_resample, dims=dims, out_channels=out_ch
540
+ )
541
+ )
542
+ )
543
+ ch = out_ch
544
+ input_block_chans.append(ch)
545
+ ds *= 2
546
+ self._feature_size += ch
547
+
548
+ self.middle_block = TimestepEmbedSequential(
549
+ ResBlock(
550
+ ch,
551
+ time_embed_dim,
552
+ dropout,
553
+ dims=dims,
554
+ use_checkpoint=use_checkpoint,
555
+ use_scale_shift_norm=use_scale_shift_norm,
556
+ ),
557
+ AttentionBlock(
558
+ ch,
559
+ use_checkpoint=use_checkpoint,
560
+ num_heads=num_heads,
561
+ num_head_channels=num_head_channels,
562
+ use_new_attention_order=use_new_attention_order,
563
+ ),
564
+ ResBlock(
565
+ ch,
566
+ time_embed_dim,
567
+ dropout,
568
+ dims=dims,
569
+ use_checkpoint=use_checkpoint,
570
+ use_scale_shift_norm=use_scale_shift_norm,
571
+ ),
572
+ )
573
+ self._feature_size += ch
574
+
575
+ self.output_blocks = nn.ModuleList([])
576
+ for level, mult in list(enumerate(channel_mult))[::-1]:
577
+ for i in range(num_res_blocks + 1):
578
+ ich = input_block_chans.pop()
579
+ layers = [
580
+ ResBlock(
581
+ ch + ich,
582
+ time_embed_dim,
583
+ dropout,
584
+ out_channels=int(model_channels * mult),
585
+ dims=dims,
586
+ use_checkpoint=use_checkpoint,
587
+ use_scale_shift_norm=use_scale_shift_norm,
588
+ )
589
+ ]
590
+ ch = int(model_channels * mult)
591
+ if ds in attention_resolutions:
592
+ layers.append(
593
+ AttentionBlock(
594
+ ch,
595
+ use_checkpoint=use_checkpoint,
596
+ num_heads=num_heads_upsample,
597
+ num_head_channels=num_head_channels,
598
+ use_new_attention_order=use_new_attention_order,
599
+ )
600
+ )
601
+ if level and i == num_res_blocks:
602
+ out_ch = ch
603
+ layers.append(
604
+ ResBlock(
605
+ ch,
606
+ time_embed_dim,
607
+ dropout,
608
+ out_channels=out_ch,
609
+ dims=dims,
610
+ use_checkpoint=use_checkpoint,
611
+ use_scale_shift_norm=use_scale_shift_norm,
612
+ up=True,
613
+ )
614
+ if resblock_updown
615
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
616
+ )
617
+ ds //= 2
618
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
619
+ self._feature_size += ch
620
+
621
+ self.out = nn.Sequential(
622
+ normalization(ch),
623
+ nn.SiLU(),
624
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
625
+ )
626
+
627
+ def convert_to_fp16(self):
628
+ """
629
+ Convert the torso of the model to float16.
630
+ """
631
+ self.input_blocks.apply(convert_module_to_f16)
632
+ self.middle_block.apply(convert_module_to_f16)
633
+ self.output_blocks.apply(convert_module_to_f16)
634
+
635
+ def convert_to_fp32(self):
636
+ """
637
+ Convert the torso of the model to float32.
638
+ """
639
+ self.input_blocks.apply(convert_module_to_f32)
640
+ self.middle_block.apply(convert_module_to_f32)
641
+ self.output_blocks.apply(convert_module_to_f32)
642
+
643
+ def forward(self, x, timesteps, y=None):
644
+ """
645
+ Apply the model to an input batch.
646
+
647
+ :param x: an [N x C x ...] Tensor of inputs.
648
+ :param timesteps: a 1-D batch of timesteps.
649
+ :param y: an [N] Tensor of labels, if class-conditional.
650
+ :return: an [N x C x ...] Tensor of outputs.
651
+ """
652
+ assert (y is not None) == (
653
+ self.num_classes is not None
654
+ ), "must specify y if and only if the model is class-conditional"
655
+
656
+ hs = []
657
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
658
+
659
+ if self.num_classes is not None:
660
+ assert y.shape == (x.shape[0],)
661
+ emb = emb + self.label_emb(y)
662
+
663
+ h = x.type(self.dtype)
664
+ for module in self.input_blocks:
665
+ h = module(h, emb)
666
+ hs.append(h)
667
+ h = self.middle_block(h, emb)
668
+ for module in self.output_blocks:
669
+ h = th.cat([h, hs.pop()], dim=1)
670
+ h = module(h, emb)
671
+ h = h.type(x.dtype)
672
+ return self.out(h)
673
+
674
+
675
+ class SuperResModel(UNetModel):
676
+ """
677
+ A UNetModel that performs super-resolution.
678
+
679
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
680
+ """
681
+
682
+ def __init__(self, image_size, in_channels, *args, **kwargs):
683
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
684
+
685
+ def forward(self, x, timesteps, low_res=None, **kwargs):
686
+ _, _, new_height, new_width = x.shape
687
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
688
+ x = th.cat([x, upsampled], dim=1)
689
+ return super().forward(x, timesteps, **kwargs)
690
+
691
+
692
+ class EncoderUNetModel(nn.Module):
693
+ """
694
+ The half UNet model with attention and timestep embedding.
695
+
696
+ For usage, see UNet.
697
+ """
698
+
699
+ def __init__(
700
+ self,
701
+ image_size,
702
+ in_channels,
703
+ model_channels,
704
+ out_channels,
705
+ num_res_blocks,
706
+ attention_resolutions,
707
+ dropout=0,
708
+ channel_mult=(1, 2, 4, 8),
709
+ conv_resample=True,
710
+ dims=2,
711
+ use_checkpoint=False,
712
+ use_fp16=False,
713
+ num_heads=1,
714
+ num_head_channels=-1,
715
+ num_heads_upsample=-1,
716
+ use_scale_shift_norm=False,
717
+ resblock_updown=False,
718
+ use_new_attention_order=False,
719
+ pool="adaptive",
720
+ ):
721
+ super().__init__()
722
+
723
+ if num_heads_upsample == -1:
724
+ num_heads_upsample = num_heads
725
+
726
+ self.in_channels = in_channels
727
+ self.model_channels = model_channels
728
+ self.out_channels = out_channels
729
+ self.num_res_blocks = num_res_blocks
730
+ self.attention_resolutions = attention_resolutions
731
+ self.dropout = dropout
732
+ self.channel_mult = channel_mult
733
+ self.conv_resample = conv_resample
734
+ self.use_checkpoint = use_checkpoint
735
+ self.dtype = th.float16 if use_fp16 else th.float32
736
+ self.num_heads = num_heads
737
+ self.num_head_channels = num_head_channels
738
+ self.num_heads_upsample = num_heads_upsample
739
+
740
+ time_embed_dim = model_channels * 4
741
+ self.time_embed = nn.Sequential(
742
+ linear(model_channels, time_embed_dim),
743
+ nn.SiLU(),
744
+ linear(time_embed_dim, time_embed_dim),
745
+ )
746
+
747
+ ch = int(channel_mult[0] * model_channels)
748
+ self.input_blocks = nn.ModuleList(
749
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
750
+ )
751
+ self._feature_size = ch
752
+ input_block_chans = [ch]
753
+ ds = 1
754
+ for level, mult in enumerate(channel_mult):
755
+ for _ in range(num_res_blocks):
756
+ layers = [
757
+ ResBlock(
758
+ ch,
759
+ time_embed_dim,
760
+ dropout,
761
+ out_channels=int(mult * model_channels),
762
+ dims=dims,
763
+ use_checkpoint=use_checkpoint,
764
+ use_scale_shift_norm=use_scale_shift_norm,
765
+ )
766
+ ]
767
+ ch = int(mult * model_channels)
768
+ if ds in attention_resolutions:
769
+ layers.append(
770
+ AttentionBlock(
771
+ ch,
772
+ use_checkpoint=use_checkpoint,
773
+ num_heads=num_heads,
774
+ num_head_channels=num_head_channels,
775
+ use_new_attention_order=use_new_attention_order,
776
+ )
777
+ )
778
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
779
+ self._feature_size += ch
780
+ input_block_chans.append(ch)
781
+ if level != len(channel_mult) - 1:
782
+ out_ch = ch
783
+ self.input_blocks.append(
784
+ TimestepEmbedSequential(
785
+ ResBlock(
786
+ ch,
787
+ time_embed_dim,
788
+ dropout,
789
+ out_channels=out_ch,
790
+ dims=dims,
791
+ use_checkpoint=use_checkpoint,
792
+ use_scale_shift_norm=use_scale_shift_norm,
793
+ down=True,
794
+ )
795
+ if resblock_updown
796
+ else Downsample(
797
+ ch, conv_resample, dims=dims, out_channels=out_ch
798
+ )
799
+ )
800
+ )
801
+ ch = out_ch
802
+ input_block_chans.append(ch)
803
+ ds *= 2
804
+ self._feature_size += ch
805
+
806
+ self.middle_block = TimestepEmbedSequential(
807
+ ResBlock(
808
+ ch,
809
+ time_embed_dim,
810
+ dropout,
811
+ dims=dims,
812
+ use_checkpoint=use_checkpoint,
813
+ use_scale_shift_norm=use_scale_shift_norm,
814
+ ),
815
+ AttentionBlock(
816
+ ch,
817
+ use_checkpoint=use_checkpoint,
818
+ num_heads=num_heads,
819
+ num_head_channels=num_head_channels,
820
+ use_new_attention_order=use_new_attention_order,
821
+ ),
822
+ ResBlock(
823
+ ch,
824
+ time_embed_dim,
825
+ dropout,
826
+ dims=dims,
827
+ use_checkpoint=use_checkpoint,
828
+ use_scale_shift_norm=use_scale_shift_norm,
829
+ ),
830
+ )
831
+ self._feature_size += ch
832
+ self.pool = pool
833
+ if pool == "adaptive":
834
+ self.out = nn.Sequential(
835
+ normalization(ch),
836
+ nn.SiLU(),
837
+ nn.AdaptiveAvgPool2d((1, 1)),
838
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
839
+ nn.Flatten(),
840
+ )
841
+ elif pool == "attention":
842
+ assert num_head_channels != -1
843
+ self.out = nn.Sequential(
844
+ normalization(ch),
845
+ nn.SiLU(),
846
+ AttentionPool2d(
847
+ (image_size // ds), ch, num_head_channels, out_channels
848
+ ),
849
+ )
850
+ elif pool == "spatial":
851
+ self.out = nn.Sequential(
852
+ nn.Linear(self._feature_size, 2048),
853
+ nn.ReLU(),
854
+ nn.Linear(2048, self.out_channels),
855
+ )
856
+ elif pool == "spatial_v2":
857
+ self.out = nn.Sequential(
858
+ nn.Linear(self._feature_size, 2048),
859
+ normalization(2048),
860
+ nn.SiLU(),
861
+ nn.Linear(2048, self.out_channels),
862
+ )
863
+ else:
864
+ raise NotImplementedError(f"Unexpected {pool} pooling")
865
+
866
+ def convert_to_fp16(self):
867
+ """
868
+ Convert the torso of the model to float16.
869
+ """
870
+ self.input_blocks.apply(convert_module_to_f16)
871
+ self.middle_block.apply(convert_module_to_f16)
872
+
873
+ def convert_to_fp32(self):
874
+ """
875
+ Convert the torso of the model to float32.
876
+ """
877
+ self.input_blocks.apply(convert_module_to_f32)
878
+ self.middle_block.apply(convert_module_to_f32)
879
+
880
+ def forward(self, x, timesteps):
881
+ """
882
+ Apply the model to an input batch.
883
+
884
+ :param x: an [N x C x ...] Tensor of inputs.
885
+ :param timesteps: a 1-D batch of timesteps.
886
+ :return: an [N x K] Tensor of outputs.
887
+ """
888
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
889
+
890
+ results = []
891
+ h = x.type(self.dtype)
892
+ for module in self.input_blocks:
893
+ h = module(h, emb)
894
+ if self.pool.startswith("spatial"):
895
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
896
+ h = self.middle_block(h, emb)
897
+ if self.pool.startswith("spatial"):
898
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
899
+ h = th.cat(results, axis=-1)
900
+ return self.out(h)
901
+ else:
902
+ h = h.type(x.dtype)
903
+ return self.out(h)
904
+
905
+
906
+ #________________________________ tfg model ________________________________#
907
+ class TFGModel(UNetModel):
908
+ '''
909
+ Talking Face Generation using UNet model
910
+ '''
911
+ def __init__(self,
912
+ image_size,
913
+ in_channels,
914
+ model_channels,
915
+ out_channels,
916
+ *args,
917
+ use_ref = False,
918
+ nframes = 1,
919
+ nrefer = 0,
920
+ use_audio = False,
921
+ audio_encoder_kwargs=None,
922
+ audio_as_style=False, # audio conditioned as style instead of concatenation in the middle
923
+ audio_as_style_encoder_mlp=False, # use mlp instead of audio encoder
924
+ **kwargs
925
+ ):
926
+ if use_ref:
927
+ super().__init__(image_size, in_channels * (1+1+nrefer), model_channels, out_channels * 1, *args, **kwargs)
928
+ else:
929
+ super().__init__(image_size, in_channels * (1+1), model_channels, out_channels * 1, *args, **kwargs)
930
+
931
+
932
+ self.use_ref = use_ref
933
+ self.nframes = nframes
934
+ self.nrefer = nrefer
935
+ self.use_audio = use_audio
936
+
937
+ if self.use_audio:
938
+ if audio_encoder_kwargs is not None:
939
+ self.audio_encoder_kwargs = audio_encoder_kwargs
940
+ else:
941
+ self.audio_encoder_kwargs = {}
942
+
943
+ self.audio_as_style = audio_as_style
944
+ self.audio_as_style_encoder_mlp = audio_as_style_encoder_mlp
945
+
946
+ self.audio_encoder = TFGAudioEncoder(
947
+ nframes = self.nframes,
948
+ dropout = self.dropout,
949
+ conv_resample = self.conv_resample,
950
+ dims = self.dims,
951
+ use_checkpoints = self.use_checkpoint,
952
+ use_fp16 = self.use_fp16,
953
+ use_scale_shift_norm = self.use_scale_shift_norm,
954
+ resblock_updown = self.resblock_updown,
955
+ **self.audio_encoder_kwargs
956
+ )
957
+
958
+ if not self.audio_as_style:
959
+ #concatenate audio encoding to the video encoding
960
+ old_middle_block_head = self.middle_block[0]
961
+ mid_img_ch = old_middle_block_head.channels
962
+ mid_aud_ch = self.audio_encoder.out_channels
963
+ self.middle_block[0] = ResBlock(
964
+ mid_img_ch + mid_aud_ch, #combined image and audio channels
965
+ old_middle_block_head.emb_channels,
966
+ old_middle_block_head.dropout,
967
+ out_channels = old_middle_block_head.out_channels,
968
+ dims = old_middle_block_head.dims,
969
+ use_checkpoint=old_middle_block_head.use_checkpoint,
970
+ use_scale_shift_norm=old_middle_block_head.use_scale_shift_norm,
971
+ )
972
+ else: # audio as style
973
+ if self.audio_as_style_encoder_mlp:
974
+ old_conv_encoder = self.audio_encoder
975
+ audio_dim = old_conv_encoder.audio_dim
976
+ audio_frames_per_video = old_conv_encoder.audio_frames_per_video
977
+ self.audio_encoder = nn.Sequential(
978
+ nn.Flatten(),
979
+ linear(audio_dim+audio_frames_per_video, self.time_embed_dim),
980
+ normalization(self.time_embed_dim),
981
+ nn.SiLU(),
982
+ linear(self.time_embed_dim, self.time_embed_dim),
983
+ )
984
+ else: # use conv_encoder+mlp to get style
985
+ # similar to the classifier defined
986
+ self.audio_encoder_to_style = nn.Sequential(
987
+ normalization(self.audio_encoder.out_channels),
988
+ nn.SiLU(),
989
+ nn.AdaptiveAvgPool2d((1,1)),
990
+ zero_module( #-> makes inital weights 0
991
+ conv_nd(self.dims, self.audio_encoder.out_channels, self.time_embed_dim, 1)
992
+ ),
993
+ nn.Flatten(),
994
+ )
995
+
996
+ def convert_to_fp16(self):
997
+ """
998
+ Convert the torso of the model to float16.
999
+ """
1000
+ self.input_blocks.apply(convert_module_to_f16)
1001
+ self.middle_block.apply(convert_module_to_f16)
1002
+ self.output_blocks.apply(convert_module_to_f16)
1003
+ if self.use_audio:
1004
+ self.audio_encoder.apply(convert_module_to_f16)
1005
+ if self.audio_as_style:
1006
+ self.audio_encoder_to_style.apply(convert_module_to_f16)
1007
+
1008
+ def convert_to_fp32(self):
1009
+ """
1010
+ Convert the torso of the model to float32.
1011
+ """
1012
+ self.input_blocks.apply(convert_module_to_f32)
1013
+ self.middle_block.apply(convert_module_to_f32)
1014
+ self.output_blocks.apply(convert_module_to_f32)
1015
+ if self.use_audio:
1016
+ self.audio_encoder.apply(convert_module_to_f32)
1017
+ if self.audio_as_style:
1018
+ self.audio_encoder_to_style.apply(convert_module_to_f32)
1019
+
1020
+ def forward(self, x, timesteps, cond_img=None, mask = None, ref_img=None, indiv_mels=None, **kwargs):
1021
+
1022
+ #preprocessing
1023
+ x = x * mask + (1. - mask) * cond_img # mask the top half of the input
1024
+ x = th.cat([x,cond_img], dim=1)
1025
+ if self.use_ref:
1026
+ x=th.cat([x,ref_img], dim=1)
1027
+
1028
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1029
+
1030
+
1031
+ if self.use_audio:
1032
+ if self.audio_as_style:
1033
+ #audio encoder
1034
+ if self.audio_as_style_encoder_mlp:#mlp uses fp32
1035
+ a = self.audio_encoder(indiv_mels)
1036
+ a = self.audio_encoder_to_style(a)
1037
+ a = a.type(self.dtype)
1038
+ else:# conv uses fp16
1039
+ a = indiv_mels.type(self.dtype)
1040
+ a = self.audio_encoder(a)
1041
+ a= self.audio_encoder_to_style(a)
1042
+ #combine
1043
+ emb = emb + a
1044
+ #video encoder
1045
+ hs = []
1046
+ h = x.type(self.dtype)
1047
+ for module in self.input_blocks:
1048
+ h = module(h, emb)
1049
+ hs.append(h)
1050
+
1051
+ else: # concat audio in the middle
1052
+ #audio encoder
1053
+ a = indiv_mels.type(self.dtype)
1054
+ a = self.audio_encoder(a)
1055
+ #video encoder
1056
+ hs = []
1057
+ h = x.type(self.dtype)
1058
+ for module in self.input_blocks:
1059
+ h = module(h, emb)
1060
+ hs.append(h)
1061
+ #combine
1062
+ h = th.cat([h,a], dim=1)
1063
+
1064
+ #middle block
1065
+ h = self.middle_block(h, emb)
1066
+
1067
+ # decoder
1068
+ for module in self.output_blocks:
1069
+ h = th.cat([h, hs.pop()], dim=1)
1070
+ h = module(h, emb)
1071
+ h = h.type(x.dtype)
1072
+ return self.out(h)
1073
+
1074
+
1075
+ class TFGAudioEncoder(nn.Module):
1076
+ """
1077
+ Audio Encoder
1078
+
1079
+ with audio_dim = 80,
1080
+ audio_frames_per_video = 16
1081
+ init_spatial_dim = 64
1082
+ model_channels=32
1083
+ channel_mult=(1,2,3,4)
1084
+
1085
+ following are the output shapes ->
1086
+ init: [BF, 80, 16]
1087
+ after in_block: [BF, 64, 16]
1088
+ adding new dim [BF, 1, 64, 16]
1089
+ encoder block before entering the loop: [BF, 32, 64, 16]
1090
+ level: 0
1091
+ 0 _ 0 : [BF, 32, 64, 16]
1092
+ 0 _ 1 : [BF, 32, 64, 16]
1093
+ 0 _ 2 : [BF, 32, 32, 16]
1094
+ level: 1
1095
+ 1 _ 0 : [BF, 64, 32, 16]
1096
+ 1 _ 1 : [BF, 64, 32, 16]
1097
+ 1 _ 2 : [BF, 64, 16, 16]
1098
+ level: 2
1099
+ 2 _ 0 : [BF, 96, 16, 16]
1100
+ 2 _ 1 : [BF, 96, 16, 16]
1101
+ 2 _ 2 : [BF, 96, 8, 8]
1102
+ level: 3
1103
+ 3 _ 0 : [BF, 128, 8, 8]
1104
+ 3 _ 1 : [BF, 128, 8, 8]
1105
+ middle block: [BF, 128, 8, 8]
1106
+ out: [BF, 128, 8, 8]
1107
+ """
1108
+ def __init__(
1109
+ self,
1110
+ audio_dim = 80,
1111
+ audio_frames_per_video = 16,
1112
+ nframes=1,
1113
+
1114
+ init_spatial_dim = 64,
1115
+ model_channels=32,
1116
+ out_channels=-1,
1117
+ num_res_blocks=2,
1118
+ dropout=0,
1119
+ channel_mult=(1,2,3,4), #(1,1,2,4,8),
1120
+ conv_resample=True,
1121
+ dims=2,
1122
+ use_checkpoint = False,
1123
+ use_fp16=False,
1124
+ use_scale_shift_norm=False,
1125
+ resblock_updown=False,
1126
+ **kwargs
1127
+ ):
1128
+ super().__init__()
1129
+ self.audio_dim = audio_dim
1130
+ self.audio_frames_per_video = audio_frames_per_video
1131
+ self.nframes = nframes
1132
+ self.model_channels = model_channels
1133
+ self.out_channels = out_channels if out_channels > 0 else model_channels * channel_mult[-1]
1134
+ self.num_res_blocks = num_res_blocks
1135
+ self.dropout = dropout
1136
+ self.channel_mult = channel_mult
1137
+ self.conv_resample = conv_resample
1138
+ self.use_checkpoint = use_checkpoint
1139
+ self.dtype = th.float16 if use_fp16 else th.float32
1140
+
1141
+ time_embed_dim = model_channels
1142
+ self.time_embed = nn.Sequential(
1143
+ linear(model_channels, time_embed_dim),
1144
+ nn.SiLU(),
1145
+ linear(time_embed_dim, time_embed_dim),
1146
+ )
1147
+
1148
+ ch = int(channel_mult[0] * model_channels)
1149
+ # init_spatial_dim = 4 * ( 2** (len(channel_mult)-1))
1150
+
1151
+ # convert spatial dim 80->64 using Conv1D: [N*F,80,16] -> [N*F, 64, 16]
1152
+ _conv_dim, _in_channels, _out_channels = 1, self.audio_dim, init_spatial_dim
1153
+ self.input_block = TimestepEmbedSequential(
1154
+ conv_nd(_conv_dim, _in_channels, _out_channels, 3, padding=1),
1155
+ normalization(_out_channels),
1156
+ nn.SiLU()
1157
+ )
1158
+
1159
+ # manually reshape [N*F, 64, 16] -> [N*F, 1, 64, 16] in __forward__()
1160
+
1161
+ # [NF, 1, 64, 16] -> [N*F, model_channels*channel_mult[0], 64, 16]
1162
+ # can't use a resblock, bc groupnorm needs 32 sized group of channels
1163
+ self.encoder_blocks = nn.ModuleList(
1164
+ [
1165
+ TimestepEmbedSequential(
1166
+ conv_nd(dims, 1, ch, 3, padding=1 )
1167
+ )
1168
+
1169
+ ]
1170
+ )
1171
+
1172
+ self._feature_size = ch
1173
+ input_block_chans = [ch]
1174
+
1175
+ ds = 1
1176
+ for level, mult in enumerate(channel_mult):
1177
+ for _ in range(num_res_blocks):
1178
+ layers = [
1179
+ ResBlock(
1180
+ ch,
1181
+ time_embed_dim,
1182
+ dropout,
1183
+ out_channels=int(mult*model_channels),
1184
+ dims = dims,
1185
+ use_checkpoint=use_checkpoint,
1186
+ use_scale_shift_norm=use_scale_shift_norm,
1187
+ )
1188
+ ]
1189
+ ch = int(mult*model_channels)
1190
+ self.encoder_blocks.append(TimestepEmbedSequential(*layers))
1191
+ self._feature_size += ch
1192
+ input_block_chans.append(ch)
1193
+ if level != len(channel_mult)-1:
1194
+ out_ch = ch
1195
+ self.encoder_blocks.append(
1196
+ TimestepEmbedSequential(
1197
+ ResBlock(
1198
+ ch,
1199
+ time_embed_dim,
1200
+ dropout,
1201
+ out_channels=out_ch,
1202
+ dims = dims,
1203
+ use_checkpoint=use_checkpoint,
1204
+ use_scale_shift_norm=use_scale_shift_norm,
1205
+ down = True,
1206
+ down_stride = (2,1) if (init_spatial_dim//ds) > self.audio_frames_per_video else (2,2),
1207
+ )
1208
+ if resblock_updown
1209
+ else Downsample(
1210
+ ch, conv_resample, dims=dims, out_channels=out_ch,
1211
+ down_stride = (2,1) if (init_spatial_dim//ds) > self.audio_frames_per_video else (2,2),
1212
+ )
1213
+ )
1214
+ )
1215
+ ch = out_ch
1216
+ input_block_chans.append(ch)
1217
+ ds*=2
1218
+ self._feature_size += ch
1219
+
1220
+ self.middle_block = TimestepEmbedSequential(
1221
+ ResBlock(
1222
+ ch,
1223
+ time_embed_dim,
1224
+ dropout,
1225
+ out_channels=self.out_channels,
1226
+ dims=dims,
1227
+ use_checkpoint=use_checkpoint,
1228
+ use_scale_shift_norm=use_scale_shift_norm,
1229
+ ),
1230
+ )
1231
+ self._feature_size += ch
1232
+
1233
+
1234
+ # self.out = Upsample(self.out_channels, False, dims)
1235
+ self.out = nn.Identity()
1236
+
1237
+ def convert_to_fp16(self):
1238
+ """
1239
+ Convert the torso of the model to float16.
1240
+ """
1241
+ self.input_blocks.apply(convert_module_to_f16)
1242
+ self.middle_block.apply(convert_module_to_f16)
1243
+ self.output_blocks.apply(convert_module_to_f16)
1244
+
1245
+ def convert_to_fp32(self):
1246
+ """
1247
+ Convert the torso of the model to float32.
1248
+ """
1249
+ self.input_blocks.apply(convert_module_to_f32)
1250
+ self.middle_block.apply(convert_module_to_f32)
1251
+ self.output_blocks.apply(convert_module_to_f32)
1252
+
1253
+
1254
+ def forward(self, x):
1255
+ h = x.type(self.dtype)
1256
+ BF_in, H_in, W_in = h.shape
1257
+
1258
+ #Fixed time embedding,(for using the same modules)
1259
+ t = th.zeros(BF_in, dtype=th.long, device = x.device)
1260
+ emb = self.time_embed(timestep_embedding(t, self.model_channels))
1261
+
1262
+ #80 -> 64 using Conv1D
1263
+ h= self.input_block(h, emb)
1264
+ _, H, W = h.shape
1265
+ h = h.reshape(BF_in, 1, H, W) #[B*F, 64, 16] -> [B*F, 1, 64, 4]
1266
+ # call encoder blocks
1267
+ for module in self.encoder_blocks:
1268
+ h = module(h,emb)
1269
+ h = self.middle_block(h,emb)
1270
+ h = h.type(x.dtype)
1271
+ return self.out(h) # -> [B*F, 256, 8, 8]
1272
+ #______________________________________________________________________#
1273
+
1274
+
1275
+
guided-diffusion/setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name="guided-diffusion",
5
+ py_modules=["guided_diffusion"],
6
+ install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
7
+ )
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa==0.9.2
2
+ opencv-python==4.5.5.64
3
+ opencv-contrib-python==4.6.0.66
4
+ tensorboard==2.11.0
5
+ tqdm==4.64.1
6
+ mpi4py-mpich==3.1.2
7
+ av==9.2.0
8
+ torch --extra-index-url https://download.pytorch.org/whl/cu113
9
+ torchvision --extra-index-url https://download.pytorch.org/whl/cu113
10
+ torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
11
+ -e ./guided-diffusion
scripts/inference.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #set paths and arguments
4
+ real_video_root='dataset/VoxCeleb2/vox2_test_mp4/mp4/'
5
+ model_path="checkpoints/checkpoint.pt"
6
+ sample_path="output_dir"
7
+ sample_mode="cross" # or "reconstruction"
8
+ NUM_GPUS=2
9
+
10
+
11
+
12
+
13
+
14
+ #cross vs reconstruction
15
+ filelist_recon='dataset/filelists/voxceleb2_test_n_5000_reconstruction_5k.txt'
16
+ filelist_cross='dataset/filelists/voxceleb2_test_n_5000_seed_797_cross_5K.txt'
17
+ if [ "$sample_mode" = "reconstruction" ]; then
18
+ sample_input_flags="--sampling_input_type=first_frame --sampling_ref_type=first_frame"
19
+ filelist=$filelist_recon
20
+ elif [ "$sample_mode" = "cross" ]; then
21
+ sample_input_flags="--sampling_input_type=gt --sampling_ref_type=gt"
22
+ filelist=$filelist_cross
23
+ else
24
+ echo "Error: sample_mode can only be \"cross\" or \"reconstruction\""
25
+ exit 0
26
+ fi
27
+ test_video_dir=$real_video_root
28
+ mkdir -p $sample_path
29
+ MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
30
+ DIFFUSION_FLAGS="--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
31
+ SAMPLE_FLAGS="--sampling_seed=7 $sample_input_flags --timestep_respacing ddim25 --use_ddim True --model_path=$model_path --sample_path=$sample_path"
32
+ DATA_FLAGS="--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32 "
33
+ TFG_FLAGS="--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
34
+ GEN_FLAGS="--generate_from_filelist 1 --test_video_dir=$test_video_dir --filelist=$filelist --save_orig=False --face_det_batch_size 64 --pads 0,0,0,0"
35
+
36
+ if [ "$NUM_GPUS" -gt 1 ]; then
37
+ mpiexec -n $NUM_GPUS python generate_dist.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
38
+ else
39
+ python generate.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
40
+ fi
scripts/inference_single_video.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #set paths and arguments
4
+ sample_mode="cross" # or "reconstruction"
5
+ NUM_GPUS=1
6
+ generate_from_filelist=0
7
+ video_path="path/to/video.mp4"
8
+ audio_path="path/to/audio.mp4"
9
+ out_path="path/to/output.mp4"
10
+ model_path="path/to/model.pt"
11
+
12
+
13
+
14
+ #cross vs reconstruction
15
+ if [ "$sample_mode" = "reconstruction" ]; then
16
+ sample_input_flags="--sampling_input_type=first_frame --sampling_ref_type=first_frame"
17
+ elif [ "$sample_mode" = "cross" ]; then
18
+ sample_input_flags="--sampling_input_type=gt --sampling_ref_type=gt"
19
+ else
20
+ echo "Error: sample_mode can only be \"cross\" or \"reconstruction\""
21
+ exit 0
22
+ fi
23
+
24
+ MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
25
+ DIFFUSION_FLAGS="--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
26
+ SAMPLE_FLAGS="--sampling_seed=7 $sample_input_flags --timestep_respacing ddim25 --use_ddim True --model_path=$model_path"
27
+ DATA_FLAGS="--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32 "
28
+ TFG_FLAGS="--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
29
+ GEN_FLAGS="--generate_from_filelist $generate_from_filelist --video_path=$video_path --audio_path=$audio_path --out_path=$out_path --save_orig=False --face_det_batch_size 64 --pads 0,0,0,0 --is_voxceleb2=False"
30
+
31
+ if [ "$NUM_GPUS" -gt 1 ]; then
32
+ mpiexec -n $NUM_GPUS python generate_dist.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
33
+ else
34
+ python generate.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
35
+ fi