Upload 49 files
Browse files- .gitattributes +2 -0
- app.py +66 -0
- assets/id02548.0pAkJZmlFqc.00001_id04570.0YMGn6BI9rg.00001.gif +3 -0
- assets/website_gif_v2.gif +3 -0
- audio/__init__.py +0 -0
- audio/audio.py +136 -0
- audio/hparams.py +66 -0
- checkpoints/checkpoint.pt +3 -0
- dataset/LRW/lrw_fullpath.py +25 -0
- dataset/filelists/lrw_cross.txt +0 -0
- dataset/filelists/lrw_cross_relative_path.txt +0 -0
- dataset/filelists/lrw_reconstruction.txt +0 -0
- dataset/filelists/lrw_reconstruction_relative_path.txt +0 -0
- dataset/filelists/voxceleb2_test_n_5000_reconstruction_5k.txt +0 -0
- dataset/filelists/voxceleb2_test_n_5000_seed_797_cross_5K.txt +0 -0
- dataset/filelists/voxceleb2_test_n_500_reconstruction.txt +500 -0
- dataset/filelists/voxceleb2_test_n_500_seed_797_cross.txt +500 -0
- face_detection/README.md +1 -0
- face_detection/__init__.py +7 -0
- face_detection/api.py +98 -0
- face_detection/detection/__init__.py +1 -0
- face_detection/detection/core.py +130 -0
- face_detection/detection/sfd/__init__.py +1 -0
- face_detection/detection/sfd/bbox.py +129 -0
- face_detection/detection/sfd/detect.py +112 -0
- face_detection/detection/sfd/net_s3fd.py +129 -0
- face_detection/detection/sfd/sfd_detector.py +59 -0
- face_detection/models.py +261 -0
- face_detection/utils.py +313 -0
- generate.py +398 -0
- generate_dist.py +428 -0
- guided-diffusion/LICENSE +21 -0
- guided-diffusion/guided_diffusion/__init__.py +3 -0
- guided-diffusion/guided_diffusion/dist_util.py +94 -0
- guided-diffusion/guided_diffusion/fp16_util.py +237 -0
- guided-diffusion/guided_diffusion/gaussian_diffusion.py +843 -0
- guided-diffusion/guided_diffusion/image_datasets.py +167 -0
- guided-diffusion/guided_diffusion/logger.py +491 -0
- guided-diffusion/guided_diffusion/losses.py +77 -0
- guided-diffusion/guided_diffusion/lpips.py +20 -0
- guided-diffusion/guided_diffusion/nn.py +170 -0
- guided-diffusion/guided_diffusion/resample.py +154 -0
- guided-diffusion/guided_diffusion/respace.py +128 -0
- guided-diffusion/guided_diffusion/script_util.py +614 -0
- guided-diffusion/guided_diffusion/tfg_data_util.py +75 -0
- guided-diffusion/guided_diffusion/unet.py +1275 -0
- guided-diffusion/setup.py +7 -0
- requirements.txt +11 -0
- scripts/inference.sh +40 -0
- scripts/inference_single_video.sh +35 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/id02548.0pAkJZmlFqc.00001_id04570.0YMGn6BI9rg.00001.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/website_gif_v2.gif filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import subprocess
|
| 3 |
+
import os
|
| 4 |
+
import requests
|
| 5 |
+
|
| 6 |
+
def process_video(audio_file, video_file):
|
| 7 |
+
# Define file paths
|
| 8 |
+
audio_path = audio_file.name
|
| 9 |
+
video_path = video_file.name
|
| 10 |
+
out_path = "output_video.mp4"
|
| 11 |
+
|
| 12 |
+
# Save uploaded files
|
| 13 |
+
audio_file.save(audio_path)
|
| 14 |
+
video_file.save(video_path)
|
| 15 |
+
|
| 16 |
+
# Define command flags
|
| 17 |
+
sample_mode = "cross" # or "reconstruction"
|
| 18 |
+
generate_from_filelist = 0
|
| 19 |
+
model_path = "checkpoints/checkpoint.pt"
|
| 20 |
+
pads = "0,0,0,0"
|
| 21 |
+
|
| 22 |
+
if sample_mode == "reconstruction":
|
| 23 |
+
sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
|
| 24 |
+
elif sample_mode == "cross":
|
| 25 |
+
sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
|
| 26 |
+
else:
|
| 27 |
+
return "Error: sample_mode can only be \"cross\" or \"reconstruction\""
|
| 28 |
+
|
| 29 |
+
MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
|
| 30 |
+
DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
|
| 31 |
+
SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
|
| 32 |
+
DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
|
| 33 |
+
TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
|
| 34 |
+
GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path={out_path} --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
|
| 35 |
+
|
| 36 |
+
# Combine all flags into one command
|
| 37 |
+
command = f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
|
| 38 |
+
|
| 39 |
+
# Execute the command
|
| 40 |
+
try:
|
| 41 |
+
subprocess.run(command, shell=True, check=True)
|
| 42 |
+
return out_path
|
| 43 |
+
except subprocess.CalledProcessError as e:
|
| 44 |
+
return f"Error processing video: {e}"
|
| 45 |
+
|
| 46 |
+
# Clean up the files after processing
|
| 47 |
+
os.remove(audio_path)
|
| 48 |
+
os.remove(video_path)
|
| 49 |
+
|
| 50 |
+
# Delete output video after sending to the user
|
| 51 |
+
os.remove(out_path)
|
| 52 |
+
|
| 53 |
+
# Create a Gradio interface
|
| 54 |
+
iface = gr.Interface(
|
| 55 |
+
fn=process_video,
|
| 56 |
+
inputs=[
|
| 57 |
+
gr.inputs.Audio(label="Input Audio", type="file"),
|
| 58 |
+
gr.inputs.Video(label="Input Video", type="file")
|
| 59 |
+
],
|
| 60 |
+
outputs=gr.outputs.Video(label="Processed Video"),
|
| 61 |
+
title="Audio-Video Processing",
|
| 62 |
+
description="Upload an audio file and a video file to process the video based on the audio input."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Launch the interface
|
| 66 |
+
iface.launch()
|
assets/id02548.0pAkJZmlFqc.00001_id04570.0YMGn6BI9rg.00001.gif
ADDED
|
Git LFS Details
|
assets/website_gif_v2.gif
ADDED
|
Git LFS Details
|
audio/__init__.py
ADDED
|
File without changes
|
audio/audio.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import librosa.filters
|
| 3 |
+
import numpy as np
|
| 4 |
+
# import tensorflow as tf
|
| 5 |
+
from scipy import signal
|
| 6 |
+
from scipy.io import wavfile
|
| 7 |
+
from .hparams import hparams as hp
|
| 8 |
+
|
| 9 |
+
def load_wav(path, sr):
|
| 10 |
+
return librosa.core.load(path, sr=sr)[0]
|
| 11 |
+
|
| 12 |
+
def save_wav(wav, path, sr):
|
| 13 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
| 14 |
+
#proposed by @dsmiller
|
| 15 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
| 16 |
+
|
| 17 |
+
def save_wavenet_wav(wav, path, sr):
|
| 18 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
| 19 |
+
|
| 20 |
+
def preemphasis(wav, k, preemphasize=True):
|
| 21 |
+
if preemphasize:
|
| 22 |
+
return signal.lfilter([1, -k], [1], wav)
|
| 23 |
+
return wav
|
| 24 |
+
|
| 25 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
| 26 |
+
if inv_preemphasize:
|
| 27 |
+
return signal.lfilter([1], [1, -k], wav)
|
| 28 |
+
return wav
|
| 29 |
+
|
| 30 |
+
def get_hop_size():
|
| 31 |
+
hop_size = hp.hop_size
|
| 32 |
+
if hop_size is None:
|
| 33 |
+
assert hp.frame_shift_ms is not None
|
| 34 |
+
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
| 35 |
+
return hop_size
|
| 36 |
+
|
| 37 |
+
def linearspectrogram(wav):
|
| 38 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
| 39 |
+
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
| 40 |
+
|
| 41 |
+
if hp.signal_normalization:
|
| 42 |
+
return _normalize(S)
|
| 43 |
+
return S
|
| 44 |
+
|
| 45 |
+
def melspectrogram(wav):
|
| 46 |
+
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
| 47 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
| 48 |
+
|
| 49 |
+
if hp.signal_normalization:
|
| 50 |
+
return _normalize(S)
|
| 51 |
+
return S
|
| 52 |
+
|
| 53 |
+
def _lws_processor():
|
| 54 |
+
import lws
|
| 55 |
+
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
| 56 |
+
|
| 57 |
+
def _stft(y):
|
| 58 |
+
if hp.use_lws:
|
| 59 |
+
return _lws_processor(hp).stft(y).T
|
| 60 |
+
else:
|
| 61 |
+
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
| 62 |
+
|
| 63 |
+
##########################################################
|
| 64 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
| 65 |
+
def num_frames(length, fsize, fshift):
|
| 66 |
+
"""Compute number of time frames of spectrogram
|
| 67 |
+
"""
|
| 68 |
+
pad = (fsize - fshift)
|
| 69 |
+
if length % fshift == 0:
|
| 70 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
| 71 |
+
else:
|
| 72 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
| 73 |
+
return M
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def pad_lr(x, fsize, fshift):
|
| 77 |
+
"""Compute left and right padding
|
| 78 |
+
"""
|
| 79 |
+
M = num_frames(len(x), fsize, fshift)
|
| 80 |
+
pad = (fsize - fshift)
|
| 81 |
+
T = len(x) + 2 * pad
|
| 82 |
+
r = (M - 1) * fshift + fsize - T
|
| 83 |
+
return pad, pad + r
|
| 84 |
+
##########################################################
|
| 85 |
+
#Librosa correct padding
|
| 86 |
+
def librosa_pad_lr(x, fsize, fshift):
|
| 87 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
| 88 |
+
|
| 89 |
+
# Conversions
|
| 90 |
+
_mel_basis = None
|
| 91 |
+
|
| 92 |
+
def _linear_to_mel(spectogram):
|
| 93 |
+
global _mel_basis
|
| 94 |
+
if _mel_basis is None:
|
| 95 |
+
_mel_basis = _build_mel_basis()
|
| 96 |
+
return np.dot(_mel_basis, spectogram)
|
| 97 |
+
|
| 98 |
+
def _build_mel_basis():
|
| 99 |
+
assert hp.fmax <= hp.sample_rate // 2
|
| 100 |
+
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
| 101 |
+
fmin=hp.fmin, fmax=hp.fmax)
|
| 102 |
+
|
| 103 |
+
def _amp_to_db(x):
|
| 104 |
+
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
| 105 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
| 106 |
+
|
| 107 |
+
def _db_to_amp(x):
|
| 108 |
+
return np.power(10.0, (x) * 0.05)
|
| 109 |
+
|
| 110 |
+
def _normalize(S):
|
| 111 |
+
if hp.allow_clipping_in_normalization:
|
| 112 |
+
if hp.symmetric_mels:
|
| 113 |
+
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
| 114 |
+
-hp.max_abs_value, hp.max_abs_value)
|
| 115 |
+
else:
|
| 116 |
+
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
| 117 |
+
|
| 118 |
+
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
| 119 |
+
if hp.symmetric_mels:
|
| 120 |
+
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
| 121 |
+
else:
|
| 122 |
+
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
| 123 |
+
|
| 124 |
+
def _denormalize(D):
|
| 125 |
+
if hp.allow_clipping_in_normalization:
|
| 126 |
+
if hp.symmetric_mels:
|
| 127 |
+
return (((np.clip(D, -hp.max_abs_value,
|
| 128 |
+
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
| 129 |
+
+ hp.min_level_db)
|
| 130 |
+
else:
|
| 131 |
+
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
| 132 |
+
|
| 133 |
+
if hp.symmetric_mels:
|
| 134 |
+
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
| 135 |
+
else:
|
| 136 |
+
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
audio/hparams.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from glob import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class HParams:
|
| 6 |
+
def __init__(self, **kwargs):
|
| 7 |
+
self.data = {}
|
| 8 |
+
|
| 9 |
+
for key, value in kwargs.items():
|
| 10 |
+
self.data[key] = value
|
| 11 |
+
|
| 12 |
+
def __getattr__(self, key):
|
| 13 |
+
if key not in self.data:
|
| 14 |
+
raise AttributeError("'HParams' object has no attribute %s" % key)
|
| 15 |
+
return self.data[key]
|
| 16 |
+
|
| 17 |
+
def set_hparam(self, key, value):
|
| 18 |
+
self.data[key] = value
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Default hyperparameters
|
| 22 |
+
hparams = HParams(
|
| 23 |
+
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
|
| 24 |
+
# network
|
| 25 |
+
rescale=True, # Whether to rescale audio prior to preprocessing
|
| 26 |
+
rescaling_max=0.9, # Rescaling value
|
| 27 |
+
|
| 28 |
+
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
| 29 |
+
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
| 30 |
+
# Does not work if n_ffit is not multiple of hop_size!!
|
| 31 |
+
use_lws=False,
|
| 32 |
+
|
| 33 |
+
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
|
| 34 |
+
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
| 35 |
+
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
| 36 |
+
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
| 37 |
+
|
| 38 |
+
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
|
| 39 |
+
|
| 40 |
+
# Mel and Linear spectrograms normalization/scaling and clipping
|
| 41 |
+
signal_normalization=True,
|
| 42 |
+
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
|
| 43 |
+
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
|
| 44 |
+
symmetric_mels=True,
|
| 45 |
+
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
|
| 46 |
+
# faster and cleaner convergence)
|
| 47 |
+
max_abs_value=4.,
|
| 48 |
+
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
|
| 49 |
+
# be too big to avoid gradient explosion,
|
| 50 |
+
# not too small for fast convergence)
|
| 51 |
+
# Contribution by @begeekmyfriend
|
| 52 |
+
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
|
| 53 |
+
# levels. Also allows for better G&L phase reconstruction)
|
| 54 |
+
preemphasize=True, # whether to apply filter
|
| 55 |
+
preemphasis=0.97, # filter coefficient.
|
| 56 |
+
|
| 57 |
+
# Limits
|
| 58 |
+
min_level_db=-100,
|
| 59 |
+
ref_level_db=20,
|
| 60 |
+
fmin=55,
|
| 61 |
+
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
|
| 62 |
+
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
| 63 |
+
fmax=7600, # To be increased/reduced depending on data.
|
| 64 |
+
|
| 65 |
+
)
|
| 66 |
+
|
checkpoints/checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c71166482d2b893f2f77450563a1bb31d805f3048c7213b974fd9201e9aa4b3
|
| 3 |
+
size 406815527
|
dataset/LRW/lrw_fullpath.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''Converts the LRW video names in filelists to LRW relative paths and dumps them unto new filelists'''
|
| 2 |
+
import os
|
| 3 |
+
filelist = "../filelists/lrw_cross.txt"
|
| 4 |
+
|
| 5 |
+
filelist_split_path = filelist.replace(".txt","_relative_path.txt")
|
| 6 |
+
with open(filelist, 'r') as f:
|
| 7 |
+
lines = f.readlines()
|
| 8 |
+
with open(filelist_split_path, 'w') as f:
|
| 9 |
+
for i in range(len(lines)):
|
| 10 |
+
audio_name, video_name=lines[i].split(' ')
|
| 11 |
+
audio_word = audio_name.split('_')[0]
|
| 12 |
+
video_word = video_name.split('_')[0]
|
| 13 |
+
f.write(os.path.join(audio_word,'test',audio_name)+' '+os.path.join(video_word,'test',video_name))
|
| 14 |
+
|
| 15 |
+
filelist = "../filelists/lrw_reconstruction.txt"
|
| 16 |
+
|
| 17 |
+
filelist_split_path = filelist.replace(".txt","_relative_path.txt")
|
| 18 |
+
with open(filelist, 'r') as f:
|
| 19 |
+
lines = f.readlines()
|
| 20 |
+
with open(filelist_split_path, 'w') as f:
|
| 21 |
+
for i in range(len(lines)):
|
| 22 |
+
audio_name, video_name=lines[i].split(' ')
|
| 23 |
+
audio_word = audio_name.split('_')[0]
|
| 24 |
+
video_word = video_name.split('_')[0]
|
| 25 |
+
f.write(os.path.join(audio_word,'test',audio_name)+' '+os.path.join(video_word,'test',video_name))
|
dataset/filelists/lrw_cross.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/filelists/lrw_cross_relative_path.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/filelists/lrw_reconstruction.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/filelists/lrw_reconstruction_relative_path.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/filelists/voxceleb2_test_n_5000_reconstruction_5k.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/filelists/voxceleb2_test_n_5000_seed_797_cross_5K.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/filelists/voxceleb2_test_n_500_reconstruction.txt
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id09017/SjCgiXBHfNU/00111 id09017/SjCgiXBHfNU/00111
|
| 2 |
+
id05055/HobsYUHmgr0/00138 id05055/HobsYUHmgr0/00138
|
| 3 |
+
id01567/M47d5UckOV8/00099 id01567/M47d5UckOV8/00099
|
| 4 |
+
id01228/SH3eBLMsRwY/00211 id01228/SH3eBLMsRwY/00211
|
| 5 |
+
id07312/m1VY1sC_P_o/00093 id07312/m1VY1sC_P_o/00093
|
| 6 |
+
id08696/dJH9aBSs1nE/00370 id08696/dJH9aBSs1nE/00370
|
| 7 |
+
id07312/9PCY4xwxgcE/00006 id07312/9PCY4xwxgcE/00006
|
| 8 |
+
id07494/1A8ZDo11tzY/00006 id07494/1A8ZDo11tzY/00006
|
| 9 |
+
id00061/thHLZ8tDJ-M/00276 id00061/thHLZ8tDJ-M/00276
|
| 10 |
+
id03862/YQMdTzyG-P8/00297 id03862/YQMdTzyG-P8/00297
|
| 11 |
+
id04570/zsnG6eKzOGE/00406 id04570/zsnG6eKzOGE/00406
|
| 12 |
+
id07414/1m1C-CdhmZ0/00016 id07414/1m1C-CdhmZ0/00016
|
| 13 |
+
id01509/2uKpHd-euIo/00038 id01509/2uKpHd-euIo/00038
|
| 14 |
+
id04276/qHXwXqxL0mk/00401 id04276/qHXwXqxL0mk/00401
|
| 15 |
+
id04366/x-VQ6z2QC4w/00252 id04366/x-VQ6z2QC4w/00252
|
| 16 |
+
id07620/um6sY627GaE/00475 id07620/um6sY627GaE/00475
|
| 17 |
+
id01000/RvjbLfo3XDM/00052 id01000/RvjbLfo3XDM/00052
|
| 18 |
+
id07868/fnWDbUI_Zbg/00289 id07868/fnWDbUI_Zbg/00289
|
| 19 |
+
id01333/cymDCPEhalE/00351 id01333/cymDCPEhalE/00351
|
| 20 |
+
id02317/Mv16h1Bx7HE/00241 id02317/Mv16h1Bx7HE/00241
|
| 21 |
+
id02317/Vi4k3cuwfgc/00342 id02317/Vi4k3cuwfgc/00342
|
| 22 |
+
id01000/eGeGHhuOJJ0/00077 id01000/eGeGHhuOJJ0/00077
|
| 23 |
+
id03980/zaDLb12pDBQ/00130 id03980/zaDLb12pDBQ/00130
|
| 24 |
+
id05124/c-Pa7b81coQ/00354 id05124/c-Pa7b81coQ/00354
|
| 25 |
+
id04478/nhLuGj0vGb8/00234 id04478/nhLuGj0vGb8/00234
|
| 26 |
+
id01541/3su8tn9nwi4/00007 id01541/3su8tn9nwi4/00007
|
| 27 |
+
id06484/cmIyVotzXiE/00125 id06484/cmIyVotzXiE/00125
|
| 28 |
+
id06209/oxofNHGCj7s/00139 id06209/oxofNHGCj7s/00139
|
| 29 |
+
id02181/rxX3t2rzLbg/00146 id02181/rxX3t2rzLbg/00146
|
| 30 |
+
id02286/YL75-u9XYUM/00105 id02286/YL75-u9XYUM/00105
|
| 31 |
+
id04276/v9mSslwD0Kg/00470 id04276/v9mSslwD0Kg/00470
|
| 32 |
+
id07802/6qBSFfV_Mig/00042 id07802/6qBSFfV_Mig/00042
|
| 33 |
+
id04295/DtC2X1KG8TE/00057 id04295/DtC2X1KG8TE/00057
|
| 34 |
+
id00866/shG_183xFlw/00243 id00866/shG_183xFlw/00243
|
| 35 |
+
id03862/2nagLhV_Yvw/00012 id03862/2nagLhV_Yvw/00012
|
| 36 |
+
id04119/Yndoy1jgHWs/00042 id04119/Yndoy1jgHWs/00042
|
| 37 |
+
id04295/mTCDT_Fv5So/00203 id04295/mTCDT_Fv5So/00203
|
| 38 |
+
id08456/o-5hKwhGqac/00354 id08456/o-5hKwhGqac/00354
|
| 39 |
+
id07494/tv6GJkx_Wy4/00331 id07494/tv6GJkx_Wy4/00331
|
| 40 |
+
id04295/mClPHVzTCLI/00196 id04295/mClPHVzTCLI/00196
|
| 41 |
+
id04478/81Tb6kjlNIk/00019 id04478/81Tb6kjlNIk/00019
|
| 42 |
+
id00812/NeNXGI8mox8/00158 id00812/NeNXGI8mox8/00158
|
| 43 |
+
id04536/sfldoEPrFPI/00438 id04536/sfldoEPrFPI/00438
|
| 44 |
+
id07620/aJVbccKJwEw/00327 id07620/aJVbccKJwEw/00327
|
| 45 |
+
id02286/4LAIxvdvguc/00001 id02286/4LAIxvdvguc/00001
|
| 46 |
+
id07802/BfQUBDw7TiM/00080 id07802/BfQUBDw7TiM/00080
|
| 47 |
+
id01066/65k0p7fUBVI/00026 id01066/65k0p7fUBVI/00026
|
| 48 |
+
id03862/w97YzyPYm1k/00460 id03862/w97YzyPYm1k/00460
|
| 49 |
+
id05816/njFBkJSpUrY/00414 id05816/njFBkJSpUrY/00414
|
| 50 |
+
id05124/_Oxp6absIhY/00341 id05124/_Oxp6absIhY/00341
|
| 51 |
+
id07663/mUw-kxAavdM/00192 id07663/mUw-kxAavdM/00192
|
| 52 |
+
id05999/Ls440srvfR4/00127 id05999/Ls440srvfR4/00127
|
| 53 |
+
id02548/Hmlw5PIf64o/00098 id02548/Hmlw5PIf64o/00098
|
| 54 |
+
id04276/Pbo_nlcZ0Lc/00190 id04276/Pbo_nlcZ0Lc/00190
|
| 55 |
+
id07802/FhKML4dLE60/00115 id07802/FhKML4dLE60/00115
|
| 56 |
+
id07621/1L2IUy6gqaM/00012 id07621/1L2IUy6gqaM/00012
|
| 57 |
+
id05654/veGIQ7p2ZSk/00130 id05654/veGIQ7p2ZSk/00130
|
| 58 |
+
id04094/0z1JYPKGBI8/00007 id04094/0z1JYPKGBI8/00007
|
| 59 |
+
id02576/wWUREnOwYo0/00136 id02576/wWUREnOwYo0/00136
|
| 60 |
+
id09017/PLNK1g5w4FY/00099 id09017/PLNK1g5w4FY/00099
|
| 61 |
+
id06484/USbx34RUkVI/00096 id06484/USbx34RUkVI/00096
|
| 62 |
+
id03030/FXbzdRO7t98/00101 id03030/FXbzdRO7t98/00101
|
| 63 |
+
id02057/VCXnx-ozS8c/00263 id02057/VCXnx-ozS8c/00263
|
| 64 |
+
id02542/JUodrwt9ucI/00033 id02542/JUodrwt9ucI/00033
|
| 65 |
+
id03030/DM_Z5D2fkRA/00068 id03030/DM_Z5D2fkRA/00068
|
| 66 |
+
id08552/irj3SqKAe0c/00196 id08552/irj3SqKAe0c/00196
|
| 67 |
+
id03030/YxBoufnVIMw/00177 id03030/YxBoufnVIMw/00177
|
| 68 |
+
id07868/Eaf-dgA59Gs/00061 id07868/Eaf-dgA59Gs/00061
|
| 69 |
+
id08456/6xVSlQDr7-w/00031 id08456/6xVSlQDr7-w/00031
|
| 70 |
+
id06811/OYFkt_n18hg/00128 id06811/OYFkt_n18hg/00128
|
| 71 |
+
id00817/tCnW5E8cMow/00383 id00817/tCnW5E8cMow/00383
|
| 72 |
+
id02542/fXQbNcIbcek/00053 id02542/fXQbNcIbcek/00053
|
| 73 |
+
id01567/oi2g17EF55s/00377 id01567/oi2g17EF55s/00377
|
| 74 |
+
id04366/HsG3OGE22DY/00117 id04366/HsG3OGE22DY/00117
|
| 75 |
+
id01509/1y0aWmgYDtw/00006 id01509/1y0aWmgYDtw/00006
|
| 76 |
+
id04295/pYfyopS672Y/00213 id04295/pYfyopS672Y/00213
|
| 77 |
+
id01989/6JfW9CPAoGY/00006 id01989/6JfW9CPAoGY/00006
|
| 78 |
+
id04366/tbcKV-IjZdI/00243 id04366/tbcKV-IjZdI/00243
|
| 79 |
+
id01298/UY0fkYSUFrY/00208 id01298/UY0fkYSUFrY/00208
|
| 80 |
+
id00817/GAs8WnyFKJM/00120 id00817/GAs8WnyFKJM/00120
|
| 81 |
+
id06484/TCp2-XVatIE/00079 id06484/TCp2-XVatIE/00079
|
| 82 |
+
id08374/Kf9N5AWprG8/00150 id08374/Kf9N5AWprG8/00150
|
| 83 |
+
id01822/QDWgjZqOkvM/00065 id01822/QDWgjZqOkvM/00065
|
| 84 |
+
id03030/pTz652Dx_6w/00230 id03030/pTz652Dx_6w/00230
|
| 85 |
+
id01460/chrI43l2Nuw/00201 id01460/chrI43l2Nuw/00201
|
| 86 |
+
id08374/85f-qB_KJP8/00041 id08374/85f-qB_KJP8/00041
|
| 87 |
+
id07961/PoSkUxZ4ags/00172 id07961/PoSkUxZ4ags/00172
|
| 88 |
+
id01437/uFPYqotT7tU/00233 id01437/uFPYqotT7tU/00233
|
| 89 |
+
id07621/Aan8MoozxII/00095 id07621/Aan8MoozxII/00095
|
| 90 |
+
id08456/fWTULQWYVoA/00250 id08456/fWTULQWYVoA/00250
|
| 91 |
+
id05055/da7Z8oWhFPY/00351 id05055/da7Z8oWhFPY/00351
|
| 92 |
+
id02181/hIvctbfcBx8/00106 id02181/hIvctbfcBx8/00106
|
| 93 |
+
id01541/dEmuPb4A7do/00184 id01541/dEmuPb4A7do/00184
|
| 94 |
+
id00419/a3Y7pQzcn40/00305 id00419/a3Y7pQzcn40/00305
|
| 95 |
+
id07354/dsDxN33xvL0/00262 id07354/dsDxN33xvL0/00262
|
| 96 |
+
id04478/MZh3AEgJ9pc/00092 id04478/MZh3AEgJ9pc/00092
|
| 97 |
+
id05124/UBUFmICrT-I/00281 id05124/UBUFmICrT-I/00281
|
| 98 |
+
id03127/SmGJu-t24hY/00195 id03127/SmGJu-t24hY/00195
|
| 99 |
+
id02465/coOp_DnsmEI/00150 id02465/coOp_DnsmEI/00150
|
| 100 |
+
id01618/qrOl1aaXBH0/00187 id01618/qrOl1aaXBH0/00187
|
| 101 |
+
id03969/WZVnB-m0X9g/00038 id03969/WZVnB-m0X9g/00038
|
| 102 |
+
id05202/s0m_4-SCn44/00186 id05202/s0m_4-SCn44/00186
|
| 103 |
+
id04657/SYVkfHq-pro/00172 id04657/SYVkfHq-pro/00172
|
| 104 |
+
id05176/p2IOP5_s_LM/00093 id05176/p2IOP5_s_LM/00093
|
| 105 |
+
id04950/XJS6SLQuCNM/00169 id04950/XJS6SLQuCNM/00169
|
| 106 |
+
id02019/anSrwA_9RPE/00152 id02019/anSrwA_9RPE/00152
|
| 107 |
+
id04570/Q-faEy1VXxQ/00140 id04570/Q-faEy1VXxQ/00140
|
| 108 |
+
id07621/bMvG2mQMZZw/00303 id07621/bMvG2mQMZZw/00303
|
| 109 |
+
id06811/vC3yQiWuuOI/00354 id06811/vC3yQiWuuOI/00354
|
| 110 |
+
id03839/aWMP8xzq2BE/00292 id03839/aWMP8xzq2BE/00292
|
| 111 |
+
id04094/j1ajUkR6_Q4/00326 id04094/j1ajUkR6_Q4/00326
|
| 112 |
+
id08149/o0Zdr9Jla7U/00047 id08149/o0Zdr9Jla7U/00047
|
| 113 |
+
id00017/hcr4tT9y3xs/00117 id00017/hcr4tT9y3xs/00117
|
| 114 |
+
id04950/Cu4jGRmYa4c/00064 id04950/Cu4jGRmYa4c/00064
|
| 115 |
+
id01567/TMozlhoPMfI/00223 id01567/TMozlhoPMfI/00223
|
| 116 |
+
id08374/QltFme-lqeI/00226 id08374/QltFme-lqeI/00226
|
| 117 |
+
id06816/tHor4VN8090/00259 id06816/tHor4VN8090/00259
|
| 118 |
+
id07494/xQ0YMPe-9u8/00413 id07494/xQ0YMPe-9u8/00413
|
| 119 |
+
id08374/FwR1K1rL3QI/00110 id08374/FwR1K1rL3QI/00110
|
| 120 |
+
id06692/Hlahj5abifM/00257 id06692/Hlahj5abifM/00257
|
| 121 |
+
id00419/J2LscHjRX7Q/00154 id00419/J2LscHjRX7Q/00154
|
| 122 |
+
id02057/CI5-q_qTR5I/00112 id02057/CI5-q_qTR5I/00112
|
| 123 |
+
id03862/7IccaH4HXRs/00069 id03862/7IccaH4HXRs/00069
|
| 124 |
+
id04656/ar3rKrkbjqI/00257 id04656/ar3rKrkbjqI/00257
|
| 125 |
+
id07494/XMEIdqio6ic/00184 id07494/XMEIdqio6ic/00184
|
| 126 |
+
id04657/dn4XY5c6mEw/00265 id04657/dn4XY5c6mEw/00265
|
| 127 |
+
id04570/SFKt669qIqs/00156 id04570/SFKt669qIqs/00156
|
| 128 |
+
id01541/sMDYdAB0MPs/00306 id01541/sMDYdAB0MPs/00306
|
| 129 |
+
id08456/F2O-frqyr9c/00101 id08456/F2O-frqyr9c/00101
|
| 130 |
+
id08701/_Ysb9mVibbk/00253 id08701/_Ysb9mVibbk/00253
|
| 131 |
+
id01333/e4FoER8nqx0/00365 id01333/e4FoER8nqx0/00365
|
| 132 |
+
id05124/F0Xpd6OoiDY/00161 id05124/F0Xpd6OoiDY/00161
|
| 133 |
+
id01593/AVmZf6Kl1So/00071 id01593/AVmZf6Kl1So/00071
|
| 134 |
+
id01567/fOlxxDqdrgc/00299 id01567/fOlxxDqdrgc/00299
|
| 135 |
+
id06484/2KVWoftPf2o/00001 id06484/2KVWoftPf2o/00001
|
| 136 |
+
id01224/g4jVqkEm1Gs/00274 id01224/g4jVqkEm1Gs/00274
|
| 137 |
+
id02445/ZX_6RMrTEP0/00066 id02445/ZX_6RMrTEP0/00066
|
| 138 |
+
id04656/5TR-W77XgF4/00032 id04656/5TR-W77XgF4/00032
|
| 139 |
+
id01618/F_ExF9xDajc/00060 id01618/F_ExF9xDajc/00060
|
| 140 |
+
id08392/gPX4IC53KwI/00355 id08392/gPX4IC53KwI/00355
|
| 141 |
+
id00866/pNbDtfW1JW4/00221 id00866/pNbDtfW1JW4/00221
|
| 142 |
+
id00812/b3dBqOtzsx0/00276 id00812/b3dBqOtzsx0/00276
|
| 143 |
+
id08701/61Al05HARgA/00001 id08701/61Al05HARgA/00001
|
| 144 |
+
id07663/FFo4JwVXeUM/00119 id07663/FFo4JwVXeUM/00119
|
| 145 |
+
id02057/22zJ50ky7CQ/00013 id02057/22zJ50ky7CQ/00013
|
| 146 |
+
id05055/2onVoeSgouI/00028 id05055/2onVoeSgouI/00028
|
| 147 |
+
id04006/zvUZFL0NyhM/00260 id04006/zvUZFL0NyhM/00260
|
| 148 |
+
id04950/EpOnsaBin0A/00077 id04950/EpOnsaBin0A/00077
|
| 149 |
+
id05015/RhBpC9Fc7a4/00154 id05015/RhBpC9Fc7a4/00154
|
| 150 |
+
id04656/Z_JFBDW9eZE/00251 id04656/Z_JFBDW9eZE/00251
|
| 151 |
+
id01509/2sb83ZBlbJg/00034 id01509/2sb83ZBlbJg/00034
|
| 152 |
+
id04030/JbcD0P6KGe0/00036 id04030/JbcD0P6KGe0/00036
|
| 153 |
+
id02542/cwgUjse_REU/00040 id02542/cwgUjse_REU/00040
|
| 154 |
+
id07620/xFc9X6EXtRM/00478 id07620/xFc9X6EXtRM/00478
|
| 155 |
+
id07354/Qrg89rvtZ1k/00217 id07354/Qrg89rvtZ1k/00217
|
| 156 |
+
id03839/wSQMEZMxxx4/00461 id03839/wSQMEZMxxx4/00461
|
| 157 |
+
id03127/iWeklsXc0H8/00268 id03127/iWeklsXc0H8/00268
|
| 158 |
+
id07663/54qlJ2HZ08s/00096 id07663/54qlJ2HZ08s/00096
|
| 159 |
+
id07961/Orp8s5aHYc8/00158 id07961/Orp8s5aHYc8/00158
|
| 160 |
+
id03347/y_F4aAkN0d8/00417 id03347/y_F4aAkN0d8/00417
|
| 161 |
+
id06913/KNDyf594xQg/00056 id06913/KNDyf594xQg/00056
|
| 162 |
+
id04366/DIgAc22fq9c/00080 id04366/DIgAc22fq9c/00080
|
| 163 |
+
id07396/uJPtbxlXi2c/00187 id07396/uJPtbxlXi2c/00187
|
| 164 |
+
id07868/gVspdH-U2XE/00290 id07868/gVspdH-U2XE/00290
|
| 165 |
+
id05594/u7qCFBP1nH4/00184 id05594/u7qCFBP1nH4/00184
|
| 166 |
+
id01541/mDoT5mpo_2c/00241 id01541/mDoT5mpo_2c/00241
|
| 167 |
+
id07354/0y9b8qlM170/00011 id07354/0y9b8qlM170/00011
|
| 168 |
+
id01460/DnnphhTlRPE/00075 id01460/DnnphhTlRPE/00075
|
| 169 |
+
id02548/1CNhmMmirfA/00009 id02548/1CNhmMmirfA/00009
|
| 170 |
+
id03127/k8z6DxdyF9w/00291 id03127/k8z6DxdyF9w/00291
|
| 171 |
+
id01437/zLRJ_8_M5Wg/00263 id01437/zLRJ_8_M5Wg/00263
|
| 172 |
+
id02576/WnbNQuJzErQ/00086 id02576/WnbNQuJzErQ/00086
|
| 173 |
+
id01333/M0UD9g1x18c/00128 id01333/M0UD9g1x18c/00128
|
| 174 |
+
id04295/1fSjOItVYVg/00001 id04295/1fSjOItVYVg/00001
|
| 175 |
+
id08456/8tt1LbCoU0E/00054 id08456/8tt1LbCoU0E/00054
|
| 176 |
+
id07494/r-ToqH_EJNs/00318 id07494/r-ToqH_EJNs/00318
|
| 177 |
+
id06816/XBKj9XWlZCw/00123 id06816/XBKj9XWlZCw/00123
|
| 178 |
+
id03030/haoNit7a4W0/00201 id03030/haoNit7a4W0/00201
|
| 179 |
+
id03839/aeObhOJLQzQ/00293 id03839/aeObhOJLQzQ/00293
|
| 180 |
+
id07868/COb1gFHXsBQ/00059 id07868/COb1gFHXsBQ/00059
|
| 181 |
+
id01224/eYWcMCsgkLY/00255 id01224/eYWcMCsgkLY/00255
|
| 182 |
+
id04006/K5ueXBlS6rc/00049 id04006/K5ueXBlS6rc/00049
|
| 183 |
+
id07620/G5-1CUbaz0c/00107 id07620/G5-1CUbaz0c/00107
|
| 184 |
+
id06104/cj0TAnwndoc/00230 id06104/cj0TAnwndoc/00230
|
| 185 |
+
id00061/STX1ycPt8fU/00076 id00061/STX1ycPt8fU/00076
|
| 186 |
+
id04478/wMbobxEQ7j8/00336 id04478/wMbobxEQ7j8/00336
|
| 187 |
+
id01106/7X_xtnJhEc0/00031 id01106/7X_xtnJhEc0/00031
|
| 188 |
+
id08374/zaYzRbE_2C8/00494 id08374/zaYzRbE_2C8/00494
|
| 189 |
+
id04276/MgOqCfwKE70/00173 id04276/MgOqCfwKE70/00173
|
| 190 |
+
id03127/Lgd5qn2-kMo/00079 id03127/Lgd5qn2-kMo/00079
|
| 191 |
+
id00154/xH3Pp_5yxOk/00153 id00154/xH3Pp_5yxOk/00153
|
| 192 |
+
id04030/7mXUMuo5_NE/00001 id04030/7mXUMuo5_NE/00001
|
| 193 |
+
id02542/p7bvjcLbZm4/00097 id02542/p7bvjcLbZm4/00097
|
| 194 |
+
id04232/T7dROCqmwNQ/00235 id04232/T7dROCqmwNQ/00235
|
| 195 |
+
id02548/KrXU-_jrtxY/00147 id02548/KrXU-_jrtxY/00147
|
| 196 |
+
id01567/SZyTC5dxJOY/00219 id01567/SZyTC5dxJOY/00219
|
| 197 |
+
id03524/2DD4Np7SaWw/00007 id03524/2DD4Np7SaWw/00007
|
| 198 |
+
id04094/DRq5F2261Ko/00072 id04094/DRq5F2261Ko/00072
|
| 199 |
+
id07802/HrpJg06dowY/00152 id07802/HrpJg06dowY/00152
|
| 200 |
+
id06816/pBt-DxsTFc8/00231 id06816/pBt-DxsTFc8/00231
|
| 201 |
+
id00154/2pSNL5YdcoQ/00002 id00154/2pSNL5YdcoQ/00002
|
| 202 |
+
id01541/C29fUBtimOE/00038 id01541/C29fUBtimOE/00038
|
| 203 |
+
id06310/b6qPjJ0isPI/00155 id06310/b6qPjJ0isPI/00155
|
| 204 |
+
id05714/wFGNufaMbDY/00025 id05714/wFGNufaMbDY/00025
|
| 205 |
+
id03980/m-8Ffv2RqYs/00092 id03980/m-8Ffv2RqYs/00092
|
| 206 |
+
id01437/uXAe0vbNWeo/00238 id01437/uXAe0vbNWeo/00238
|
| 207 |
+
id04232/tPZ-zVT67gs/00479 id04232/tPZ-zVT67gs/00479
|
| 208 |
+
id06811/ImzUwwYU6SQ/00067 id06811/ImzUwwYU6SQ/00067
|
| 209 |
+
id05459/wq3Z0I944wU/00436 id05459/wq3Z0I944wU/00436
|
| 210 |
+
id03969/Evoldg-U2_c/00024 id03969/Evoldg-U2_c/00024
|
| 211 |
+
id08548/BSChFozahbU/00019 id08548/BSChFozahbU/00019
|
| 212 |
+
id04950/PQEAck-3wcA/00134 id04950/PQEAck-3wcA/00134
|
| 213 |
+
id04295/G4YnExZSzlM/00066 id04295/G4YnExZSzlM/00066
|
| 214 |
+
id05176/mc7rFp2B1j0/00092 id05176/mc7rFp2B1j0/00092
|
| 215 |
+
id00812/1Xfgvdu7oDo/00001 id00812/1Xfgvdu7oDo/00001
|
| 216 |
+
id05459/UPSPGawaVsg/00233 id05459/UPSPGawaVsg/00233
|
| 217 |
+
id04656/7nG3rOv0oBw/00050 id04656/7nG3rOv0oBw/00050
|
| 218 |
+
id02548/nvYBpt14BrQ/00309 id02548/nvYBpt14BrQ/00309
|
| 219 |
+
id02317/A3AvljK8Upk/00102 id02317/A3AvljK8Upk/00102
|
| 220 |
+
id04478/qLNvRwMkhik/00242 id04478/qLNvRwMkhik/00242
|
| 221 |
+
id01228/lCDMC8JvKyU/00295 id01228/lCDMC8JvKyU/00295
|
| 222 |
+
id03041/5CfnYwQCW48/00001 id03041/5CfnYwQCW48/00001
|
| 223 |
+
id04950/LnsriCjCIV4/00116 id04950/LnsriCjCIV4/00116
|
| 224 |
+
id04094/plxNYSFgDTM/00384 id04094/plxNYSFgDTM/00384
|
| 225 |
+
id01460/30_QmGw7lmE/00030 id01460/30_QmGw7lmE/00030
|
| 226 |
+
id04366/6rX7hCNSjaw/00056 id04366/6rX7hCNSjaw/00056
|
| 227 |
+
id01041/m-xolqIq8p4/00370 id01041/m-xolqIq8p4/00370
|
| 228 |
+
id04950/BG4CCg2RiuQ/00052 id04950/BG4CCg2RiuQ/00052
|
| 229 |
+
id01989/7g0A7pF94r0/00018 id01989/7g0A7pF94r0/00018
|
| 230 |
+
id03382/b_NJ2Xz3G4Y/00030 id03382/b_NJ2Xz3G4Y/00030
|
| 231 |
+
id00812/IteHRVKyzaE/00138 id00812/IteHRVKyzaE/00138
|
| 232 |
+
id00061/bdkqfVtDZVY/00121 id00061/bdkqfVtDZVY/00121
|
| 233 |
+
id03839/YkYIh4cYwwg/00275 id03839/YkYIh4cYwwg/00275
|
| 234 |
+
id07354/wyTuCRGjUIQ/00477 id07354/wyTuCRGjUIQ/00477
|
| 235 |
+
id02057/TddnW2TaXrc/00246 id02057/TddnW2TaXrc/00246
|
| 236 |
+
id01989/gHVHtKTQBsw/00128 id01989/gHVHtKTQBsw/00128
|
| 237 |
+
id08374/bXlUHb5hxxA/00266 id08374/bXlUHb5hxxA/00266
|
| 238 |
+
id03862/TE2zQc8_W-g/00252 id03862/TE2zQc8_W-g/00252
|
| 239 |
+
id08696/86-k8TuowAE/00033 id08696/86-k8TuowAE/00033
|
| 240 |
+
id05176/K8yZYHg_4ro/00050 id05176/K8yZYHg_4ro/00050
|
| 241 |
+
id04253/SKsPkHMGHYY/00240 id04253/SKsPkHMGHYY/00240
|
| 242 |
+
id07874/2KK4ozkjaEE/00002 id07874/2KK4ozkjaEE/00002
|
| 243 |
+
id08392/g-SJYYaaLgE/00352 id08392/g-SJYYaaLgE/00352
|
| 244 |
+
id02542/glhCf1hwJhE/00065 id02542/glhCf1hwJhE/00065
|
| 245 |
+
id00817/FsL-bTbDTyw/00112 id00817/FsL-bTbDTyw/00112
|
| 246 |
+
id04862/IuXPj9VhUVA/00100 id04862/IuXPj9VhUVA/00100
|
| 247 |
+
id06811/f9-8d3lNNcw/00237 id06811/f9-8d3lNNcw/00237
|
| 248 |
+
id04094/JUYMzfVp8zI/00113 id04094/JUYMzfVp8zI/00113
|
| 249 |
+
id03347/r-xJUB0A4ok/00346 id03347/r-xJUB0A4ok/00346
|
| 250 |
+
id07868/MNibTv_ODQ8/00148 id07868/MNibTv_ODQ8/00148
|
| 251 |
+
id08392/3e5zvNaT-eU/00020 id08392/3e5zvNaT-eU/00020
|
| 252 |
+
id04295/bKMKvAr440A/00141 id04295/bKMKvAr440A/00141
|
| 253 |
+
id04295/l62YPD0ZkZI/00185 id04295/l62YPD0ZkZI/00185
|
| 254 |
+
id07312/RO9DsspwXiE/00047 id07312/RO9DsspwXiE/00047
|
| 255 |
+
id03030/rmFsUV5ICKk/00267 id03030/rmFsUV5ICKk/00267
|
| 256 |
+
id03677/nVWTTopGQdU/00181 id03677/nVWTTopGQdU/00181
|
| 257 |
+
id00866/xQ1Yy0kjvjA/00256 id00866/xQ1Yy0kjvjA/00256
|
| 258 |
+
id01333/fRnqtJR0rws/00371 id01333/fRnqtJR0rws/00371
|
| 259 |
+
id05055/AZoIKG33E8s/00115 id05055/AZoIKG33E8s/00115
|
| 260 |
+
id01822/_CkfCmQXII8/00098 id01822/_CkfCmQXII8/00098
|
| 261 |
+
id01593/_gyaAyVi6SA/00344 id01593/_gyaAyVi6SA/00344
|
| 262 |
+
id04295/DS3RDwf2xI8/00049 id04295/DS3RDwf2xI8/00049
|
| 263 |
+
id00812/EjO-VORTv_o/00098 id00812/EjO-VORTv_o/00098
|
| 264 |
+
id04657/WdJ_DuU0ack/00236 id04657/WdJ_DuU0ack/00236
|
| 265 |
+
id04232/AB9fk1MH2rA/00035 id04232/AB9fk1MH2rA/00035
|
| 266 |
+
id00419/chfgCUm9-Mg/00364 id00419/chfgCUm9-Mg/00364
|
| 267 |
+
id02577/Az0BGrX_TwI/00021 id02577/Az0BGrX_TwI/00021
|
| 268 |
+
id01437/hyj4OYm0cvA/00195 id01437/hyj4OYm0cvA/00195
|
| 269 |
+
id01593/tLFWX-IdAwI/00431 id01593/tLFWX-IdAwI/00431
|
| 270 |
+
id04536/MNDmkEXRS7s/00312 id04536/MNDmkEXRS7s/00312
|
| 271 |
+
id03789/7qhkM8qY3Fw/00077 id03789/7qhkM8qY3Fw/00077
|
| 272 |
+
id01593/neAk6K8BvTA/00397 id01593/neAk6K8BvTA/00397
|
| 273 |
+
id06484/jTHSVo6NvS4/00151 id06484/jTHSVo6NvS4/00151
|
| 274 |
+
id07414/cAudd_5Yv2I/00256 id07414/cAudd_5Yv2I/00256
|
| 275 |
+
id00866/ADzqaRZtJNA/00087 id00866/ADzqaRZtJNA/00087
|
| 276 |
+
id06484/ZySpn0Aj09k/00108 id06484/ZySpn0Aj09k/00108
|
| 277 |
+
id07312/ZHBjHQENqW8/00053 id07312/ZHBjHQENqW8/00053
|
| 278 |
+
id04656/LDuq2UPHKoA/00157 id04656/LDuq2UPHKoA/00157
|
| 279 |
+
id01509/UZL8Obdt--8/00181 id01509/UZL8Obdt--8/00181
|
| 280 |
+
id05816/7jt8zGB27QQ/00017 id05816/7jt8zGB27QQ/00017
|
| 281 |
+
id08456/7PKsuBS5LQI/00050 id08456/7PKsuBS5LQI/00050
|
| 282 |
+
id06913/Tx0vAZhSPuE/00077 id06913/Tx0vAZhSPuE/00077
|
| 283 |
+
id02465/UEmI4r5G-5Y/00117 id02465/UEmI4r5G-5Y/00117
|
| 284 |
+
id01460/9sefvU9y4Kw/00046 id01460/9sefvU9y4Kw/00046
|
| 285 |
+
id01567/uYDx0vIVy_A/00429 id01567/uYDx0vIVy_A/00429
|
| 286 |
+
id07961/qott7SmhA-A/00351 id07961/qott7SmhA-A/00351
|
| 287 |
+
id00866/Awi1Q0yib1s/00092 id00866/Awi1Q0yib1s/00092
|
| 288 |
+
id02086/CqJKcn8m_Xo/00152 id02086/CqJKcn8m_Xo/00152
|
| 289 |
+
id05015/Obbv73CqtmQ/00137 id05015/Obbv73CqtmQ/00137
|
| 290 |
+
id01041/1UYZqPpavtk/00001 id01041/1UYZqPpavtk/00001
|
| 291 |
+
id01593/GiLxqKSI68o/00188 id01593/GiLxqKSI68o/00188
|
| 292 |
+
id02317/IR0psXbOjdc/00176 id02317/IR0psXbOjdc/00176
|
| 293 |
+
id01066/X33aJxc3Kt0/00112 id01066/X33aJxc3Kt0/00112
|
| 294 |
+
id08456/VU3fkD-QqPw/00206 id08456/VU3fkD-QqPw/00206
|
| 295 |
+
id04536/wat5sbCSs0k/00470 id04536/wat5sbCSs0k/00470
|
| 296 |
+
id01066/4KOSmyAMipc/00020 id01066/4KOSmyAMipc/00020
|
| 297 |
+
id02445/f5u3ktNPHAk/00074 id02445/f5u3ktNPHAk/00074
|
| 298 |
+
id03041/NJUcU7j30JI/00011 id03041/NJUcU7j30JI/00011
|
| 299 |
+
id00817/vUezvJDh_tA/00394 id00817/vUezvJDh_tA/00394
|
| 300 |
+
id04478/sw50KQMY8vw/00298 id04478/sw50KQMY8vw/00298
|
| 301 |
+
id04657/hMrgeYf5ToQ/00267 id04657/hMrgeYf5ToQ/00267
|
| 302 |
+
id02548/VdjlKRtLD_w/00206 id02548/VdjlKRtLD_w/00206
|
| 303 |
+
id06310/4oJF1NW2bIg/00006 id06310/4oJF1NW2bIg/00006
|
| 304 |
+
id01509/jqbtAt91alI/00329 id01509/jqbtAt91alI/00329
|
| 305 |
+
id07414/oXx9CvIeFFY/00407 id07414/oXx9CvIeFFY/00407
|
| 306 |
+
id04570/mwhiZtTZYX0/00271 id04570/mwhiZtTZYX0/00271
|
| 307 |
+
id00812/AzDjo0Uyk4Y/00061 id00812/AzDjo0Uyk4Y/00061
|
| 308 |
+
id05999/MJwLq17VoMA/00146 id05999/MJwLq17VoMA/00146
|
| 309 |
+
id07414/dsqrI97WQHE/00319 id07414/dsqrI97WQHE/00319
|
| 310 |
+
id05015/C3KsCD-pUgs/00046 id05015/C3KsCD-pUgs/00046
|
| 311 |
+
id06484/Gh6H7Md_L2k/00053 id06484/Gh6H7Md_L2k/00053
|
| 312 |
+
id00081/xlwJqdrzeMA/00291 id00081/xlwJqdrzeMA/00291
|
| 313 |
+
id05055/RLN5nKfza4A/00219 id05055/RLN5nKfza4A/00219
|
| 314 |
+
id05055/OKw_hph-hK8/00197 id05055/OKw_hph-hK8/00197
|
| 315 |
+
id03839/xtBkY9xYpjA/00464 id03839/xtBkY9xYpjA/00464
|
| 316 |
+
id07620/HEX00yF8LTs/00117 id07620/HEX00yF8LTs/00117
|
| 317 |
+
id05816/hjrZgsKuvpw/00349 id05816/hjrZgsKuvpw/00349
|
| 318 |
+
id02548/6LPbT49zy38/00050 id02548/6LPbT49zy38/00050
|
| 319 |
+
id01000/7eYakM6qrTs/00006 id01000/7eYakM6qrTs/00006
|
| 320 |
+
id02181/cNCj0pLxR24/00084 id02181/cNCj0pLxR24/00084
|
| 321 |
+
id02086/sSliWvu6Ufs/00453 id02086/sSliWvu6Ufs/00453
|
| 322 |
+
id03178/KHelFt1Jyyg/00057 id03178/KHelFt1Jyyg/00057
|
| 323 |
+
id05594/8dYcSoUAQO8/00014 id05594/8dYcSoUAQO8/00014
|
| 324 |
+
id05015/JmvJemqIeS0/00102 id05015/JmvJemqIeS0/00102
|
| 325 |
+
id00081/EvCyt2keqW4/00065 id00081/EvCyt2keqW4/00065
|
| 326 |
+
id07663/QWe7IIGrv5s/00146 id07663/QWe7IIGrv5s/00146
|
| 327 |
+
id01618/kzxW2WAFWLI/00126 id01618/kzxW2WAFWLI/00126
|
| 328 |
+
id00562/X7FJ3M3bz3c/00124 id00562/X7FJ3M3bz3c/00124
|
| 329 |
+
id07961/bvPOvzukTE4/00224 id07961/bvPOvzukTE4/00224
|
| 330 |
+
id03789/nv8sQplhvX0/00357 id03789/nv8sQplhvX0/00357
|
| 331 |
+
id04295/VUHarbuO_eE/00125 id04295/VUHarbuO_eE/00125
|
| 332 |
+
id01822/IaBziWYcwK4/00037 id01822/IaBziWYcwK4/00037
|
| 333 |
+
id05015/X1opVctkTE8/00170 id05015/X1opVctkTE8/00170
|
| 334 |
+
id01041/MMXznNig_iU/00248 id01041/MMXznNig_iU/00248
|
| 335 |
+
id02465/EZ_F0hUZdS4/00054 id02465/EZ_F0hUZdS4/00054
|
| 336 |
+
id04656/Bi7kCsbg5L0/00061 id04656/Bi7kCsbg5L0/00061
|
| 337 |
+
id07494/K4ndWNAHgdU/00093 id07494/K4ndWNAHgdU/00093
|
| 338 |
+
id07354/TKTT7fArInQ/00218 id07354/TKTT7fArInQ/00218
|
| 339 |
+
id05714/Lu4PPvWXGn8/00014 id05714/Lu4PPvWXGn8/00014
|
| 340 |
+
id05654/07pANazoyJg/00001 id05654/07pANazoyJg/00001
|
| 341 |
+
id01066/FDp-ZLCWrIc/00054 id01066/FDp-ZLCWrIc/00054
|
| 342 |
+
id05999/ZQJVmCJFjNs/00182 id05999/ZQJVmCJFjNs/00182
|
| 343 |
+
id04570/5Fg6CLuRntk/00041 id04570/5Fg6CLuRntk/00041
|
| 344 |
+
id08696/vqLNqYW4TQA/00476 id08696/vqLNqYW4TQA/00476
|
| 345 |
+
id04862/2uYHadPvHRU/00016 id04862/2uYHadPvHRU/00016
|
| 346 |
+
id03980/7MRUusImkno/00001 id03980/7MRUusImkno/00001
|
| 347 |
+
id02542/QJKFnt1lHeE/00035 id02542/QJKFnt1lHeE/00035
|
| 348 |
+
id04536/OYH-6uGB6jI/00322 id04536/OYH-6uGB6jI/00322
|
| 349 |
+
id06484/dOTMnYZcY9Q/00126 id06484/dOTMnYZcY9Q/00126
|
| 350 |
+
id04478/GZQGZOmFU5U/00063 id04478/GZQGZOmFU5U/00063
|
| 351 |
+
id01224/tELp6C7FELU/00421 id01224/tELp6C7FELU/00421
|
| 352 |
+
id03862/5m5iPZNJS6c/00022 id03862/5m5iPZNJS6c/00022
|
| 353 |
+
id05124/lcDhSnyeN5E/00381 id05124/lcDhSnyeN5E/00381
|
| 354 |
+
id08149/3V9V5sDAWTc/00001 id08149/3V9V5sDAWTc/00001
|
| 355 |
+
id02181/iEF0MWApQms/00108 id02181/iEF0MWApQms/00108
|
| 356 |
+
id04536/xrsxSF2qey8/00471 id04536/xrsxSF2qey8/00471
|
| 357 |
+
id03178/9AJzTUwGbRk/00005 id03178/9AJzTUwGbRk/00005
|
| 358 |
+
id01041/Izmh75CZNW0/00207 id01041/Izmh75CZNW0/00207
|
| 359 |
+
id03041/g5YLpUZBNKc/00018 id03041/g5YLpUZBNKc/00018
|
| 360 |
+
id03347/nSAKXYdEOOM/00297 id03347/nSAKXYdEOOM/00297
|
| 361 |
+
id03347/pPWGEPixOoM/00337 id03347/pPWGEPixOoM/00337
|
| 362 |
+
id07312/XBBpLMEjfUo/00048 id07312/XBBpLMEjfUo/00048
|
| 363 |
+
id08456/6QFe7cYnZk4/00023 id08456/6QFe7cYnZk4/00023
|
| 364 |
+
id05176/5Hk_hj0oXN8/00004 id05176/5Hk_hj0oXN8/00004
|
| 365 |
+
id07426/DBBfi7aKLx4/00038 id07426/DBBfi7aKLx4/00038
|
| 366 |
+
id07494/uhPKcTLLwcM/00347 id07494/uhPKcTLLwcM/00347
|
| 367 |
+
id02576/agxjz_O2Wfs/00088 id02576/agxjz_O2Wfs/00088
|
| 368 |
+
id01541/SvTz_Pn15Vk/00119 id01541/SvTz_Pn15Vk/00119
|
| 369 |
+
id07414/Uxggn91FBog/00214 id07414/Uxggn91FBog/00214
|
| 370 |
+
id04253/1HOlzefgLu8/00001 id04253/1HOlzefgLu8/00001
|
| 371 |
+
id01567/RPUd0ua7RR0/00216 id01567/RPUd0ua7RR0/00216
|
| 372 |
+
id04657/5DzZTPLgwTM/00044 id04657/5DzZTPLgwTM/00044
|
| 373 |
+
id04006/zSMWS35kYdQ/00253 id04006/zSMWS35kYdQ/00253
|
| 374 |
+
id03347/KT7B07WFWyM/00104 id03347/KT7B07WFWyM/00104
|
| 375 |
+
id02445/z5u4yO1EsZo/00109 id02445/z5u4yO1EsZo/00109
|
| 376 |
+
id00154/z1dLArSg5PQ/00190 id00154/z1dLArSg5PQ/00190
|
| 377 |
+
id07414/Cn6Ws4oK1jg/00095 id07414/Cn6Ws4oK1jg/00095
|
| 378 |
+
id02286/WHS1n7XUt_8/00103 id02286/WHS1n7XUt_8/00103
|
| 379 |
+
id01509/Zmmnr4iRsCM/00230 id01509/Zmmnr4iRsCM/00230
|
| 380 |
+
id04276/tGOA4fVnSgw/00448 id04276/tGOA4fVnSgw/00448
|
| 381 |
+
id00419/nu9cRW2J4Dk/00420 id00419/nu9cRW2J4Dk/00420
|
| 382 |
+
id07868/6RQX9l98N-g/00002 id07868/6RQX9l98N-g/00002
|
| 383 |
+
id03839/1lh57VnuaKE/00004 id03839/1lh57VnuaKE/00004
|
| 384 |
+
id03178/LT-BNQKA9NU/00075 id03178/LT-BNQKA9NU/00075
|
| 385 |
+
id01460/Es6CkRmkIBY/00080 id01460/Es6CkRmkIBY/00080
|
| 386 |
+
id06692/T2Xk7MO6m2g/00297 id06692/T2Xk7MO6m2g/00297
|
| 387 |
+
id01892/d8b9y_CRE3M/00102 id01892/d8b9y_CRE3M/00102
|
| 388 |
+
id07426/K_25cVSB-JU/00063 id07426/K_25cVSB-JU/00063
|
| 389 |
+
id01333/LI6eLfuTn6I/00127 id01333/LI6eLfuTn6I/00127
|
| 390 |
+
id00081/hIBFutPzn8s/00158 id00081/hIBFutPzn8s/00158
|
| 391 |
+
id04536/2j8I_WX5mhY/00009 id04536/2j8I_WX5mhY/00009
|
| 392 |
+
id04232/UElg0R7fmlk/00253 id04232/UElg0R7fmlk/00253
|
| 393 |
+
id01460/eZR__GGkVw4/00221 id01460/eZR__GGkVw4/00221
|
| 394 |
+
id01041/GymfYtTsKEU/00119 id01041/GymfYtTsKEU/00119
|
| 395 |
+
id07396/xK1gClL60tY/00191 id07396/xK1gClL60tY/00191
|
| 396 |
+
id05459/81o3ictaOnU/00075 id05459/81o3ictaOnU/00075
|
| 397 |
+
id02685/yN8ilDTW-o4/00114 id02685/yN8ilDTW-o4/00114
|
| 398 |
+
id02286/c8LjgwDQAkw/00137 id02286/c8LjgwDQAkw/00137
|
| 399 |
+
id01541/SWcGs-DbV9Q/00100 id01541/SWcGs-DbV9Q/00100
|
| 400 |
+
id01822/x4Fr2ceg_f8/00231 id01822/x4Fr2ceg_f8/00231
|
| 401 |
+
id03347/FKY5V8wmX5k/00043 id03347/FKY5V8wmX5k/00043
|
| 402 |
+
id00817/0GmSijZelGY/00001 id00817/0GmSijZelGY/00001
|
| 403 |
+
id06209/ahL3F1x5sE4/00091 id06209/ahL3F1x5sE4/00091
|
| 404 |
+
id06692/4k3Eo5s1Rwo/00057 id06692/4k3Eo5s1Rwo/00057
|
| 405 |
+
id09017/sduESYpj2-I/00297 id09017/sduESYpj2-I/00297
|
| 406 |
+
id07354/grg37qaxKjI/00329 id07354/grg37qaxKjI/00329
|
| 407 |
+
id07802/X8I5FN64_Oc/00199 id07802/X8I5FN64_Oc/00199
|
| 408 |
+
id07494/JV5S_SUcHmI/00088 id07494/JV5S_SUcHmI/00088
|
| 409 |
+
id03524/eHrI5bD8hSs/00282 id03524/eHrI5bD8hSs/00282
|
| 410 |
+
id01460/HNjuGz9ayBk/00109 id01460/HNjuGz9ayBk/00109
|
| 411 |
+
id04570/961AefP1-is/00056 id04570/961AefP1-is/00056
|
| 412 |
+
id00419/749eTxP4Us8/00061 id00419/749eTxP4Us8/00061
|
| 413 |
+
id00017/OLguY5ofUrY/00039 id00017/OLguY5ofUrY/00039
|
| 414 |
+
id08392/RogKVSjaAH0/00293 id08392/RogKVSjaAH0/00293
|
| 415 |
+
id01066/lI1wGa1UhEM/00205 id01066/lI1wGa1UhEM/00205
|
| 416 |
+
id07621/zSdriAuJUKo/00485 id07621/zSdriAuJUKo/00485
|
| 417 |
+
id03862/JBkaiUNeMmk/00166 id03862/JBkaiUNeMmk/00166
|
| 418 |
+
id00017/E6aqL_Nc410/00027 id00017/E6aqL_Nc410/00027
|
| 419 |
+
id03839/fi-g--cBwnU/00348 id03839/fi-g--cBwnU/00348
|
| 420 |
+
id05654/eLztZmvnk-k/00095 id05654/eLztZmvnk-k/00095
|
| 421 |
+
id02548/wF5HfFXZCBI/00349 id02548/wF5HfFXZCBI/00349
|
| 422 |
+
id02576/LAipS5WJ29s/00075 id02576/LAipS5WJ29s/00075
|
| 423 |
+
id06692/SEPs17_AkTI/00295 id06692/SEPs17_AkTI/00295
|
| 424 |
+
id05459/kkaYxtBZnNo/00348 id05459/kkaYxtBZnNo/00348
|
| 425 |
+
id04232/MEGVEqgGCME/00167 id04232/MEGVEqgGCME/00167
|
| 426 |
+
id01989/8CUktsB_2bA/00031 id01989/8CUktsB_2bA/00031
|
| 427 |
+
id01066/kqP_NZ1FRlM/00176 id01066/kqP_NZ1FRlM/00176
|
| 428 |
+
id03382/ockh8KdXJP8/00059 id03382/ockh8KdXJP8/00059
|
| 429 |
+
id01593/pO180haP_vo/00410 id01593/pO180haP_vo/00410
|
| 430 |
+
id07396/nTQDZrnGXXY/00179 id07396/nTQDZrnGXXY/00179
|
| 431 |
+
id03030/rg-VUeksKaU/00257 id03030/rg-VUeksKaU/00257
|
| 432 |
+
id08911/IddDkZwRflE/00053 id08911/IddDkZwRflE/00053
|
| 433 |
+
id02317/K2GT02zavxo/00193 id02317/K2GT02zavxo/00193
|
| 434 |
+
id01298/5P4ldDRuo5c/00065 id01298/5P4ldDRuo5c/00065
|
| 435 |
+
id01989/Evbf6fMJNmk/00060 id01989/Evbf6fMJNmk/00060
|
| 436 |
+
id05124/fNJI2A0v8yI/00357 id05124/fNJI2A0v8yI/00357
|
| 437 |
+
id02465/RLi2ItGherA/00098 id02465/RLi2ItGherA/00098
|
| 438 |
+
id07868/qMNfMcG6sh0/00346 id07868/qMNfMcG6sh0/00346
|
| 439 |
+
id04366/tmoYV4kPOGU/00246 id04366/tmoYV4kPOGU/00246
|
| 440 |
+
id06484/_ZkoebnFkVA/00110 id06484/_ZkoebnFkVA/00110
|
| 441 |
+
id04276/I9gCyrZWFn0/00097 id04276/I9gCyrZWFn0/00097
|
| 442 |
+
id03978/IMn6f0iDOtE/00032 id03978/IMn6f0iDOtE/00032
|
| 443 |
+
id00419/w_0sK8WuSsg/00472 id00419/w_0sK8WuSsg/00472
|
| 444 |
+
id04478/RwcHXQ3MvsQ/00109 id04478/RwcHXQ3MvsQ/00109
|
| 445 |
+
id08696/cUmyIjpOYlY/00360 id08696/cUmyIjpOYlY/00360
|
| 446 |
+
id04366/DqBQx6AZ1Nk/00083 id04366/DqBQx6AZ1Nk/00083
|
| 447 |
+
id05459/RhOon49C3g8/00201 id05459/RhOon49C3g8/00201
|
| 448 |
+
id04656/OzgjshkHUiA/00166 id04656/OzgjshkHUiA/00166
|
| 449 |
+
id03969/x38Sqv819yE/00110 id03969/x38Sqv819yE/00110
|
| 450 |
+
id00061/0G9G9oyFHI8/00001 id00061/0G9G9oyFHI8/00001
|
| 451 |
+
id06913/IreNhnVfTkQ/00043 id06913/IreNhnVfTkQ/00043
|
| 452 |
+
id01618/NqYUgbuImpk/00096 id01618/NqYUgbuImpk/00096
|
| 453 |
+
id08552/y05_B9NXizo/00237 id08552/y05_B9NXizo/00237
|
| 454 |
+
id01460/zcTt06bjKuA/00365 id01460/zcTt06bjKuA/00365
|
| 455 |
+
id00866/nI-zVYcQX40/00220 id00866/nI-zVYcQX40/00220
|
| 456 |
+
id08374/9eMfNJiKBPQ/00056 id08374/9eMfNJiKBPQ/00056
|
| 457 |
+
id03524/nKxz0LxKZ58/00344 id03524/nKxz0LxKZ58/00344
|
| 458 |
+
id09017/A3CAugN2cjk/00021 id09017/A3CAugN2cjk/00021
|
| 459 |
+
id02685/NtHmnSLaGCA/00036 id02685/NtHmnSLaGCA/00036
|
| 460 |
+
id01224/atjwjz0vAk8/00213 id01224/atjwjz0vAk8/00213
|
| 461 |
+
id07961/gvLf2DggTu0/00271 id07961/gvLf2DggTu0/00271
|
| 462 |
+
id01567/CCs8rZLCdVw/00043 id01567/CCs8rZLCdVw/00043
|
| 463 |
+
id03347/nbmPriSE9NY/00316 id03347/nbmPriSE9NY/00316
|
| 464 |
+
id06104/snzG1OymFgs/00273 id06104/snzG1OymFgs/00273
|
| 465 |
+
id02019/xsXm-MSuD-E/00290 id02019/xsXm-MSuD-E/00290
|
| 466 |
+
id00061/VugwXDj1ka4/00088 id00061/VugwXDj1ka4/00088
|
| 467 |
+
id01224/4z68GFZuYKU/00028 id01224/4z68GFZuYKU/00028
|
| 468 |
+
id03839/ajkGXKUvTWY/00296 id03839/ajkGXKUvTWY/00296
|
| 469 |
+
id07874/N7fMpS_yaF4/00047 id07874/N7fMpS_yaF4/00047
|
| 470 |
+
id05124/fRhAX7v_R6A/00365 id05124/fRhAX7v_R6A/00365
|
| 471 |
+
id02181/ci_22Oqhwtc/00088 id02181/ci_22Oqhwtc/00088
|
| 472 |
+
id07414/njxmqS9ncTA/00399 id07414/njxmqS9ncTA/00399
|
| 473 |
+
id05176/yEMRxKA0vSw/00101 id05176/yEMRxKA0vSw/00101
|
| 474 |
+
id03862/VVaxYHNmtA8/00269 id03862/VVaxYHNmtA8/00269
|
| 475 |
+
id07396/X6KkvYh6rPA/00148 id07396/X6KkvYh6rPA/00148
|
| 476 |
+
id06310/TkxTnoic67U/00130 id06310/TkxTnoic67U/00130
|
| 477 |
+
id08374/Yh9O9ETuF_0/00250 id08374/Yh9O9ETuF_0/00250
|
| 478 |
+
id02317/5moKZXlJTEs/00058 id02317/5moKZXlJTEs/00058
|
| 479 |
+
id04536/EDCwhtRFARA/00172 id04536/EDCwhtRFARA/00172
|
| 480 |
+
id03789/pz1jGMsPY9M/00381 id03789/pz1jGMsPY9M/00381
|
| 481 |
+
id03127/wzS06bKAZ48/00354 id03127/wzS06bKAZ48/00354
|
| 482 |
+
id08911/wedpC4fN4YY/00096 id08911/wedpC4fN4YY/00096
|
| 483 |
+
id01106/6SFpvp42pMA/00014 id01106/6SFpvp42pMA/00014
|
| 484 |
+
id02465/6jp5YsZYtHI/00021 id02465/6jp5YsZYtHI/00021
|
| 485 |
+
id01618/Ay_BKx5-JOc/00046 id01618/Ay_BKx5-JOc/00046
|
| 486 |
+
id04478/x07vvSVm2Yo/00363 id04478/x07vvSVm2Yo/00363
|
| 487 |
+
id01593/u5AgUWl3fFU/00437 id01593/u5AgUWl3fFU/00437
|
| 488 |
+
id03030/IpwcoJajjJI/00124 id03030/IpwcoJajjJI/00124
|
| 489 |
+
id01593/t9TUbyp3xfs/00423 id01593/t9TUbyp3xfs/00423
|
| 490 |
+
id07414/hUxcsEMKssA/00320 id07414/hUxcsEMKssA/00320
|
| 491 |
+
id04366/L-56A5RNeWg/00124 id04366/L-56A5RNeWg/00124
|
| 492 |
+
id07961/3EPjXGhfst4/00001 id07961/3EPjXGhfst4/00001
|
| 493 |
+
id00061/mMOd25Ag7XY/00239 id00061/mMOd25Ag7XY/00239
|
| 494 |
+
id01567/RQMG0K5AchU/00218 id01567/RQMG0K5AchU/00218
|
| 495 |
+
id08552/PL5vk3XeKRM/00114 id08552/PL5vk3XeKRM/00114
|
| 496 |
+
id04862/eX3wAZ0yr7w/00260 id04862/eX3wAZ0yr7w/00260
|
| 497 |
+
id02086/CBNOvx4Phxw/00146 id02086/CBNOvx4Phxw/00146
|
| 498 |
+
id01228/3wAkCYQR3fQ/00011 id01228/3wAkCYQR3fQ/00011
|
| 499 |
+
id06484/MXwPpo1Dg7U/00073 id06484/MXwPpo1Dg7U/00073
|
| 500 |
+
id01460/9fJy9zGdESI/00045 id01460/9fJy9zGdESI/00045
|
dataset/filelists/voxceleb2_test_n_500_seed_797_cross.txt
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id05459/18XmQEiGLnQ/00001 id07961/3EPjXGhfst4/00001
|
| 2 |
+
id03980/7MRUusImkno/00001 id08696/0H1PxInJCK0/00001
|
| 3 |
+
id05654/07pANazoyJg/00001 id04570/0YMGn6BI9rg/00001
|
| 4 |
+
id00817/0GmSijZelGY/00001 id07354/0NjekFZqaY0/00001
|
| 5 |
+
id05202/2gnLcAbAoSc/00001 id00817/0GmSijZelGY/00001
|
| 6 |
+
id03041/5CfnYwQCW48/00001 id07354/0NjekFZqaY0/00001
|
| 7 |
+
id03980/7MRUusImkno/00001 id07621/0CiFdFegqZM/00001
|
| 8 |
+
id05850/B8kp8ed48JE/00001 id04253/1HOlzefgLu8/00001
|
| 9 |
+
id01298/2K5F6xG-Rbs/00001 id05816/1dyCBbJ94iw/00001
|
| 10 |
+
id07494/0P1wPmgz0Bk/00001 id07621/0CiFdFegqZM/00001
|
| 11 |
+
id06913/4Ug7aJemzpg/00001 id04030/7mXUMuo5_NE/00001
|
| 12 |
+
id02286/4LAIxvdvguc/00001 id05850/B8kp8ed48JE/00001
|
| 13 |
+
id02548/0pAkJZmlFqc/00001 id05459/18XmQEiGLnQ/00001
|
| 14 |
+
id08456/29EhSZDqzas/00001 id04295/1fSjOItVYVg/00001
|
| 15 |
+
id04295/1fSjOItVYVg/00001 id02685/4JDRxqYC0a4/00001
|
| 16 |
+
id04276/5M8NmCwTHZ0/00001 id05654/07pANazoyJg/00001
|
| 17 |
+
id03030/5wOxV1wAgqA/00001 id03041/5CfnYwQCW48/00001
|
| 18 |
+
id04656/1tZYt8jey54/00001 id07961/3EPjXGhfst4/00001
|
| 19 |
+
id03980/7MRUusImkno/00001 id04536/0f_Yi_1CoeM/00001
|
| 20 |
+
id03980/7MRUusImkno/00001 id05202/2gnLcAbAoSc/00001
|
| 21 |
+
id01298/2K5F6xG-Rbs/00001 id04862/0zJh2FMTaDE/00001
|
| 22 |
+
id02548/0pAkJZmlFqc/00001 id04478/2grMtwdG93I/00001
|
| 23 |
+
id02685/4JDRxqYC0a4/00001 id01892/3vKPgjwFjbo/00001
|
| 24 |
+
id07494/0P1wPmgz0Bk/00001 id04656/1tZYt8jey54/00001
|
| 25 |
+
id00812/1Xfgvdu7oDo/00001 id00926/2Nd7f1yNQzE/00001
|
| 26 |
+
id07426/1KNFfOFEhyI/00001 id03030/5wOxV1wAgqA/00001
|
| 27 |
+
id00866/03SSllwNkGk/00001 id00812/1Xfgvdu7oDo/00001
|
| 28 |
+
id04570/0YMGn6BI9rg/00001 id01892/3vKPgjwFjbo/00001
|
| 29 |
+
id03041/5CfnYwQCW48/00001 id04030/7mXUMuo5_NE/00001
|
| 30 |
+
id07494/0P1wPmgz0Bk/00001 id00081/2xYrsnvtUWc/00001
|
| 31 |
+
id08392/0fwuibKviJU/00001 id05015/0Cu3AvWWOFI/00001
|
| 32 |
+
id06692/2ptBBNIZXtI/00001 id04536/0f_Yi_1CoeM/00001
|
| 33 |
+
id04253/1HOlzefgLu8/00001 id06104/02L1L9RFAgI/00001
|
| 34 |
+
id02725/37kUrf6RJdw/00001 id02685/4JDRxqYC0a4/00001
|
| 35 |
+
id04006/113VkmVVz1Q/00001 id04119/1uH67UruKlE/00001
|
| 36 |
+
id01567/1Lx_ZqrK1bM/00001 id04030/7mXUMuo5_NE/00001
|
| 37 |
+
id02445/3Rnk8eja3TU/00001 id05816/1dyCBbJ94iw/00001
|
| 38 |
+
id03041/5CfnYwQCW48/00001 id03347/4xXZ75_TeSM/00001
|
| 39 |
+
id04570/0YMGn6BI9rg/00001 id02317/0q4X8kPTlEY/00001
|
| 40 |
+
id07426/1KNFfOFEhyI/00001 id01822/0QcHowaLAF0/00001
|
| 41 |
+
id02577/0euHS_r5JH4/00001 id02725/37kUrf6RJdw/00001
|
| 42 |
+
id07354/0NjekFZqaY0/00001 id05459/18XmQEiGLnQ/00001
|
| 43 |
+
id06692/2ptBBNIZXtI/00001 id05850/B8kp8ed48JE/00001
|
| 44 |
+
id01822/0QcHowaLAF0/00001 id07961/3EPjXGhfst4/00001
|
| 45 |
+
id04366/0iG2Ub9zETM/00001 id03347/4xXZ75_TeSM/00001
|
| 46 |
+
id03030/5wOxV1wAgqA/00001 id03789/0kdVSujPa9g/00001
|
| 47 |
+
id04366/0iG2Ub9zETM/00001 id02286/4LAIxvdvguc/00001
|
| 48 |
+
id00926/2Nd7f1yNQzE/00001 id02548/0pAkJZmlFqc/00001
|
| 49 |
+
id03030/5wOxV1wAgqA/00001 id01298/2K5F6xG-Rbs/00001
|
| 50 |
+
id01892/3vKPgjwFjbo/00001 id02317/0q4X8kPTlEY/00001
|
| 51 |
+
id05202/2gnLcAbAoSc/00001 id04253/1HOlzefgLu8/00001
|
| 52 |
+
id05714/2gvpaZcvAY4/00001 id06692/2ptBBNIZXtI/00001
|
| 53 |
+
id07621/0CiFdFegqZM/00001 id07802/0RUpqvi3sPU/00001
|
| 54 |
+
id03030/5wOxV1wAgqA/00001 id02317/0q4X8kPTlEY/00001
|
| 55 |
+
id01822/0QcHowaLAF0/00001 id02445/3Rnk8eja3TU/00001
|
| 56 |
+
id07961/3EPjXGhfst4/00001 id05459/18XmQEiGLnQ/00001
|
| 57 |
+
id04950/2n4sGPqU9M8/00001 id07426/1KNFfOFEhyI/00001
|
| 58 |
+
id04862/0zJh2FMTaDE/00001 id02465/0Ocu8l1eAng/00001
|
| 59 |
+
id06104/02L1L9RFAgI/00001 id07312/0LWllHGohPY/00001
|
| 60 |
+
id07414/110UMQovTR0/00001 id06692/2ptBBNIZXtI/00001
|
| 61 |
+
id05015/0Cu3AvWWOFI/00001 id08696/0H1PxInJCK0/00001
|
| 62 |
+
id02181/02gIO4WrZLY/00001 id00812/1Xfgvdu7oDo/00001
|
| 63 |
+
id08392/0fwuibKviJU/00001 id01041/1UYZqPpavtk/00001
|
| 64 |
+
id03347/4xXZ75_TeSM/00001 id04950/2n4sGPqU9M8/00001
|
| 65 |
+
id07312/0LWllHGohPY/00001 id04950/2n4sGPqU9M8/00001
|
| 66 |
+
id05202/2gnLcAbAoSc/00001 id05654/07pANazoyJg/00001
|
| 67 |
+
id01041/1UYZqPpavtk/00001 id02317/0q4X8kPTlEY/00001
|
| 68 |
+
id02057/0xZU7Oi9nvM/00001 id03178/2CT-6fnBC_o/00001
|
| 69 |
+
id04006/113VkmVVz1Q/00001 id00817/0GmSijZelGY/00001
|
| 70 |
+
id05850/B8kp8ed48JE/00001 id01892/3vKPgjwFjbo/00001
|
| 71 |
+
id08696/0H1PxInJCK0/00001 id06692/2ptBBNIZXtI/00001
|
| 72 |
+
id02057/0xZU7Oi9nvM/00001 id01541/2P7hzPq5iDw/00001
|
| 73 |
+
id04006/113VkmVVz1Q/00001 id02057/0xZU7Oi9nvM/00001
|
| 74 |
+
id04276/5M8NmCwTHZ0/00001 id04570/0YMGn6BI9rg/00001
|
| 75 |
+
id07868/5YYJq3fSbH8/00001 id03030/5wOxV1wAgqA/00001
|
| 76 |
+
id00812/1Xfgvdu7oDo/00001 id00154/0hjW3eTGAy8/00001
|
| 77 |
+
id06692/2ptBBNIZXtI/00001 id05594/0ohBiepcHWI/00001
|
| 78 |
+
id04536/0f_Yi_1CoeM/00001 id05202/2gnLcAbAoSc/00001
|
| 79 |
+
id06310/1IAgr_CRnuE/00001 id05816/1dyCBbJ94iw/00001
|
| 80 |
+
id01541/2P7hzPq5iDw/00001 id00419/1zffAxBod_c/00001
|
| 81 |
+
id07354/0NjekFZqaY0/00001 id00866/03SSllwNkGk/00001
|
| 82 |
+
id03347/4xXZ75_TeSM/00001 id02577/0euHS_r5JH4/00001
|
| 83 |
+
id04119/1uH67UruKlE/00001 id04006/113VkmVVz1Q/00001
|
| 84 |
+
id05714/2gvpaZcvAY4/00001 id07961/3EPjXGhfst4/00001
|
| 85 |
+
id06104/02L1L9RFAgI/00001 id03178/2CT-6fnBC_o/00001
|
| 86 |
+
id07354/0NjekFZqaY0/00001 id02445/3Rnk8eja3TU/00001
|
| 87 |
+
id04030/7mXUMuo5_NE/00001 id03030/5wOxV1wAgqA/00001
|
| 88 |
+
id07312/0LWllHGohPY/00001 id04536/0f_Yi_1CoeM/00001
|
| 89 |
+
id03839/1jWHvl2qCq0/00001 id07802/0RUpqvi3sPU/00001
|
| 90 |
+
id07621/0CiFdFegqZM/00001 id05816/1dyCBbJ94iw/00001
|
| 91 |
+
id03839/1jWHvl2qCq0/00001 id03980/7MRUusImkno/00001
|
| 92 |
+
id03030/5wOxV1wAgqA/00001 id02445/3Rnk8eja3TU/00001
|
| 93 |
+
id03862/0w8W8jp7MJk/00001 id04253/1HOlzefgLu8/00001
|
| 94 |
+
id05714/2gvpaZcvAY4/00001 id04119/1uH67UruKlE/00001
|
| 95 |
+
id08392/0fwuibKviJU/00001 id07868/5YYJq3fSbH8/00001
|
| 96 |
+
id01298/2K5F6xG-Rbs/00001 id03030/5wOxV1wAgqA/00001
|
| 97 |
+
id05459/18XmQEiGLnQ/00001 id00817/0GmSijZelGY/00001
|
| 98 |
+
id05850/B8kp8ed48JE/00001 id06692/2ptBBNIZXtI/00001
|
| 99 |
+
id04295/1fSjOItVYVg/00001 id08456/29EhSZDqzas/00001
|
| 100 |
+
id04570/0YMGn6BI9rg/00001 id02057/0xZU7Oi9nvM/00001
|
| 101 |
+
id01541/2P7hzPq5iDw/00001 id00817/0GmSijZelGY/00001
|
| 102 |
+
id07426/1KNFfOFEhyI/00001 id07354/0NjekFZqaY0/00001
|
| 103 |
+
id04253/1HOlzefgLu8/00001 id06209/2zM9EAPsZZQ/00001
|
| 104 |
+
id05850/B8kp8ed48JE/00001 id08392/0fwuibKviJU/00001
|
| 105 |
+
id07802/0RUpqvi3sPU/00001 id02465/0Ocu8l1eAng/00001
|
| 106 |
+
id04119/1uH67UruKlE/00001 id04862/0zJh2FMTaDE/00001
|
| 107 |
+
id01541/2P7hzPq5iDw/00001 id08696/0H1PxInJCK0/00001
|
| 108 |
+
id08696/0H1PxInJCK0/00001 id07802/0RUpqvi3sPU/00001
|
| 109 |
+
id01228/2TIFacjgehY/00001 id07621/0CiFdFegqZM/00001
|
| 110 |
+
id03178/2CT-6fnBC_o/00001 id07868/5YYJq3fSbH8/00001
|
| 111 |
+
id05654/07pANazoyJg/00001 id01298/2K5F6xG-Rbs/00001
|
| 112 |
+
id01822/0QcHowaLAF0/00001 id02548/0pAkJZmlFqc/00001
|
| 113 |
+
id01618/0iFlmfmWVlY/00001 id08696/0H1PxInJCK0/00001
|
| 114 |
+
id00812/1Xfgvdu7oDo/00001 id08456/29EhSZDqzas/00001
|
| 115 |
+
id05594/0ohBiepcHWI/00001 id07312/0LWllHGohPY/00001
|
| 116 |
+
id05714/2gvpaZcvAY4/00001 id06104/02L1L9RFAgI/00001
|
| 117 |
+
id02445/3Rnk8eja3TU/00001 id07426/1KNFfOFEhyI/00001
|
| 118 |
+
id05714/2gvpaZcvAY4/00001 id00817/0GmSijZelGY/00001
|
| 119 |
+
id08696/0H1PxInJCK0/00001 id02317/0q4X8kPTlEY/00001
|
| 120 |
+
id04950/2n4sGPqU9M8/00001 id04478/2grMtwdG93I/00001
|
| 121 |
+
id01228/2TIFacjgehY/00001 id07414/110UMQovTR0/00001
|
| 122 |
+
id00926/2Nd7f1yNQzE/00001 id01541/2P7hzPq5iDw/00001
|
| 123 |
+
id05714/2gvpaZcvAY4/00001 id06913/4Ug7aJemzpg/00001
|
| 124 |
+
id01228/2TIFacjgehY/00001 id03862/0w8W8jp7MJk/00001
|
| 125 |
+
id03030/5wOxV1wAgqA/00001 id05015/0Cu3AvWWOFI/00001
|
| 126 |
+
id02548/0pAkJZmlFqc/00001 id06692/2ptBBNIZXtI/00001
|
| 127 |
+
id05202/2gnLcAbAoSc/00001 id04119/1uH67UruKlE/00001
|
| 128 |
+
id04656/1tZYt8jey54/00001 id07426/1KNFfOFEhyI/00001
|
| 129 |
+
id07312/0LWllHGohPY/00001 id03980/7MRUusImkno/00001
|
| 130 |
+
id04366/0iG2Ub9zETM/00001 id00817/0GmSijZelGY/00001
|
| 131 |
+
id07961/3EPjXGhfst4/00001 id01228/2TIFacjgehY/00001
|
| 132 |
+
id00154/0hjW3eTGAy8/00001 id04295/1fSjOItVYVg/00001
|
| 133 |
+
id04478/2grMtwdG93I/00001 id00154/0hjW3eTGAy8/00001
|
| 134 |
+
id04570/0YMGn6BI9rg/00001 id05202/2gnLcAbAoSc/00001
|
| 135 |
+
id04478/2grMtwdG93I/00001 id06913/4Ug7aJemzpg/00001
|
| 136 |
+
id06104/02L1L9RFAgI/00001 id04295/1fSjOItVYVg/00001
|
| 137 |
+
id05816/1dyCBbJ94iw/00001 id08392/0fwuibKviJU/00001
|
| 138 |
+
id00926/2Nd7f1yNQzE/00001 id04536/0f_Yi_1CoeM/00001
|
| 139 |
+
id00926/2Nd7f1yNQzE/00001 id02181/02gIO4WrZLY/00001
|
| 140 |
+
id05459/18XmQEiGLnQ/00001 id02317/0q4X8kPTlEY/00001
|
| 141 |
+
id05594/0ohBiepcHWI/00001 id01228/2TIFacjgehY/00001
|
| 142 |
+
id02181/02gIO4WrZLY/00001 id07312/0LWllHGohPY/00001
|
| 143 |
+
id00154/0hjW3eTGAy8/00001 id03839/1jWHvl2qCq0/00001
|
| 144 |
+
id04030/7mXUMuo5_NE/00001 id02725/37kUrf6RJdw/00001
|
| 145 |
+
id04295/1fSjOItVYVg/00001 id05714/2gvpaZcvAY4/00001
|
| 146 |
+
id02548/0pAkJZmlFqc/00001 id04570/0YMGn6BI9rg/00001
|
| 147 |
+
id04478/2grMtwdG93I/00001 id00866/03SSllwNkGk/00001
|
| 148 |
+
id03030/5wOxV1wAgqA/00001 id04366/0iG2Ub9zETM/00001
|
| 149 |
+
id02685/4JDRxqYC0a4/00001 id07426/1KNFfOFEhyI/00001
|
| 150 |
+
id07802/0RUpqvi3sPU/00001 id07312/0LWllHGohPY/00001
|
| 151 |
+
id02317/0q4X8kPTlEY/00001 id01892/3vKPgjwFjbo/00001
|
| 152 |
+
id00154/0hjW3eTGAy8/00001 id00866/03SSllwNkGk/00001
|
| 153 |
+
id02181/02gIO4WrZLY/00001 id02685/4JDRxqYC0a4/00001
|
| 154 |
+
id03178/2CT-6fnBC_o/00001 id05459/18XmQEiGLnQ/00001
|
| 155 |
+
id00926/2Nd7f1yNQzE/00001 id05202/2gnLcAbAoSc/00001
|
| 156 |
+
id03041/5CfnYwQCW48/00001 id03178/2CT-6fnBC_o/00001
|
| 157 |
+
id05850/B8kp8ed48JE/00001 id04006/113VkmVVz1Q/00001
|
| 158 |
+
id01822/0QcHowaLAF0/00001 id04570/0YMGn6BI9rg/00001
|
| 159 |
+
id04478/2grMtwdG93I/00001 id03839/1jWHvl2qCq0/00001
|
| 160 |
+
id01298/2K5F6xG-Rbs/00001 id01228/2TIFacjgehY/00001
|
| 161 |
+
id06310/1IAgr_CRnuE/00001 id04006/113VkmVVz1Q/00001
|
| 162 |
+
id00154/0hjW3eTGAy8/00001 id04006/113VkmVVz1Q/00001
|
| 163 |
+
id05816/1dyCBbJ94iw/00001 id01041/1UYZqPpavtk/00001
|
| 164 |
+
id04570/0YMGn6BI9rg/00001 id04862/0zJh2FMTaDE/00001
|
| 165 |
+
id06913/4Ug7aJemzpg/00001 id04862/0zJh2FMTaDE/00001
|
| 166 |
+
id03862/0w8W8jp7MJk/00001 id02465/0Ocu8l1eAng/00001
|
| 167 |
+
id04253/1HOlzefgLu8/00001 id01567/1Lx_ZqrK1bM/00001
|
| 168 |
+
id06209/2zM9EAPsZZQ/00001 id01298/2K5F6xG-Rbs/00001
|
| 169 |
+
id01822/0QcHowaLAF0/00001 id01541/2P7hzPq5iDw/00001
|
| 170 |
+
id07312/0LWllHGohPY/00001 id02317/0q4X8kPTlEY/00001
|
| 171 |
+
id06692/2ptBBNIZXtI/00001 id02445/3Rnk8eja3TU/00001
|
| 172 |
+
id07414/110UMQovTR0/00001 id00154/0hjW3eTGAy8/00001
|
| 173 |
+
id04478/2grMtwdG93I/00001 id03347/4xXZ75_TeSM/00001
|
| 174 |
+
id04656/1tZYt8jey54/00001 id07802/0RUpqvi3sPU/00001
|
| 175 |
+
id03839/1jWHvl2qCq0/00001 id06310/1IAgr_CRnuE/00001
|
| 176 |
+
id02057/0xZU7Oi9nvM/00001 id01228/2TIFacjgehY/00001
|
| 177 |
+
id00081/2xYrsnvtUWc/00001 id02057/0xZU7Oi9nvM/00001
|
| 178 |
+
id03862/0w8W8jp7MJk/00001 id01892/3vKPgjwFjbo/00001
|
| 179 |
+
id04570/0YMGn6BI9rg/00001 id06913/4Ug7aJemzpg/00001
|
| 180 |
+
id08392/0fwuibKviJU/00001 id01567/1Lx_ZqrK1bM/00001
|
| 181 |
+
id00081/2xYrsnvtUWc/00001 id07494/0P1wPmgz0Bk/00001
|
| 182 |
+
id04536/0f_Yi_1CoeM/00001 id00081/2xYrsnvtUWc/00001
|
| 183 |
+
id03839/1jWHvl2qCq0/00001 id05850/B8kp8ed48JE/00001
|
| 184 |
+
id07621/0CiFdFegqZM/00001 id08456/29EhSZDqzas/00001
|
| 185 |
+
id01822/0QcHowaLAF0/00001 id07868/5YYJq3fSbH8/00001
|
| 186 |
+
id05202/2gnLcAbAoSc/00001 id03178/2CT-6fnBC_o/00001
|
| 187 |
+
id06692/2ptBBNIZXtI/00001 id06913/4Ug7aJemzpg/00001
|
| 188 |
+
id01041/1UYZqPpavtk/00001 id03030/5wOxV1wAgqA/00001
|
| 189 |
+
id07426/1KNFfOFEhyI/00001 id08456/29EhSZDqzas/00001
|
| 190 |
+
id04478/2grMtwdG93I/00001 id02548/0pAkJZmlFqc/00001
|
| 191 |
+
id08392/0fwuibKviJU/00001 id01298/2K5F6xG-Rbs/00001
|
| 192 |
+
id03041/5CfnYwQCW48/00001 id08696/0H1PxInJCK0/00001
|
| 193 |
+
id04366/0iG2Ub9zETM/00001 id07426/1KNFfOFEhyI/00001
|
| 194 |
+
id04950/2n4sGPqU9M8/00001 id07494/0P1wPmgz0Bk/00001
|
| 195 |
+
id01822/0QcHowaLAF0/00001 id08392/0fwuibKviJU/00001
|
| 196 |
+
id02577/0euHS_r5JH4/00001 id06692/2ptBBNIZXtI/00001
|
| 197 |
+
id04570/0YMGn6BI9rg/00001 id00866/03SSllwNkGk/00001
|
| 198 |
+
id05850/B8kp8ed48JE/00001 id08456/29EhSZDqzas/00001
|
| 199 |
+
id01618/0iFlmfmWVlY/00001 id01041/1UYZqPpavtk/00001
|
| 200 |
+
id07414/110UMQovTR0/00001 id04536/0f_Yi_1CoeM/00001
|
| 201 |
+
id02057/0xZU7Oi9nvM/00001 id06913/4Ug7aJemzpg/00001
|
| 202 |
+
id04536/0f_Yi_1CoeM/00001 id01041/1UYZqPpavtk/00001
|
| 203 |
+
id04030/7mXUMuo5_NE/00001 id05850/B8kp8ed48JE/00001
|
| 204 |
+
id04656/1tZYt8jey54/00001 id05459/18XmQEiGLnQ/00001
|
| 205 |
+
id03789/0kdVSujPa9g/00001 id02057/0xZU7Oi9nvM/00001
|
| 206 |
+
id01041/1UYZqPpavtk/00001 id05594/0ohBiepcHWI/00001
|
| 207 |
+
id07494/0P1wPmgz0Bk/00001 id04006/113VkmVVz1Q/00001
|
| 208 |
+
id00812/1Xfgvdu7oDo/00001 id04295/1fSjOItVYVg/00001
|
| 209 |
+
id01541/2P7hzPq5iDw/00001 id02465/0Ocu8l1eAng/00001
|
| 210 |
+
id04862/0zJh2FMTaDE/00001 id05594/0ohBiepcHWI/00001
|
| 211 |
+
id05714/2gvpaZcvAY4/00001 id02286/4LAIxvdvguc/00001
|
| 212 |
+
id06209/2zM9EAPsZZQ/00001 id05816/1dyCBbJ94iw/00001
|
| 213 |
+
id05850/B8kp8ed48JE/00001 id00866/03SSllwNkGk/00001
|
| 214 |
+
id07494/0P1wPmgz0Bk/00001 id07312/0LWllHGohPY/00001
|
| 215 |
+
id04366/0iG2Ub9zETM/00001 id04570/0YMGn6BI9rg/00001
|
| 216 |
+
id00866/03SSllwNkGk/00001 id03347/4xXZ75_TeSM/00001
|
| 217 |
+
id02445/3Rnk8eja3TU/00001 id07802/0RUpqvi3sPU/00001
|
| 218 |
+
id08696/0H1PxInJCK0/00001 id06209/2zM9EAPsZZQ/00001
|
| 219 |
+
id02445/3Rnk8eja3TU/00001 id07621/0CiFdFegqZM/00001
|
| 220 |
+
id08392/0fwuibKviJU/00001 id05850/B8kp8ed48JE/00001
|
| 221 |
+
id00419/1zffAxBod_c/00001 id01228/2TIFacjgehY/00001
|
| 222 |
+
id07354/0NjekFZqaY0/00001 id01041/1UYZqPpavtk/00001
|
| 223 |
+
id04570/0YMGn6BI9rg/00001 id03347/4xXZ75_TeSM/00001
|
| 224 |
+
id01892/3vKPgjwFjbo/00001 id02445/3Rnk8eja3TU/00001
|
| 225 |
+
id00081/2xYrsnvtUWc/00001 id05459/18XmQEiGLnQ/00001
|
| 226 |
+
id06104/02L1L9RFAgI/00001 id04570/0YMGn6BI9rg/00001
|
| 227 |
+
id07961/3EPjXGhfst4/00001 id05654/07pANazoyJg/00001
|
| 228 |
+
id00926/2Nd7f1yNQzE/00001 id03839/1jWHvl2qCq0/00001
|
| 229 |
+
id02181/02gIO4WrZLY/00001 id08696/0H1PxInJCK0/00001
|
| 230 |
+
id07426/1KNFfOFEhyI/00001 id05459/18XmQEiGLnQ/00001
|
| 231 |
+
id03041/5CfnYwQCW48/00001 id06104/02L1L9RFAgI/00001
|
| 232 |
+
id01298/2K5F6xG-Rbs/00001 id01541/2P7hzPq5iDw/00001
|
| 233 |
+
id04570/0YMGn6BI9rg/00001 id01618/0iFlmfmWVlY/00001
|
| 234 |
+
id02685/4JDRxqYC0a4/00001 id02548/0pAkJZmlFqc/00001
|
| 235 |
+
id01822/0QcHowaLAF0/00001 id07426/1KNFfOFEhyI/00001
|
| 236 |
+
id07868/5YYJq3fSbH8/00001 id07494/0P1wPmgz0Bk/00001
|
| 237 |
+
id07802/0RUpqvi3sPU/00001 id03041/5CfnYwQCW48/00001
|
| 238 |
+
id04656/1tZYt8jey54/00001 id01541/2P7hzPq5iDw/00001
|
| 239 |
+
id03347/4xXZ75_TeSM/00001 id02445/3Rnk8eja3TU/00001
|
| 240 |
+
id02548/0pAkJZmlFqc/00001 id01298/2K5F6xG-Rbs/00001
|
| 241 |
+
id07354/0NjekFZqaY0/00001 id07426/1KNFfOFEhyI/00001
|
| 242 |
+
id03862/0w8W8jp7MJk/00001 id01298/2K5F6xG-Rbs/00001
|
| 243 |
+
id04536/0f_Yi_1CoeM/00001 id02465/0Ocu8l1eAng/00001
|
| 244 |
+
id00081/2xYrsnvtUWc/00001 id04366/0iG2Ub9zETM/00001
|
| 245 |
+
id04950/2n4sGPqU9M8/00001 id01822/0QcHowaLAF0/00001
|
| 246 |
+
id06692/2ptBBNIZXtI/00001 id03030/5wOxV1wAgqA/00001
|
| 247 |
+
id07312/0LWllHGohPY/00001 id04478/2grMtwdG93I/00001
|
| 248 |
+
id03862/0w8W8jp7MJk/00001 id03030/5wOxV1wAgqA/00001
|
| 249 |
+
id00081/2xYrsnvtUWc/00001 id08392/0fwuibKviJU/00001
|
| 250 |
+
id02317/0q4X8kPTlEY/00001 id00154/0hjW3eTGAy8/00001
|
| 251 |
+
id05594/0ohBiepcHWI/00001 id04536/0f_Yi_1CoeM/00001
|
| 252 |
+
id07868/5YYJq3fSbH8/00001 id03839/1jWHvl2qCq0/00001
|
| 253 |
+
id02577/0euHS_r5JH4/00001 id06913/4Ug7aJemzpg/00001
|
| 254 |
+
id08456/29EhSZDqzas/00001 id01541/2P7hzPq5iDw/00001
|
| 255 |
+
id01567/1Lx_ZqrK1bM/00001 id04119/1uH67UruKlE/00001
|
| 256 |
+
id04253/1HOlzefgLu8/00001 id01228/2TIFacjgehY/00001
|
| 257 |
+
id02445/3Rnk8eja3TU/00001 id02685/4JDRxqYC0a4/00001
|
| 258 |
+
id05015/0Cu3AvWWOFI/00001 id02465/0Ocu8l1eAng/00001
|
| 259 |
+
id07494/0P1wPmgz0Bk/00001 id05714/2gvpaZcvAY4/00001
|
| 260 |
+
id02548/0pAkJZmlFqc/00001 id04006/113VkmVVz1Q/00001
|
| 261 |
+
id00866/03SSllwNkGk/00001 id02317/0q4X8kPTlEY/00001
|
| 262 |
+
id07354/0NjekFZqaY0/00001 id04253/1HOlzefgLu8/00001
|
| 263 |
+
id00812/1Xfgvdu7oDo/00001 id03030/5wOxV1wAgqA/00001
|
| 264 |
+
id02465/0Ocu8l1eAng/00001 id07354/0NjekFZqaY0/00001
|
| 265 |
+
id04276/5M8NmCwTHZ0/00001 id03862/0w8W8jp7MJk/00001
|
| 266 |
+
id01567/1Lx_ZqrK1bM/00001 id04253/1HOlzefgLu8/00001
|
| 267 |
+
id01618/0iFlmfmWVlY/00001 id06913/4Ug7aJemzpg/00001
|
| 268 |
+
id03862/0w8W8jp7MJk/00001 id08392/0fwuibKviJU/00001
|
| 269 |
+
id07961/3EPjXGhfst4/00001 id00154/0hjW3eTGAy8/00001
|
| 270 |
+
id02577/0euHS_r5JH4/00001 id01228/2TIFacjgehY/00001
|
| 271 |
+
id05654/07pANazoyJg/00001 id03041/5CfnYwQCW48/00001
|
| 272 |
+
id03980/7MRUusImkno/00001 id08392/0fwuibKviJU/00001
|
| 273 |
+
id03178/2CT-6fnBC_o/00001 id04295/1fSjOItVYVg/00001
|
| 274 |
+
id02317/0q4X8kPTlEY/00001 id03347/4xXZ75_TeSM/00001
|
| 275 |
+
id02548/0pAkJZmlFqc/00001 id07426/1KNFfOFEhyI/00001
|
| 276 |
+
id03839/1jWHvl2qCq0/00001 id05654/07pANazoyJg/00001
|
| 277 |
+
id02548/0pAkJZmlFqc/00001 id07868/5YYJq3fSbH8/00001
|
| 278 |
+
id04570/0YMGn6BI9rg/00001 id01041/1UYZqPpavtk/00001
|
| 279 |
+
id07414/110UMQovTR0/00001 id00419/1zffAxBod_c/00001
|
| 280 |
+
id00154/0hjW3eTGAy8/00001 id01618/0iFlmfmWVlY/00001
|
| 281 |
+
id07494/0P1wPmgz0Bk/00001 id05654/07pANazoyJg/00001
|
| 282 |
+
id01822/0QcHowaLAF0/00001 id06310/1IAgr_CRnuE/00001
|
| 283 |
+
id05015/0Cu3AvWWOFI/00001 id05459/18XmQEiGLnQ/00001
|
| 284 |
+
id05816/1dyCBbJ94iw/00001 id02317/0q4X8kPTlEY/00001
|
| 285 |
+
id01541/2P7hzPq5iDw/00001 id05816/1dyCBbJ94iw/00001
|
| 286 |
+
id06104/02L1L9RFAgI/00001 id01892/3vKPgjwFjbo/00001
|
| 287 |
+
id04862/0zJh2FMTaDE/00001 id05850/B8kp8ed48JE/00001
|
| 288 |
+
id05202/2gnLcAbAoSc/00001 id04366/0iG2Ub9zETM/00001
|
| 289 |
+
id02286/4LAIxvdvguc/00001 id02725/37kUrf6RJdw/00001
|
| 290 |
+
id04276/5M8NmCwTHZ0/00001 id01541/2P7hzPq5iDw/00001
|
| 291 |
+
id02057/0xZU7Oi9nvM/00001 id03862/0w8W8jp7MJk/00001
|
| 292 |
+
id06104/02L1L9RFAgI/00001 id00419/1zffAxBod_c/00001
|
| 293 |
+
id04950/2n4sGPqU9M8/00001 id02181/02gIO4WrZLY/00001
|
| 294 |
+
id04478/2grMtwdG93I/00001 id02685/4JDRxqYC0a4/00001
|
| 295 |
+
id04006/113VkmVVz1Q/00001 id00081/2xYrsnvtUWc/00001
|
| 296 |
+
id06692/2ptBBNIZXtI/00001 id03347/4xXZ75_TeSM/00001
|
| 297 |
+
id03030/5wOxV1wAgqA/00001 id02465/0Ocu8l1eAng/00001
|
| 298 |
+
id07312/0LWllHGohPY/00001 id03839/1jWHvl2qCq0/00001
|
| 299 |
+
id04950/2n4sGPqU9M8/00001 id05654/07pANazoyJg/00001
|
| 300 |
+
id02465/0Ocu8l1eAng/00001 id01618/0iFlmfmWVlY/00001
|
| 301 |
+
id00419/1zffAxBod_c/00001 id02181/02gIO4WrZLY/00001
|
| 302 |
+
id07426/1KNFfOFEhyI/00001 id05202/2gnLcAbAoSc/00001
|
| 303 |
+
id07621/0CiFdFegqZM/00001 id08696/0H1PxInJCK0/00001
|
| 304 |
+
id04006/113VkmVVz1Q/00001 id08392/0fwuibKviJU/00001
|
| 305 |
+
id04478/2grMtwdG93I/00001 id02445/3Rnk8eja3TU/00001
|
| 306 |
+
id03347/4xXZ75_TeSM/00001 id00154/0hjW3eTGAy8/00001
|
| 307 |
+
id07312/0LWllHGohPY/00001 id02181/02gIO4WrZLY/00001
|
| 308 |
+
id06310/1IAgr_CRnuE/00001 id02057/0xZU7Oi9nvM/00001
|
| 309 |
+
id04366/0iG2Ub9zETM/00001 id05654/07pANazoyJg/00001
|
| 310 |
+
id00419/1zffAxBod_c/00001 id04570/0YMGn6BI9rg/00001
|
| 311 |
+
id04862/0zJh2FMTaDE/00001 id03862/0w8W8jp7MJk/00001
|
| 312 |
+
id04366/0iG2Ub9zETM/00001 id00154/0hjW3eTGAy8/00001
|
| 313 |
+
id00866/03SSllwNkGk/00001 id00081/2xYrsnvtUWc/00001
|
| 314 |
+
id01618/0iFlmfmWVlY/00001 id02725/37kUrf6RJdw/00001
|
| 315 |
+
id01892/3vKPgjwFjbo/00001 id07621/0CiFdFegqZM/00001
|
| 316 |
+
id05015/0Cu3AvWWOFI/00001 id00926/2Nd7f1yNQzE/00001
|
| 317 |
+
id06913/4Ug7aJemzpg/00001 id03839/1jWHvl2qCq0/00001
|
| 318 |
+
id07312/0LWllHGohPY/00001 id07802/0RUpqvi3sPU/00001
|
| 319 |
+
id06104/02L1L9RFAgI/00001 id02465/0Ocu8l1eAng/00001
|
| 320 |
+
id04295/1fSjOItVYVg/00001 id01298/2K5F6xG-Rbs/00001
|
| 321 |
+
id00866/03SSllwNkGk/00001 id05714/2gvpaZcvAY4/00001
|
| 322 |
+
id06104/02L1L9RFAgI/00001 id01541/2P7hzPq5iDw/00001
|
| 323 |
+
id02445/3Rnk8eja3TU/00001 id03789/0kdVSujPa9g/00001
|
| 324 |
+
id00081/2xYrsnvtUWc/00001 id05816/1dyCBbJ94iw/00001
|
| 325 |
+
id02548/0pAkJZmlFqc/00001 id03030/5wOxV1wAgqA/00001
|
| 326 |
+
id04276/5M8NmCwTHZ0/00001 id01041/1UYZqPpavtk/00001
|
| 327 |
+
id06913/4Ug7aJemzpg/00001 id07868/5YYJq3fSbH8/00001
|
| 328 |
+
id04656/1tZYt8jey54/00001 id06692/2ptBBNIZXtI/00001
|
| 329 |
+
id07494/0P1wPmgz0Bk/00001 id08696/0H1PxInJCK0/00001
|
| 330 |
+
id04119/1uH67UruKlE/00001 id02317/0q4X8kPTlEY/00001
|
| 331 |
+
id00419/1zffAxBod_c/00001 id04862/0zJh2FMTaDE/00001
|
| 332 |
+
id03862/0w8W8jp7MJk/00001 id02445/3Rnk8eja3TU/00001
|
| 333 |
+
id01892/3vKPgjwFjbo/00001 id04862/0zJh2FMTaDE/00001
|
| 334 |
+
id04950/2n4sGPqU9M8/00001 id01618/0iFlmfmWVlY/00001
|
| 335 |
+
id01228/2TIFacjgehY/00001 id01298/2K5F6xG-Rbs/00001
|
| 336 |
+
id01041/1UYZqPpavtk/00001 id07961/3EPjXGhfst4/00001
|
| 337 |
+
id07802/0RUpqvi3sPU/00001 id06913/4Ug7aJemzpg/00001
|
| 338 |
+
id04276/5M8NmCwTHZ0/00001 id03030/5wOxV1wAgqA/00001
|
| 339 |
+
id01567/1Lx_ZqrK1bM/00001 id05459/18XmQEiGLnQ/00001
|
| 340 |
+
id02465/0Ocu8l1eAng/00001 id02725/37kUrf6RJdw/00001
|
| 341 |
+
id05816/1dyCBbJ94iw/00001 id02181/02gIO4WrZLY/00001
|
| 342 |
+
id06913/4Ug7aJemzpg/00001 id04950/2n4sGPqU9M8/00001
|
| 343 |
+
id04276/5M8NmCwTHZ0/00001 id04253/1HOlzefgLu8/00001
|
| 344 |
+
id07414/110UMQovTR0/00001 id06209/2zM9EAPsZZQ/00001
|
| 345 |
+
id06310/1IAgr_CRnuE/00001 id03839/1jWHvl2qCq0/00001
|
| 346 |
+
id03347/4xXZ75_TeSM/00001 id04006/113VkmVVz1Q/00001
|
| 347 |
+
id01541/2P7hzPq5iDw/00001 id04253/1HOlzefgLu8/00001
|
| 348 |
+
id08456/29EhSZDqzas/00001 id07494/0P1wPmgz0Bk/00001
|
| 349 |
+
id07621/0CiFdFegqZM/00001 id05594/0ohBiepcHWI/00001
|
| 350 |
+
id02685/4JDRxqYC0a4/00001 id04536/0f_Yi_1CoeM/00001
|
| 351 |
+
id02317/0q4X8kPTlEY/00001 id08696/0H1PxInJCK0/00001
|
| 352 |
+
id04253/1HOlzefgLu8/00001 id01041/1UYZqPpavtk/00001
|
| 353 |
+
id01041/1UYZqPpavtk/00001 id03178/2CT-6fnBC_o/00001
|
| 354 |
+
id05654/07pANazoyJg/00001 id01892/3vKPgjwFjbo/00001
|
| 355 |
+
id04862/0zJh2FMTaDE/00001 id06310/1IAgr_CRnuE/00001
|
| 356 |
+
id01541/2P7hzPq5iDw/00001 id04478/2grMtwdG93I/00001
|
| 357 |
+
id02445/3Rnk8eja3TU/00001 id02057/0xZU7Oi9nvM/00001
|
| 358 |
+
id08392/0fwuibKviJU/00001 id04570/0YMGn6BI9rg/00001
|
| 359 |
+
id06692/2ptBBNIZXtI/00001 id02057/0xZU7Oi9nvM/00001
|
| 360 |
+
id04950/2n4sGPqU9M8/00001 id04862/0zJh2FMTaDE/00001
|
| 361 |
+
id03862/0w8W8jp7MJk/00001 id07621/0CiFdFegqZM/00001
|
| 362 |
+
id07312/0LWllHGohPY/00001 id04656/1tZYt8jey54/00001
|
| 363 |
+
id02577/0euHS_r5JH4/00001 id00866/03SSllwNkGk/00001
|
| 364 |
+
id01228/2TIFacjgehY/00001 id02685/4JDRxqYC0a4/00001
|
| 365 |
+
id00081/2xYrsnvtUWc/00001 id00419/1zffAxBod_c/00001
|
| 366 |
+
id00154/0hjW3eTGAy8/00001 id04656/1tZYt8jey54/00001
|
| 367 |
+
id03839/1jWHvl2qCq0/00001 id01618/0iFlmfmWVlY/00001
|
| 368 |
+
id03862/0w8W8jp7MJk/00001 id02286/4LAIxvdvguc/00001
|
| 369 |
+
id06310/1IAgr_CRnuE/00001 id08456/29EhSZDqzas/00001
|
| 370 |
+
id02317/0q4X8kPTlEY/00001 id04276/5M8NmCwTHZ0/00001
|
| 371 |
+
id06913/4Ug7aJemzpg/00001 id04366/0iG2Ub9zETM/00001
|
| 372 |
+
id06310/1IAgr_CRnuE/00001 id00926/2Nd7f1yNQzE/00001
|
| 373 |
+
id01228/2TIFacjgehY/00001 id02181/02gIO4WrZLY/00001
|
| 374 |
+
id07414/110UMQovTR0/00001 id05594/0ohBiepcHWI/00001
|
| 375 |
+
id03980/7MRUusImkno/00001 id03178/2CT-6fnBC_o/00001
|
| 376 |
+
id03347/4xXZ75_TeSM/00001 id04478/2grMtwdG93I/00001
|
| 377 |
+
id06692/2ptBBNIZXtI/00001 id05459/18XmQEiGLnQ/00001
|
| 378 |
+
id00154/0hjW3eTGAy8/00001 id02725/37kUrf6RJdw/00001
|
| 379 |
+
id01228/2TIFacjgehY/00001 id04006/113VkmVVz1Q/00001
|
| 380 |
+
id00866/03SSllwNkGk/00001 id00926/2Nd7f1yNQzE/00001
|
| 381 |
+
id05594/0ohBiepcHWI/00001 id04006/113VkmVVz1Q/00001
|
| 382 |
+
id04656/1tZYt8jey54/00001 id01822/0QcHowaLAF0/00001
|
| 383 |
+
id07354/0NjekFZqaY0/00001 id04536/0f_Yi_1CoeM/00001
|
| 384 |
+
id07354/0NjekFZqaY0/00001 id04656/1tZYt8jey54/00001
|
| 385 |
+
id04366/0iG2Ub9zETM/00001 id02057/0xZU7Oi9nvM/00001
|
| 386 |
+
id03789/0kdVSujPa9g/00001 id01822/0QcHowaLAF0/00001
|
| 387 |
+
id07621/0CiFdFegqZM/00001 id03347/4xXZ75_TeSM/00001
|
| 388 |
+
id04030/7mXUMuo5_NE/00001 id04366/0iG2Ub9zETM/00001
|
| 389 |
+
id00812/1Xfgvdu7oDo/00001 id07354/0NjekFZqaY0/00001
|
| 390 |
+
id04536/0f_Yi_1CoeM/00001 id07494/0P1wPmgz0Bk/00001
|
| 391 |
+
id04536/0f_Yi_1CoeM/00001 id05816/1dyCBbJ94iw/00001
|
| 392 |
+
id03862/0w8W8jp7MJk/00001 id07868/5YYJq3fSbH8/00001
|
| 393 |
+
id02685/4JDRxqYC0a4/00001 id05459/18XmQEiGLnQ/00001
|
| 394 |
+
id06209/2zM9EAPsZZQ/00001 id07426/1KNFfOFEhyI/00001
|
| 395 |
+
id07426/1KNFfOFEhyI/00001 id02317/0q4X8kPTlEY/00001
|
| 396 |
+
id00926/2Nd7f1yNQzE/00001 id05594/0ohBiepcHWI/00001
|
| 397 |
+
id00154/0hjW3eTGAy8/00001 id04950/2n4sGPqU9M8/00001
|
| 398 |
+
id03041/5CfnYwQCW48/00001 id01892/3vKPgjwFjbo/00001
|
| 399 |
+
id00419/1zffAxBod_c/00001 id00866/03SSllwNkGk/00001
|
| 400 |
+
id02725/37kUrf6RJdw/00001 id05202/2gnLcAbAoSc/00001
|
| 401 |
+
id04656/1tZYt8jey54/00001 id06913/4Ug7aJemzpg/00001
|
| 402 |
+
id03862/0w8W8jp7MJk/00001 id04006/113VkmVVz1Q/00001
|
| 403 |
+
id00419/1zffAxBod_c/00001 id04030/7mXUMuo5_NE/00001
|
| 404 |
+
id06692/2ptBBNIZXtI/00001 id01541/2P7hzPq5iDw/00001
|
| 405 |
+
id07354/0NjekFZqaY0/00001 id03041/5CfnYwQCW48/00001
|
| 406 |
+
id03347/4xXZ75_TeSM/00001 id07802/0RUpqvi3sPU/00001
|
| 407 |
+
id07354/0NjekFZqaY0/00001 id01298/2K5F6xG-Rbs/00001
|
| 408 |
+
id02725/37kUrf6RJdw/00001 id03980/7MRUusImkno/00001
|
| 409 |
+
id01618/0iFlmfmWVlY/00001 id02445/3Rnk8eja3TU/00001
|
| 410 |
+
id05816/1dyCBbJ94iw/00001 id00081/2xYrsnvtUWc/00001
|
| 411 |
+
id07354/0NjekFZqaY0/00001 id04478/2grMtwdG93I/00001
|
| 412 |
+
id03980/7MRUusImkno/00001 id04295/1fSjOItVYVg/00001
|
| 413 |
+
id02548/0pAkJZmlFqc/00001 id00081/2xYrsnvtUWc/00001
|
| 414 |
+
id05459/18XmQEiGLnQ/00001 id03347/4xXZ75_TeSM/00001
|
| 415 |
+
id04570/0YMGn6BI9rg/00001 id04006/113VkmVVz1Q/00001
|
| 416 |
+
id06209/2zM9EAPsZZQ/00001 id01041/1UYZqPpavtk/00001
|
| 417 |
+
id01228/2TIFacjgehY/00001 id02317/0q4X8kPTlEY/00001
|
| 418 |
+
id07802/0RUpqvi3sPU/00001 id01541/2P7hzPq5iDw/00001
|
| 419 |
+
id04862/0zJh2FMTaDE/00001 id01892/3vKPgjwFjbo/00001
|
| 420 |
+
id04253/1HOlzefgLu8/00001 id07802/0RUpqvi3sPU/00001
|
| 421 |
+
id06692/2ptBBNIZXtI/00001 id02286/4LAIxvdvguc/00001
|
| 422 |
+
id01228/2TIFacjgehY/00001 id07961/3EPjXGhfst4/00001
|
| 423 |
+
id05714/2gvpaZcvAY4/00001 id00812/1Xfgvdu7oDo/00001
|
| 424 |
+
id03789/0kdVSujPa9g/00001 id03862/0w8W8jp7MJk/00001
|
| 425 |
+
id04295/1fSjOItVYVg/00001 id07868/5YYJq3fSbH8/00001
|
| 426 |
+
id04276/5M8NmCwTHZ0/00001 id02057/0xZU7Oi9nvM/00001
|
| 427 |
+
id02286/4LAIxvdvguc/00001 id03862/0w8W8jp7MJk/00001
|
| 428 |
+
id04478/2grMtwdG93I/00001 id05816/1dyCBbJ94iw/00001
|
| 429 |
+
id08456/29EhSZDqzas/00001 id02725/37kUrf6RJdw/00001
|
| 430 |
+
id02577/0euHS_r5JH4/00001 id07961/3EPjXGhfst4/00001
|
| 431 |
+
id01618/0iFlmfmWVlY/00001 id00812/1Xfgvdu7oDo/00001
|
| 432 |
+
id07312/0LWllHGohPY/00001 id03789/0kdVSujPa9g/00001
|
| 433 |
+
id02685/4JDRxqYC0a4/00001 id03839/1jWHvl2qCq0/00001
|
| 434 |
+
id04030/7mXUMuo5_NE/00001 id07802/0RUpqvi3sPU/00001
|
| 435 |
+
id01567/1Lx_ZqrK1bM/00001 id04478/2grMtwdG93I/00001
|
| 436 |
+
id02577/0euHS_r5JH4/00001 id02548/0pAkJZmlFqc/00001
|
| 437 |
+
id04536/0f_Yi_1CoeM/00001 id03030/5wOxV1wAgqA/00001
|
| 438 |
+
id03347/4xXZ75_TeSM/00001 id00081/2xYrsnvtUWc/00001
|
| 439 |
+
id03980/7MRUusImkno/00001 id06209/2zM9EAPsZZQ/00001
|
| 440 |
+
id01567/1Lx_ZqrK1bM/00001 id00154/0hjW3eTGAy8/00001
|
| 441 |
+
id06104/02L1L9RFAgI/00001 id02057/0xZU7Oi9nvM/00001
|
| 442 |
+
id04570/0YMGn6BI9rg/00001 id03980/7MRUusImkno/00001
|
| 443 |
+
id08456/29EhSZDqzas/00001 id02286/4LAIxvdvguc/00001
|
| 444 |
+
id07312/0LWllHGohPY/00001 id04366/0iG2Ub9zETM/00001
|
| 445 |
+
id05654/07pANazoyJg/00001 id07426/1KNFfOFEhyI/00001
|
| 446 |
+
id03839/1jWHvl2qCq0/00001 id03347/4xXZ75_TeSM/00001
|
| 447 |
+
id04536/0f_Yi_1CoeM/00001 id04478/2grMtwdG93I/00001
|
| 448 |
+
id05816/1dyCBbJ94iw/00001 id04862/0zJh2FMTaDE/00001
|
| 449 |
+
id04950/2n4sGPqU9M8/00001 id00817/0GmSijZelGY/00001
|
| 450 |
+
id07426/1KNFfOFEhyI/00001 id04862/0zJh2FMTaDE/00001
|
| 451 |
+
id05459/18XmQEiGLnQ/00001 id00812/1Xfgvdu7oDo/00001
|
| 452 |
+
id00154/0hjW3eTGAy8/00001 id03178/2CT-6fnBC_o/00001
|
| 453 |
+
id04295/1fSjOItVYVg/00001 id07312/0LWllHGohPY/00001
|
| 454 |
+
id05594/0ohBiepcHWI/00001 id04862/0zJh2FMTaDE/00001
|
| 455 |
+
id03347/4xXZ75_TeSM/00001 id01541/2P7hzPq5iDw/00001
|
| 456 |
+
id04536/0f_Yi_1CoeM/00001 id02445/3Rnk8eja3TU/00001
|
| 457 |
+
id03862/0w8W8jp7MJk/00001 id04030/7mXUMuo5_NE/00001
|
| 458 |
+
id00154/0hjW3eTGAy8/00001 id01541/2P7hzPq5iDw/00001
|
| 459 |
+
id06913/4Ug7aJemzpg/00001 id03347/4xXZ75_TeSM/00001
|
| 460 |
+
id08696/0H1PxInJCK0/00001 id04478/2grMtwdG93I/00001
|
| 461 |
+
id04366/0iG2Ub9zETM/00001 id02445/3Rnk8eja3TU/00001
|
| 462 |
+
id07354/0NjekFZqaY0/00001 id01567/1Lx_ZqrK1bM/00001
|
| 463 |
+
id06913/4Ug7aJemzpg/00001 id05202/2gnLcAbAoSc/00001
|
| 464 |
+
id04862/0zJh2FMTaDE/00001 id08696/0H1PxInJCK0/00001
|
| 465 |
+
id03178/2CT-6fnBC_o/00001 id02685/4JDRxqYC0a4/00001
|
| 466 |
+
id01822/0QcHowaLAF0/00001 id04950/2n4sGPqU9M8/00001
|
| 467 |
+
id00081/2xYrsnvtUWc/00001 id06913/4Ug7aJemzpg/00001
|
| 468 |
+
id07868/5YYJq3fSbH8/00001 id02465/0Ocu8l1eAng/00001
|
| 469 |
+
id02181/02gIO4WrZLY/00001 id03862/0w8W8jp7MJk/00001
|
| 470 |
+
id07868/5YYJq3fSbH8/00001 id05202/2gnLcAbAoSc/00001
|
| 471 |
+
id02286/4LAIxvdvguc/00001 id03178/2CT-6fnBC_o/00001
|
| 472 |
+
id01298/2K5F6xG-Rbs/00001 id01618/0iFlmfmWVlY/00001
|
| 473 |
+
id03980/7MRUusImkno/00001 id04006/113VkmVVz1Q/00001
|
| 474 |
+
id03862/0w8W8jp7MJk/00001 id08456/29EhSZDqzas/00001
|
| 475 |
+
id01567/1Lx_ZqrK1bM/00001 id03041/5CfnYwQCW48/00001
|
| 476 |
+
id02465/0Ocu8l1eAng/00001 id00419/1zffAxBod_c/00001
|
| 477 |
+
id04570/0YMGn6BI9rg/00001 id04295/1fSjOItVYVg/00001
|
| 478 |
+
id03862/0w8W8jp7MJk/00001 id04295/1fSjOItVYVg/00001
|
| 479 |
+
id03789/0kdVSujPa9g/00001 id00866/03SSllwNkGk/00001
|
| 480 |
+
id05654/07pANazoyJg/00001 id00926/2Nd7f1yNQzE/00001
|
| 481 |
+
id05850/B8kp8ed48JE/00001 id02685/4JDRxqYC0a4/00001
|
| 482 |
+
id03347/4xXZ75_TeSM/00001 id08392/0fwuibKviJU/00001
|
| 483 |
+
id00926/2Nd7f1yNQzE/00001 id07312/0LWllHGohPY/00001
|
| 484 |
+
id05850/B8kp8ed48JE/00001 id01041/1UYZqPpavtk/00001
|
| 485 |
+
id03030/5wOxV1wAgqA/00001 id06913/4Ug7aJemzpg/00001
|
| 486 |
+
id02057/0xZU7Oi9nvM/00001 id01041/1UYZqPpavtk/00001
|
| 487 |
+
id03030/5wOxV1wAgqA/00001 id01041/1UYZqPpavtk/00001
|
| 488 |
+
id01618/0iFlmfmWVlY/00001 id04366/0iG2Ub9zETM/00001
|
| 489 |
+
id06310/1IAgr_CRnuE/00001 id04119/1uH67UruKlE/00001
|
| 490 |
+
id05594/0ohBiepcHWI/00001 id02317/0q4X8kPTlEY/00001
|
| 491 |
+
id01228/2TIFacjgehY/00001 id04119/1uH67UruKlE/00001
|
| 492 |
+
id02286/4LAIxvdvguc/00001 id02445/3Rnk8eja3TU/00001
|
| 493 |
+
id04030/7mXUMuo5_NE/00001 id00419/1zffAxBod_c/00001
|
| 494 |
+
id01298/2K5F6xG-Rbs/00001 id02445/3Rnk8eja3TU/00001
|
| 495 |
+
id07802/0RUpqvi3sPU/00001 id04862/0zJh2FMTaDE/00001
|
| 496 |
+
id04006/113VkmVVz1Q/00001 id03347/4xXZ75_TeSM/00001
|
| 497 |
+
id02317/0q4X8kPTlEY/00001 id05850/B8kp8ed48JE/00001
|
| 498 |
+
id08456/29EhSZDqzas/00001 id04656/1tZYt8jey54/00001
|
| 499 |
+
id04656/1tZYt8jey54/00001 id05816/1dyCBbJ94iw/00001
|
| 500 |
+
id05202/2gnLcAbAoSc/00001 id06209/2zM9EAPsZZQ/00001
|
face_detection/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
face_detection/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
__author__ = """Adrian Bulat"""
|
| 4 |
+
__email__ = 'adrian.bulat@nottingham.ac.uk'
|
| 5 |
+
__version__ = '1.0.1'
|
| 6 |
+
|
| 7 |
+
from .api import FaceAlignment, LandmarksType, NetworkSize
|
face_detection/api.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.model_zoo import load_url
|
| 5 |
+
from enum import Enum
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
try:
|
| 9 |
+
import urllib.request as request_file
|
| 10 |
+
except BaseException:
|
| 11 |
+
import urllib as request_file
|
| 12 |
+
|
| 13 |
+
from .models import FAN, ResNetDepth
|
| 14 |
+
from .utils import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LandmarksType(Enum):
|
| 18 |
+
"""Enum class defining the type of landmarks to detect.
|
| 19 |
+
|
| 20 |
+
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
| 21 |
+
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
| 22 |
+
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
_2D = 1
|
| 26 |
+
_2halfD = 2
|
| 27 |
+
_3D = 3
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class NetworkSize(Enum):
|
| 31 |
+
# TINY = 1
|
| 32 |
+
# SMALL = 2
|
| 33 |
+
# MEDIUM = 3
|
| 34 |
+
LARGE = 4
|
| 35 |
+
|
| 36 |
+
def __new__(cls, value):
|
| 37 |
+
member = object.__new__(cls)
|
| 38 |
+
member._value_ = value
|
| 39 |
+
return member
|
| 40 |
+
|
| 41 |
+
def __int__(self):
|
| 42 |
+
return self.value
|
| 43 |
+
|
| 44 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 45 |
+
|
| 46 |
+
class FaceAlignment:
|
| 47 |
+
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
| 48 |
+
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
| 49 |
+
self.device = device
|
| 50 |
+
self.flip_input = flip_input
|
| 51 |
+
self.landmarks_type = landmarks_type
|
| 52 |
+
self.verbose = verbose
|
| 53 |
+
|
| 54 |
+
network_size = int(network_size)
|
| 55 |
+
|
| 56 |
+
if 'cuda' in device:
|
| 57 |
+
torch.backends.cudnn.benchmark = True
|
| 58 |
+
|
| 59 |
+
# Get the face detector
|
| 60 |
+
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
| 61 |
+
globals(), locals(), [face_detector], 0)
|
| 62 |
+
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
| 63 |
+
|
| 64 |
+
def get_detections_for_batch(self, images):
|
| 65 |
+
images = images[..., ::-1]
|
| 66 |
+
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
| 67 |
+
results = []
|
| 68 |
+
|
| 69 |
+
for i, d in enumerate(detected_faces):
|
| 70 |
+
# print("Inside facedection:", i, len(d))
|
| 71 |
+
if len(d) == 0:
|
| 72 |
+
results.append(None)
|
| 73 |
+
continue
|
| 74 |
+
d = d[0]
|
| 75 |
+
d = np.clip(d, 0, None)
|
| 76 |
+
|
| 77 |
+
x1, y1, x2, y2 = map(int, d[:-1])
|
| 78 |
+
results.append((x1, y1, x2, y2))
|
| 79 |
+
|
| 80 |
+
return results
|
| 81 |
+
|
| 82 |
+
def get_all_detections_for_batch(self, images):
|
| 83 |
+
#for multiface facedetection
|
| 84 |
+
images = images[..., ::-1]
|
| 85 |
+
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
| 86 |
+
results = []
|
| 87 |
+
|
| 88 |
+
for i, d in enumerate(detected_faces):
|
| 89 |
+
# print("Inside facedection:", i, len(d))
|
| 90 |
+
if len(d) == 0:
|
| 91 |
+
results.append(None)
|
| 92 |
+
continue
|
| 93 |
+
d = [np.clip(dd, 0, None) for dd in d]
|
| 94 |
+
# d = [map(int, dd[:-1]) for dd in d]
|
| 95 |
+
d = [[int(ddd) for ddd in dd[:-1]] for dd in d]
|
| 96 |
+
results.append(d)
|
| 97 |
+
|
| 98 |
+
return results
|
face_detection/detection/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .core import FaceDetector
|
face_detection/detection/core.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import glob
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FaceDetector(object):
|
| 10 |
+
"""An abstract class representing a face detector.
|
| 11 |
+
|
| 12 |
+
Any other face detection implementation must subclass it. All subclasses
|
| 13 |
+
must implement ``detect_from_image``, that return a list of detected
|
| 14 |
+
bounding boxes. Optionally, for speed considerations detect from path is
|
| 15 |
+
recommended.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, device, verbose):
|
| 19 |
+
self.device = device
|
| 20 |
+
self.verbose = verbose
|
| 21 |
+
|
| 22 |
+
if verbose:
|
| 23 |
+
if 'cpu' in device:
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
logger.warning("Detection running on CPU, this may be potentially slow.")
|
| 26 |
+
|
| 27 |
+
if 'cpu' not in device and 'cuda' not in device:
|
| 28 |
+
if verbose:
|
| 29 |
+
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
| 30 |
+
raise ValueError
|
| 31 |
+
|
| 32 |
+
def detect_from_image(self, tensor_or_path):
|
| 33 |
+
"""Detects faces in a given image.
|
| 34 |
+
|
| 35 |
+
This function detects the faces present in a provided BGR(usually)
|
| 36 |
+
image. The input can be either the image itself or the path to it.
|
| 37 |
+
|
| 38 |
+
Arguments:
|
| 39 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
| 40 |
+
to an image or the image itself.
|
| 41 |
+
|
| 42 |
+
Example::
|
| 43 |
+
|
| 44 |
+
>>> path_to_image = 'data/image_01.jpg'
|
| 45 |
+
... detected_faces = detect_from_image(path_to_image)
|
| 46 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
| 47 |
+
>>> image = cv2.imread(path_to_image)
|
| 48 |
+
... detected_faces = detect_from_image(image)
|
| 49 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
| 55 |
+
"""Detects faces from all the images present in a given directory.
|
| 56 |
+
|
| 57 |
+
Arguments:
|
| 58 |
+
path {string} -- a string containing a path that points to the folder containing the images
|
| 59 |
+
|
| 60 |
+
Keyword Arguments:
|
| 61 |
+
extensions {list} -- list of string containing the extensions to be
|
| 62 |
+
consider in the following format: ``.extension_name`` (default:
|
| 63 |
+
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
| 64 |
+
folder recursively (default: {False}) show_progress_bar {bool} --
|
| 65 |
+
display a progressbar (default: {True})
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
>>> directory = 'data'
|
| 69 |
+
... detected_faces = detect_from_directory(directory)
|
| 70 |
+
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
if self.verbose:
|
| 74 |
+
logger = logging.getLogger(__name__)
|
| 75 |
+
|
| 76 |
+
if len(extensions) == 0:
|
| 77 |
+
if self.verbose:
|
| 78 |
+
logger.error("Expected at list one extension, but none was received.")
|
| 79 |
+
raise ValueError
|
| 80 |
+
|
| 81 |
+
if self.verbose:
|
| 82 |
+
logger.info("Constructing the list of images.")
|
| 83 |
+
additional_pattern = '/**/*' if recursive else '/*'
|
| 84 |
+
files = []
|
| 85 |
+
for extension in extensions:
|
| 86 |
+
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
| 87 |
+
|
| 88 |
+
if self.verbose:
|
| 89 |
+
logger.info("Finished searching for images. %s images found", len(files))
|
| 90 |
+
logger.info("Preparing to run the detection.")
|
| 91 |
+
|
| 92 |
+
predictions = {}
|
| 93 |
+
for image_path in tqdm(files, disable=not show_progress_bar):
|
| 94 |
+
if self.verbose:
|
| 95 |
+
logger.info("Running the face detector on image: %s", image_path)
|
| 96 |
+
predictions[image_path] = self.detect_from_image(image_path)
|
| 97 |
+
|
| 98 |
+
if self.verbose:
|
| 99 |
+
logger.info("The detector was successfully run on all %s images", len(files))
|
| 100 |
+
|
| 101 |
+
return predictions
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def reference_scale(self):
|
| 105 |
+
raise NotImplementedError
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def reference_x_shift(self):
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def reference_y_shift(self):
|
| 113 |
+
raise NotImplementedError
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
| 117 |
+
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
| 118 |
+
|
| 119 |
+
Arguments:
|
| 120 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
| 121 |
+
"""
|
| 122 |
+
if isinstance(tensor_or_path, str):
|
| 123 |
+
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
| 124 |
+
elif torch.is_tensor(tensor_or_path):
|
| 125 |
+
# Call cpu in case its coming from cuda
|
| 126 |
+
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
| 127 |
+
elif isinstance(tensor_or_path, np.ndarray):
|
| 128 |
+
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
| 129 |
+
else:
|
| 130 |
+
raise TypeError
|
face_detection/detection/sfd/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sfd_detector import SFDDetector as FaceDetector
|
face_detection/detection/sfd/bbox.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import cv2
|
| 5 |
+
import random
|
| 6 |
+
import datetime
|
| 7 |
+
import time
|
| 8 |
+
import math
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from iou import IOU
|
| 15 |
+
except BaseException:
|
| 16 |
+
# IOU cython speedup 10x
|
| 17 |
+
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
| 18 |
+
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
| 19 |
+
sb = abs((bx2 - bx1) * (by2 - by1))
|
| 20 |
+
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
| 21 |
+
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
| 22 |
+
w = x2 - x1
|
| 23 |
+
h = y2 - y1
|
| 24 |
+
if w < 0 or h < 0:
|
| 25 |
+
return 0.0
|
| 26 |
+
else:
|
| 27 |
+
return 1.0 * w * h / (sa + sb - w * h)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
| 31 |
+
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
| 32 |
+
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
| 33 |
+
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
| 34 |
+
return dx, dy, dw, dh
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
| 38 |
+
xc, yc = dx * aww + axc, dy * ahh + ayc
|
| 39 |
+
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
| 40 |
+
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
| 41 |
+
return x1, y1, x2, y2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def nms(dets, thresh):
|
| 45 |
+
if 0 == len(dets):
|
| 46 |
+
return []
|
| 47 |
+
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
| 48 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 49 |
+
order = scores.argsort()[::-1]
|
| 50 |
+
|
| 51 |
+
keep = []
|
| 52 |
+
while order.size > 0:
|
| 53 |
+
i = order[0]
|
| 54 |
+
keep.append(i)
|
| 55 |
+
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
| 56 |
+
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
| 57 |
+
|
| 58 |
+
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
| 59 |
+
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
| 60 |
+
|
| 61 |
+
inds = np.where(ovr <= thresh)[0]
|
| 62 |
+
order = order[inds + 1]
|
| 63 |
+
|
| 64 |
+
return keep
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def encode(matched, priors, variances):
|
| 68 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 69 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 70 |
+
Args:
|
| 71 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
| 72 |
+
Shape: [num_priors, 4].
|
| 73 |
+
priors: (tensor) Prior boxes in center-offset form
|
| 74 |
+
Shape: [num_priors,4].
|
| 75 |
+
variances: (list[float]) Variances of priorboxes
|
| 76 |
+
Return:
|
| 77 |
+
encoded boxes (tensor), Shape: [num_priors, 4]
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# dist b/t match center and prior's center
|
| 81 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
| 82 |
+
# encode variance
|
| 83 |
+
g_cxcy /= (variances[0] * priors[:, 2:])
|
| 84 |
+
# match wh / prior wh
|
| 85 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
| 86 |
+
g_wh = torch.log(g_wh) / variances[1]
|
| 87 |
+
# return target for smooth_l1_loss
|
| 88 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def decode(loc, priors, variances):
|
| 92 |
+
"""Decode locations from predictions using priors to undo
|
| 93 |
+
the encoding we did for offset regression at train time.
|
| 94 |
+
Args:
|
| 95 |
+
loc (tensor): location predictions for loc layers,
|
| 96 |
+
Shape: [num_priors,4]
|
| 97 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 98 |
+
Shape: [num_priors,4].
|
| 99 |
+
variances: (list[float]) Variances of priorboxes
|
| 100 |
+
Return:
|
| 101 |
+
decoded bounding box predictions
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
boxes = torch.cat((
|
| 105 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
| 106 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
| 107 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 108 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 109 |
+
return boxes
|
| 110 |
+
|
| 111 |
+
def batch_decode(loc, priors, variances):
|
| 112 |
+
"""Decode locations from predictions using priors to undo
|
| 113 |
+
the encoding we did for offset regression at train time.
|
| 114 |
+
Args:
|
| 115 |
+
loc (tensor): location predictions for loc layers,
|
| 116 |
+
Shape: [num_priors,4]
|
| 117 |
+
priors (tensor): Prior boxes in center-offset form.
|
| 118 |
+
Shape: [num_priors,4].
|
| 119 |
+
variances: (list[float]) Variances of priorboxes
|
| 120 |
+
Return:
|
| 121 |
+
decoded bounding box predictions
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
boxes = torch.cat((
|
| 125 |
+
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
| 126 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
| 127 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
| 128 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
| 129 |
+
return boxes
|
face_detection/detection/sfd/detect.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import cv2
|
| 7 |
+
import random
|
| 8 |
+
import datetime
|
| 9 |
+
import math
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import scipy.io as sio
|
| 14 |
+
import zipfile
|
| 15 |
+
from .net_s3fd import s3fd
|
| 16 |
+
from .bbox import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def detect(net, img, device):
|
| 20 |
+
img = img - np.array([104, 117, 123])
|
| 21 |
+
img = img.transpose(2, 0, 1)
|
| 22 |
+
img = img.reshape((1,) + img.shape)
|
| 23 |
+
|
| 24 |
+
if 'cuda' in device:
|
| 25 |
+
torch.backends.cudnn.benchmark = True
|
| 26 |
+
|
| 27 |
+
img = torch.from_numpy(img).float().to(device)
|
| 28 |
+
BB, CC, HH, WW = img.size()
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
olist = net(img)
|
| 31 |
+
|
| 32 |
+
bboxlist = []
|
| 33 |
+
for i in range(len(olist) // 2):
|
| 34 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
| 35 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
| 36 |
+
for i in range(len(olist) // 2):
|
| 37 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
| 38 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
| 39 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
| 40 |
+
anchor = stride * 4
|
| 41 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
| 42 |
+
for Iindex, hindex, windex in poss:
|
| 43 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
| 44 |
+
score = ocls[0, 1, hindex, windex]
|
| 45 |
+
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
| 46 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
| 47 |
+
variances = [0.1, 0.2]
|
| 48 |
+
box = decode(loc, priors, variances)
|
| 49 |
+
x1, y1, x2, y2 = box[0] * 1.0
|
| 50 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
| 51 |
+
bboxlist.append([x1, y1, x2, y2, score])
|
| 52 |
+
bboxlist = np.array(bboxlist)
|
| 53 |
+
if 0 == len(bboxlist):
|
| 54 |
+
bboxlist = np.zeros((1, 5))
|
| 55 |
+
|
| 56 |
+
return bboxlist
|
| 57 |
+
|
| 58 |
+
def batch_detect(net, imgs, device):
|
| 59 |
+
imgs = imgs - np.array([104, 117, 123])
|
| 60 |
+
imgs = imgs.transpose(0, 3, 1, 2)
|
| 61 |
+
|
| 62 |
+
if 'cuda' in device:
|
| 63 |
+
torch.backends.cudnn.benchmark = True
|
| 64 |
+
|
| 65 |
+
imgs = torch.from_numpy(imgs).float().to(device)
|
| 66 |
+
BB, CC, HH, WW = imgs.size()
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
olist = net(imgs)
|
| 69 |
+
|
| 70 |
+
bboxlist = []
|
| 71 |
+
for i in range(len(olist) // 2):
|
| 72 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
| 73 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
| 74 |
+
for i in range(len(olist) // 2):
|
| 75 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
| 76 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
| 77 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
| 78 |
+
anchor = stride * 4
|
| 79 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
| 80 |
+
for Iindex, hindex, windex in poss:
|
| 81 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
| 82 |
+
score = ocls[:, 1, hindex, windex]
|
| 83 |
+
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
| 84 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
| 85 |
+
variances = [0.1, 0.2]
|
| 86 |
+
box = batch_decode(loc, priors, variances)
|
| 87 |
+
box = box[:, 0] * 1.0
|
| 88 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
| 89 |
+
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
| 90 |
+
bboxlist = np.array(bboxlist)
|
| 91 |
+
if 0 == len(bboxlist):
|
| 92 |
+
bboxlist = np.zeros((1, BB, 5))
|
| 93 |
+
|
| 94 |
+
return bboxlist
|
| 95 |
+
|
| 96 |
+
def flip_detect(net, img, device):
|
| 97 |
+
img = cv2.flip(img, 1)
|
| 98 |
+
b = detect(net, img, device)
|
| 99 |
+
|
| 100 |
+
bboxlist = np.zeros(b.shape)
|
| 101 |
+
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
| 102 |
+
bboxlist[:, 1] = b[:, 1]
|
| 103 |
+
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
| 104 |
+
bboxlist[:, 3] = b[:, 3]
|
| 105 |
+
bboxlist[:, 4] = b[:, 4]
|
| 106 |
+
return bboxlist
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def pts_to_bb(pts):
|
| 110 |
+
min_x, min_y = np.min(pts, axis=0)
|
| 111 |
+
max_x, max_y = np.max(pts, axis=0)
|
| 112 |
+
return np.array([min_x, min_y, max_x, max_y])
|
face_detection/detection/sfd/net_s3fd.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class L2Norm(nn.Module):
|
| 7 |
+
def __init__(self, n_channels, scale=1.0):
|
| 8 |
+
super(L2Norm, self).__init__()
|
| 9 |
+
self.n_channels = n_channels
|
| 10 |
+
self.scale = scale
|
| 11 |
+
self.eps = 1e-10
|
| 12 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
| 13 |
+
self.weight.data *= 0.0
|
| 14 |
+
self.weight.data += self.scale
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
| 18 |
+
x = x / norm * self.weight.view(1, -1, 1, 1)
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class s3fd(nn.Module):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super(s3fd, self).__init__()
|
| 25 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
| 26 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 27 |
+
|
| 28 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
| 29 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
| 30 |
+
|
| 31 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
| 32 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 33 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
| 34 |
+
|
| 35 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
| 36 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 37 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 38 |
+
|
| 39 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 40 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 41 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
| 42 |
+
|
| 43 |
+
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
| 44 |
+
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
| 45 |
+
|
| 46 |
+
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
| 47 |
+
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
| 48 |
+
|
| 49 |
+
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
| 50 |
+
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
| 51 |
+
|
| 52 |
+
self.conv3_3_norm = L2Norm(256, scale=10)
|
| 53 |
+
self.conv4_3_norm = L2Norm(512, scale=8)
|
| 54 |
+
self.conv5_3_norm = L2Norm(512, scale=5)
|
| 55 |
+
|
| 56 |
+
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 57 |
+
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 58 |
+
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 59 |
+
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 60 |
+
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 61 |
+
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 62 |
+
|
| 63 |
+
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
| 64 |
+
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
| 65 |
+
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
| 66 |
+
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
| 67 |
+
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
| 68 |
+
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
h = F.relu(self.conv1_1(x))
|
| 72 |
+
h = F.relu(self.conv1_2(h))
|
| 73 |
+
h = F.max_pool2d(h, 2, 2)
|
| 74 |
+
|
| 75 |
+
h = F.relu(self.conv2_1(h))
|
| 76 |
+
h = F.relu(self.conv2_2(h))
|
| 77 |
+
h = F.max_pool2d(h, 2, 2)
|
| 78 |
+
|
| 79 |
+
h = F.relu(self.conv3_1(h))
|
| 80 |
+
h = F.relu(self.conv3_2(h))
|
| 81 |
+
h = F.relu(self.conv3_3(h))
|
| 82 |
+
f3_3 = h
|
| 83 |
+
h = F.max_pool2d(h, 2, 2)
|
| 84 |
+
|
| 85 |
+
h = F.relu(self.conv4_1(h))
|
| 86 |
+
h = F.relu(self.conv4_2(h))
|
| 87 |
+
h = F.relu(self.conv4_3(h))
|
| 88 |
+
f4_3 = h
|
| 89 |
+
h = F.max_pool2d(h, 2, 2)
|
| 90 |
+
|
| 91 |
+
h = F.relu(self.conv5_1(h))
|
| 92 |
+
h = F.relu(self.conv5_2(h))
|
| 93 |
+
h = F.relu(self.conv5_3(h))
|
| 94 |
+
f5_3 = h
|
| 95 |
+
h = F.max_pool2d(h, 2, 2)
|
| 96 |
+
|
| 97 |
+
h = F.relu(self.fc6(h))
|
| 98 |
+
h = F.relu(self.fc7(h))
|
| 99 |
+
ffc7 = h
|
| 100 |
+
h = F.relu(self.conv6_1(h))
|
| 101 |
+
h = F.relu(self.conv6_2(h))
|
| 102 |
+
f6_2 = h
|
| 103 |
+
h = F.relu(self.conv7_1(h))
|
| 104 |
+
h = F.relu(self.conv7_2(h))
|
| 105 |
+
f7_2 = h
|
| 106 |
+
|
| 107 |
+
f3_3 = self.conv3_3_norm(f3_3)
|
| 108 |
+
f4_3 = self.conv4_3_norm(f4_3)
|
| 109 |
+
f5_3 = self.conv5_3_norm(f5_3)
|
| 110 |
+
|
| 111 |
+
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
| 112 |
+
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
| 113 |
+
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
| 114 |
+
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
| 115 |
+
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
| 116 |
+
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
| 117 |
+
cls4 = self.fc7_mbox_conf(ffc7)
|
| 118 |
+
reg4 = self.fc7_mbox_loc(ffc7)
|
| 119 |
+
cls5 = self.conv6_2_mbox_conf(f6_2)
|
| 120 |
+
reg5 = self.conv6_2_mbox_loc(f6_2)
|
| 121 |
+
cls6 = self.conv7_2_mbox_conf(f7_2)
|
| 122 |
+
reg6 = self.conv7_2_mbox_loc(f7_2)
|
| 123 |
+
|
| 124 |
+
# max-out background label
|
| 125 |
+
chunk = torch.chunk(cls1, 4, 1)
|
| 126 |
+
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
| 127 |
+
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
| 128 |
+
|
| 129 |
+
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
face_detection/detection/sfd/sfd_detector.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
from torch.utils.model_zoo import load_url
|
| 4 |
+
|
| 5 |
+
from ..core import FaceDetector
|
| 6 |
+
|
| 7 |
+
from .net_s3fd import s3fd
|
| 8 |
+
from .bbox import *
|
| 9 |
+
from .detect import *
|
| 10 |
+
|
| 11 |
+
models_urls = {
|
| 12 |
+
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SFDDetector(FaceDetector):
|
| 17 |
+
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
| 18 |
+
super(SFDDetector, self).__init__(device, verbose)
|
| 19 |
+
|
| 20 |
+
# Initialise the face detector
|
| 21 |
+
if not os.path.isfile(path_to_detector):
|
| 22 |
+
model_weights = load_url(models_urls['s3fd'])
|
| 23 |
+
else:
|
| 24 |
+
model_weights = torch.load(path_to_detector)
|
| 25 |
+
|
| 26 |
+
self.face_detector = s3fd()
|
| 27 |
+
self.face_detector.load_state_dict(model_weights)
|
| 28 |
+
self.face_detector.to(device)
|
| 29 |
+
self.face_detector.eval()
|
| 30 |
+
|
| 31 |
+
def detect_from_image(self, tensor_or_path):
|
| 32 |
+
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
| 33 |
+
|
| 34 |
+
bboxlist = detect(self.face_detector, image, device=self.device)
|
| 35 |
+
keep = nms(bboxlist, 0.3)
|
| 36 |
+
bboxlist = bboxlist[keep, :]
|
| 37 |
+
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
| 38 |
+
|
| 39 |
+
return bboxlist
|
| 40 |
+
|
| 41 |
+
def detect_from_batch(self, images):
|
| 42 |
+
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
| 43 |
+
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
| 44 |
+
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
| 45 |
+
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
| 46 |
+
|
| 47 |
+
return bboxlists
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def reference_scale(self):
|
| 51 |
+
return 195
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def reference_x_shift(self):
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def reference_y_shift(self):
|
| 59 |
+
return 0
|
face_detection/models.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
| 8 |
+
"3x3 convolution with padding"
|
| 9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
| 10 |
+
stride=strd, padding=padding, bias=bias)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConvBlock(nn.Module):
|
| 14 |
+
def __init__(self, in_planes, out_planes):
|
| 15 |
+
super(ConvBlock, self).__init__()
|
| 16 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
| 17 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
| 18 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
| 19 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
| 20 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
| 21 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
| 22 |
+
|
| 23 |
+
if in_planes != out_planes:
|
| 24 |
+
self.downsample = nn.Sequential(
|
| 25 |
+
nn.BatchNorm2d(in_planes),
|
| 26 |
+
nn.ReLU(True),
|
| 27 |
+
nn.Conv2d(in_planes, out_planes,
|
| 28 |
+
kernel_size=1, stride=1, bias=False),
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
self.downsample = None
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
residual = x
|
| 35 |
+
|
| 36 |
+
out1 = self.bn1(x)
|
| 37 |
+
out1 = F.relu(out1, True)
|
| 38 |
+
out1 = self.conv1(out1)
|
| 39 |
+
|
| 40 |
+
out2 = self.bn2(out1)
|
| 41 |
+
out2 = F.relu(out2, True)
|
| 42 |
+
out2 = self.conv2(out2)
|
| 43 |
+
|
| 44 |
+
out3 = self.bn3(out2)
|
| 45 |
+
out3 = F.relu(out3, True)
|
| 46 |
+
out3 = self.conv3(out3)
|
| 47 |
+
|
| 48 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
residual = self.downsample(residual)
|
| 52 |
+
|
| 53 |
+
out3 += residual
|
| 54 |
+
|
| 55 |
+
return out3
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Bottleneck(nn.Module):
|
| 59 |
+
|
| 60 |
+
expansion = 4
|
| 61 |
+
|
| 62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 63 |
+
super(Bottleneck, self).__init__()
|
| 64 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 65 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 66 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 67 |
+
padding=1, bias=False)
|
| 68 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 69 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 70 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
| 71 |
+
self.relu = nn.ReLU(inplace=True)
|
| 72 |
+
self.downsample = downsample
|
| 73 |
+
self.stride = stride
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
residual = x
|
| 77 |
+
|
| 78 |
+
out = self.conv1(x)
|
| 79 |
+
out = self.bn1(out)
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
out = self.conv2(out)
|
| 83 |
+
out = self.bn2(out)
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
out = self.conv3(out)
|
| 87 |
+
out = self.bn3(out)
|
| 88 |
+
|
| 89 |
+
if self.downsample is not None:
|
| 90 |
+
residual = self.downsample(x)
|
| 91 |
+
|
| 92 |
+
out += residual
|
| 93 |
+
out = self.relu(out)
|
| 94 |
+
|
| 95 |
+
return out
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class HourGlass(nn.Module):
|
| 99 |
+
def __init__(self, num_modules, depth, num_features):
|
| 100 |
+
super(HourGlass, self).__init__()
|
| 101 |
+
self.num_modules = num_modules
|
| 102 |
+
self.depth = depth
|
| 103 |
+
self.features = num_features
|
| 104 |
+
|
| 105 |
+
self._generate_network(self.depth)
|
| 106 |
+
|
| 107 |
+
def _generate_network(self, level):
|
| 108 |
+
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
| 109 |
+
|
| 110 |
+
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
| 111 |
+
|
| 112 |
+
if level > 1:
|
| 113 |
+
self._generate_network(level - 1)
|
| 114 |
+
else:
|
| 115 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
| 116 |
+
|
| 117 |
+
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
| 118 |
+
|
| 119 |
+
def _forward(self, level, inp):
|
| 120 |
+
# Upper branch
|
| 121 |
+
up1 = inp
|
| 122 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
| 123 |
+
|
| 124 |
+
# Lower branch
|
| 125 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
| 126 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
| 127 |
+
|
| 128 |
+
if level > 1:
|
| 129 |
+
low2 = self._forward(level - 1, low1)
|
| 130 |
+
else:
|
| 131 |
+
low2 = low1
|
| 132 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
| 133 |
+
|
| 134 |
+
low3 = low2
|
| 135 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
| 136 |
+
|
| 137 |
+
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
| 138 |
+
|
| 139 |
+
return up1 + up2
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
return self._forward(self.depth, x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class FAN(nn.Module):
|
| 146 |
+
|
| 147 |
+
def __init__(self, num_modules=1):
|
| 148 |
+
super(FAN, self).__init__()
|
| 149 |
+
self.num_modules = num_modules
|
| 150 |
+
|
| 151 |
+
# Base part
|
| 152 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 153 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 154 |
+
self.conv2 = ConvBlock(64, 128)
|
| 155 |
+
self.conv3 = ConvBlock(128, 128)
|
| 156 |
+
self.conv4 = ConvBlock(128, 256)
|
| 157 |
+
|
| 158 |
+
# Stacking part
|
| 159 |
+
for hg_module in range(self.num_modules):
|
| 160 |
+
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
| 161 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
| 162 |
+
self.add_module('conv_last' + str(hg_module),
|
| 163 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
| 164 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
| 165 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
| 166 |
+
68, kernel_size=1, stride=1, padding=0))
|
| 167 |
+
|
| 168 |
+
if hg_module < self.num_modules - 1:
|
| 169 |
+
self.add_module(
|
| 170 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
| 171 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
| 172 |
+
256, kernel_size=1, stride=1, padding=0))
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
| 176 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
| 177 |
+
x = self.conv3(x)
|
| 178 |
+
x = self.conv4(x)
|
| 179 |
+
|
| 180 |
+
previous = x
|
| 181 |
+
|
| 182 |
+
outputs = []
|
| 183 |
+
for i in range(self.num_modules):
|
| 184 |
+
hg = self._modules['m' + str(i)](previous)
|
| 185 |
+
|
| 186 |
+
ll = hg
|
| 187 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
| 188 |
+
|
| 189 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
| 190 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
| 191 |
+
|
| 192 |
+
# Predict heatmaps
|
| 193 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
| 194 |
+
outputs.append(tmp_out)
|
| 195 |
+
|
| 196 |
+
if i < self.num_modules - 1:
|
| 197 |
+
ll = self._modules['bl' + str(i)](ll)
|
| 198 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
| 199 |
+
previous = previous + ll + tmp_out_
|
| 200 |
+
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class ResNetDepth(nn.Module):
|
| 205 |
+
|
| 206 |
+
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
| 207 |
+
self.inplanes = 64
|
| 208 |
+
super(ResNetDepth, self).__init__()
|
| 209 |
+
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
| 210 |
+
bias=False)
|
| 211 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 212 |
+
self.relu = nn.ReLU(inplace=True)
|
| 213 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 214 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 215 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 216 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 217 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 218 |
+
self.avgpool = nn.AvgPool2d(7)
|
| 219 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 220 |
+
|
| 221 |
+
for m in self.modules():
|
| 222 |
+
if isinstance(m, nn.Conv2d):
|
| 223 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 224 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 225 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 226 |
+
m.weight.data.fill_(1)
|
| 227 |
+
m.bias.data.zero_()
|
| 228 |
+
|
| 229 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 230 |
+
downsample = None
|
| 231 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 232 |
+
downsample = nn.Sequential(
|
| 233 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 234 |
+
kernel_size=1, stride=stride, bias=False),
|
| 235 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
layers = []
|
| 239 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 240 |
+
self.inplanes = planes * block.expansion
|
| 241 |
+
for i in range(1, blocks):
|
| 242 |
+
layers.append(block(self.inplanes, planes))
|
| 243 |
+
|
| 244 |
+
return nn.Sequential(*layers)
|
| 245 |
+
|
| 246 |
+
def forward(self, x):
|
| 247 |
+
x = self.conv1(x)
|
| 248 |
+
x = self.bn1(x)
|
| 249 |
+
x = self.relu(x)
|
| 250 |
+
x = self.maxpool(x)
|
| 251 |
+
|
| 252 |
+
x = self.layer1(x)
|
| 253 |
+
x = self.layer2(x)
|
| 254 |
+
x = self.layer3(x)
|
| 255 |
+
x = self.layer4(x)
|
| 256 |
+
|
| 257 |
+
x = self.avgpool(x)
|
| 258 |
+
x = x.view(x.size(0), -1)
|
| 259 |
+
x = self.fc(x)
|
| 260 |
+
|
| 261 |
+
return x
|
face_detection/utils.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _gaussian(
|
| 12 |
+
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
| 13 |
+
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
| 14 |
+
mean_vert=0.5):
|
| 15 |
+
# handle some defaults
|
| 16 |
+
if width is None:
|
| 17 |
+
width = size
|
| 18 |
+
if height is None:
|
| 19 |
+
height = size
|
| 20 |
+
if sigma_horz is None:
|
| 21 |
+
sigma_horz = sigma
|
| 22 |
+
if sigma_vert is None:
|
| 23 |
+
sigma_vert = sigma
|
| 24 |
+
center_x = mean_horz * width + 0.5
|
| 25 |
+
center_y = mean_vert * height + 0.5
|
| 26 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
| 27 |
+
# generate kernel
|
| 28 |
+
for i in range(height):
|
| 29 |
+
for j in range(width):
|
| 30 |
+
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
| 31 |
+
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
| 32 |
+
if normalize:
|
| 33 |
+
gauss = gauss / np.sum(gauss)
|
| 34 |
+
return gauss
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def draw_gaussian(image, point, sigma):
|
| 38 |
+
# Check if the gaussian is inside
|
| 39 |
+
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
| 40 |
+
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
| 41 |
+
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
| 42 |
+
return image
|
| 43 |
+
size = 6 * sigma + 1
|
| 44 |
+
g = _gaussian(size)
|
| 45 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
| 46 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
| 47 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
| 48 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
| 49 |
+
assert (g_x[0] > 0 and g_y[1] > 0)
|
| 50 |
+
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
| 51 |
+
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
| 52 |
+
image[image > 1] = 1
|
| 53 |
+
return image
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def transform(point, center, scale, resolution, invert=False):
|
| 57 |
+
"""Generate and affine transformation matrix.
|
| 58 |
+
|
| 59 |
+
Given a set of points, a center, a scale and a targer resolution, the
|
| 60 |
+
function generates and affine transformation matrix. If invert is ``True``
|
| 61 |
+
it will produce the inverse transformation.
|
| 62 |
+
|
| 63 |
+
Arguments:
|
| 64 |
+
point {torch.tensor} -- the input 2D point
|
| 65 |
+
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
| 66 |
+
scale {float} -- the scale of the face/object
|
| 67 |
+
resolution {float} -- the output resolution
|
| 68 |
+
|
| 69 |
+
Keyword Arguments:
|
| 70 |
+
invert {bool} -- define wherever the function should produce the direct or the
|
| 71 |
+
inverse transformation matrix (default: {False})
|
| 72 |
+
"""
|
| 73 |
+
_pt = torch.ones(3)
|
| 74 |
+
_pt[0] = point[0]
|
| 75 |
+
_pt[1] = point[1]
|
| 76 |
+
|
| 77 |
+
h = 200.0 * scale
|
| 78 |
+
t = torch.eye(3)
|
| 79 |
+
t[0, 0] = resolution / h
|
| 80 |
+
t[1, 1] = resolution / h
|
| 81 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
| 82 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
| 83 |
+
|
| 84 |
+
if invert:
|
| 85 |
+
t = torch.inverse(t)
|
| 86 |
+
|
| 87 |
+
new_point = (torch.matmul(t, _pt))[0:2]
|
| 88 |
+
|
| 89 |
+
return new_point.int()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def crop(image, center, scale, resolution=256.0):
|
| 93 |
+
"""Center crops an image or set of heatmaps
|
| 94 |
+
|
| 95 |
+
Arguments:
|
| 96 |
+
image {numpy.array} -- an rgb image
|
| 97 |
+
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
| 98 |
+
scale {float} -- scale of the face
|
| 99 |
+
|
| 100 |
+
Keyword Arguments:
|
| 101 |
+
resolution {float} -- the size of the output cropped image (default: {256.0})
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
[type] -- [description]
|
| 105 |
+
""" # Crop around the center point
|
| 106 |
+
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
| 107 |
+
ul = transform([1, 1], center, scale, resolution, True)
|
| 108 |
+
br = transform([resolution, resolution], center, scale, resolution, True)
|
| 109 |
+
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
| 110 |
+
if image.ndim > 2:
|
| 111 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
| 112 |
+
image.shape[2]], dtype=np.int32)
|
| 113 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
| 114 |
+
else:
|
| 115 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
| 116 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
| 117 |
+
ht = image.shape[0]
|
| 118 |
+
wd = image.shape[1]
|
| 119 |
+
newX = np.array(
|
| 120 |
+
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
| 121 |
+
newY = np.array(
|
| 122 |
+
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
| 123 |
+
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
| 124 |
+
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
| 125 |
+
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
| 126 |
+
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
| 127 |
+
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
| 128 |
+
interpolation=cv2.INTER_LINEAR)
|
| 129 |
+
return newImg
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_preds_fromhm(hm, center=None, scale=None):
|
| 133 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
| 134 |
+
and the scale is provided the function will return the points also in
|
| 135 |
+
the original coordinate frame.
|
| 136 |
+
|
| 137 |
+
Arguments:
|
| 138 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
| 139 |
+
|
| 140 |
+
Keyword Arguments:
|
| 141 |
+
center {torch.tensor} -- the center of the bounding box (default: {None})
|
| 142 |
+
scale {float} -- face scale (default: {None})
|
| 143 |
+
"""
|
| 144 |
+
max, idx = torch.max(
|
| 145 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
| 146 |
+
idx += 1
|
| 147 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
| 148 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
| 149 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
| 150 |
+
|
| 151 |
+
for i in range(preds.size(0)):
|
| 152 |
+
for j in range(preds.size(1)):
|
| 153 |
+
hm_ = hm[i, j, :]
|
| 154 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
| 155 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
| 156 |
+
diff = torch.FloatTensor(
|
| 157 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
| 158 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
| 159 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
| 160 |
+
|
| 161 |
+
preds.add_(-.5)
|
| 162 |
+
|
| 163 |
+
preds_orig = torch.zeros(preds.size())
|
| 164 |
+
if center is not None and scale is not None:
|
| 165 |
+
for i in range(hm.size(0)):
|
| 166 |
+
for j in range(hm.size(1)):
|
| 167 |
+
preds_orig[i, j] = transform(
|
| 168 |
+
preds[i, j], center, scale, hm.size(2), True)
|
| 169 |
+
|
| 170 |
+
return preds, preds_orig
|
| 171 |
+
|
| 172 |
+
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
| 173 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
| 174 |
+
and the scales is provided the function will return the points also in
|
| 175 |
+
the original coordinate frame.
|
| 176 |
+
|
| 177 |
+
Arguments:
|
| 178 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
| 179 |
+
|
| 180 |
+
Keyword Arguments:
|
| 181 |
+
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
| 182 |
+
scales {float} -- face scales (default: {None})
|
| 183 |
+
"""
|
| 184 |
+
max, idx = torch.max(
|
| 185 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
| 186 |
+
idx += 1
|
| 187 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
| 188 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
| 189 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
| 190 |
+
|
| 191 |
+
for i in range(preds.size(0)):
|
| 192 |
+
for j in range(preds.size(1)):
|
| 193 |
+
hm_ = hm[i, j, :]
|
| 194 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
| 195 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
| 196 |
+
diff = torch.FloatTensor(
|
| 197 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
| 198 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
| 199 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
| 200 |
+
|
| 201 |
+
preds.add_(-.5)
|
| 202 |
+
|
| 203 |
+
preds_orig = torch.zeros(preds.size())
|
| 204 |
+
if centers is not None and scales is not None:
|
| 205 |
+
for i in range(hm.size(0)):
|
| 206 |
+
for j in range(hm.size(1)):
|
| 207 |
+
preds_orig[i, j] = transform(
|
| 208 |
+
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
| 209 |
+
|
| 210 |
+
return preds, preds_orig
|
| 211 |
+
|
| 212 |
+
def shuffle_lr(parts, pairs=None):
|
| 213 |
+
"""Shuffle the points left-right according to the axis of symmetry
|
| 214 |
+
of the object.
|
| 215 |
+
|
| 216 |
+
Arguments:
|
| 217 |
+
parts {torch.tensor} -- a 3D or 4D object containing the
|
| 218 |
+
heatmaps.
|
| 219 |
+
|
| 220 |
+
Keyword Arguments:
|
| 221 |
+
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
| 222 |
+
"""
|
| 223 |
+
if pairs is None:
|
| 224 |
+
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
| 225 |
+
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
| 226 |
+
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
| 227 |
+
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
| 228 |
+
62, 61, 60, 67, 66, 65]
|
| 229 |
+
if parts.ndimension() == 3:
|
| 230 |
+
parts = parts[pairs, ...]
|
| 231 |
+
else:
|
| 232 |
+
parts = parts[:, pairs, ...]
|
| 233 |
+
|
| 234 |
+
return parts
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def flip(tensor, is_label=False):
|
| 238 |
+
"""Flip an image or a set of heatmaps left-right
|
| 239 |
+
|
| 240 |
+
Arguments:
|
| 241 |
+
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
| 242 |
+
|
| 243 |
+
Keyword Arguments:
|
| 244 |
+
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
| 245 |
+
"""
|
| 246 |
+
if not torch.is_tensor(tensor):
|
| 247 |
+
tensor = torch.from_numpy(tensor)
|
| 248 |
+
|
| 249 |
+
if is_label:
|
| 250 |
+
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
| 251 |
+
else:
|
| 252 |
+
tensor = tensor.flip(tensor.ndimension() - 1)
|
| 253 |
+
|
| 254 |
+
return tensor
|
| 255 |
+
|
| 256 |
+
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def appdata_dir(appname=None, roaming=False):
|
| 260 |
+
""" appdata_dir(appname=None, roaming=False)
|
| 261 |
+
|
| 262 |
+
Get the path to the application directory, where applications are allowed
|
| 263 |
+
to write user specific files (e.g. configurations). For non-user specific
|
| 264 |
+
data, consider using common_appdata_dir().
|
| 265 |
+
If appname is given, a subdir is appended (and created if necessary).
|
| 266 |
+
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
# Define default user directory
|
| 270 |
+
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
| 271 |
+
if userDir is None:
|
| 272 |
+
userDir = os.path.expanduser('~')
|
| 273 |
+
if not os.path.isdir(userDir): # pragma: no cover
|
| 274 |
+
userDir = '/var/tmp' # issue #54
|
| 275 |
+
|
| 276 |
+
# Get system app data dir
|
| 277 |
+
path = None
|
| 278 |
+
if sys.platform.startswith('win'):
|
| 279 |
+
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
| 280 |
+
path = (path2 or path1) if roaming else (path1 or path2)
|
| 281 |
+
elif sys.platform.startswith('darwin'):
|
| 282 |
+
path = os.path.join(userDir, 'Library', 'Application Support')
|
| 283 |
+
# On Linux and as fallback
|
| 284 |
+
if not (path and os.path.isdir(path)):
|
| 285 |
+
path = userDir
|
| 286 |
+
|
| 287 |
+
# Maybe we should store things local to the executable (in case of a
|
| 288 |
+
# portable distro or a frozen application that wants to be portable)
|
| 289 |
+
prefix = sys.prefix
|
| 290 |
+
if getattr(sys, 'frozen', None):
|
| 291 |
+
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
| 292 |
+
for reldir in ('settings', '../settings'):
|
| 293 |
+
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
| 294 |
+
if os.path.isdir(localpath): # pragma: no cover
|
| 295 |
+
try:
|
| 296 |
+
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
| 297 |
+
os.remove(os.path.join(localpath, 'test.write'))
|
| 298 |
+
except IOError:
|
| 299 |
+
pass # We cannot write in this directory
|
| 300 |
+
else:
|
| 301 |
+
path = localpath
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
# Get path specific for this app
|
| 305 |
+
if appname:
|
| 306 |
+
if path == userDir:
|
| 307 |
+
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
| 308 |
+
path = os.path.join(path, appname)
|
| 309 |
+
if not os.path.isdir(path): # pragma: no cover
|
| 310 |
+
os.mkdir(path)
|
| 311 |
+
|
| 312 |
+
# Done
|
| 313 |
+
return path
|
generate.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
''' consistent initial noise for video generation'''
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
from os.path import join, basename, dirname, splitext
|
| 5 |
+
import shutil
|
| 6 |
+
import argparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
import torch, torchvision
|
| 10 |
+
import subprocess
|
| 11 |
+
from audio import audio
|
| 12 |
+
import face_detection
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from guided_diffusion import dist_util, logger
|
| 16 |
+
from guided_diffusion.resample import create_named_schedule_sampler
|
| 17 |
+
from guided_diffusion.script_util import (
|
| 18 |
+
tfg_model_and_diffusion_defaults,
|
| 19 |
+
tfg_create_model_and_diffusion,
|
| 20 |
+
args_to_dict,
|
| 21 |
+
add_dict_to_argparser,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from guided_diffusion.tfg_data_util import (
|
| 25 |
+
tfg_process_batch,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def get_frame_id(frame):
|
| 29 |
+
return int(basename(frame).split('.')[0])
|
| 30 |
+
|
| 31 |
+
def crop_audio_window(spec, start_frame, args ):
|
| 32 |
+
if type(start_frame) == int:
|
| 33 |
+
start_frame_num = start_frame
|
| 34 |
+
else:
|
| 35 |
+
start_frame_num = get_frame_id(start_frame)
|
| 36 |
+
start_idx = int(args.mel_steps_per_sec * (start_frame_num / float(args.video_fps)))
|
| 37 |
+
end_idx = start_idx + args.syncnet_mel_step_size
|
| 38 |
+
return spec[start_idx : end_idx, :]
|
| 39 |
+
|
| 40 |
+
def load_all_indiv_mels(path, args):
|
| 41 |
+
in_path = path
|
| 42 |
+
out_dir = join(args.sample_path, "temp", basename(in_path).replace(".mp4", ""))
|
| 43 |
+
os.makedirs(out_dir, exist_ok= True)
|
| 44 |
+
out_path = join(out_dir, "audio.wav")
|
| 45 |
+
command2 = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(in_path, out_path)
|
| 46 |
+
subprocess.call(command2, shell=True)
|
| 47 |
+
wav = audio.load_wav(out_path, args.sample_rate)
|
| 48 |
+
orig_mel = audio.melspectrogram(wav).T
|
| 49 |
+
|
| 50 |
+
all_indiv_mels = []
|
| 51 |
+
# i=0
|
| 52 |
+
i=1
|
| 53 |
+
while True:
|
| 54 |
+
m = crop_audio_window(orig_mel.copy(), max(i - args.syncnet_T//2,0), args)
|
| 55 |
+
if (m.shape[0] != args.syncnet_mel_step_size):
|
| 56 |
+
break
|
| 57 |
+
all_indiv_mels.append(m.T)
|
| 58 |
+
i+=1
|
| 59 |
+
|
| 60 |
+
#clean up
|
| 61 |
+
shutil.rmtree(join(args.sample_path, "temp"))
|
| 62 |
+
|
| 63 |
+
return all_indiv_mels, wav
|
| 64 |
+
|
| 65 |
+
def load_video_frames(path, args):
|
| 66 |
+
in_path = path
|
| 67 |
+
out_dir = join(args.sample_path, "temp", basename(in_path).replace(".mp4", ""), "image")
|
| 68 |
+
os.makedirs(out_dir, exist_ok= True)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
command = "ffmpeg -loglevel error -y -i {} -vf fps={} -q:v 2 -qmin 1 {}/%05d.jpg".format(in_path, args.video_fps, out_dir)
|
| 72 |
+
subprocess.call(command, shell=True)
|
| 73 |
+
|
| 74 |
+
video_frames=[]
|
| 75 |
+
for i, img_name in enumerate(sorted(os.listdir(out_dir))):
|
| 76 |
+
img_path=join(out_dir, img_name)
|
| 77 |
+
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
|
| 78 |
+
video_frames.append(img)
|
| 79 |
+
|
| 80 |
+
#clean up
|
| 81 |
+
shutil.rmtree(join(args.sample_path, "temp"))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
return video_frames
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_smoothened_boxes(boxes, T):
|
| 88 |
+
for i in range(len(boxes)):
|
| 89 |
+
if i + T > len(boxes):
|
| 90 |
+
window = boxes[len(boxes) - T:]
|
| 91 |
+
else:
|
| 92 |
+
window = boxes[i : i + T]
|
| 93 |
+
boxes[i] = np.mean(window, axis=0)
|
| 94 |
+
return boxes
|
| 95 |
+
|
| 96 |
+
def my_voxceleb2_crop(img):
|
| 97 |
+
return img[:-int(img.shape[0]*2.36/8) , int(img.shape[1]*1.8/8): -int(img.shape[1]*1.8/8)]
|
| 98 |
+
|
| 99 |
+
def my_voxceleb2_crop_bboxs(img):
|
| 100 |
+
return 0,img.shape[0]-int(img.shape[0]*2.36/8), int(img.shape[1]*1.8/8), img.shape[1]-int(img.shape[1]*1.8/8)
|
| 101 |
+
|
| 102 |
+
def face_detect(images, detector, args, resize=False):
|
| 103 |
+
batch_size = args.face_det_batch_size
|
| 104 |
+
|
| 105 |
+
while 1:
|
| 106 |
+
predictions = []
|
| 107 |
+
try:
|
| 108 |
+
for i in range(0, len(images), batch_size):
|
| 109 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
| 110 |
+
except RuntimeError:
|
| 111 |
+
if batch_size == 1:
|
| 112 |
+
raise RuntimeError('Image too big to run face detection on GPU')
|
| 113 |
+
batch_size //= 2
|
| 114 |
+
args.face_det_batch_size = batch_size
|
| 115 |
+
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
| 116 |
+
continue
|
| 117 |
+
break
|
| 118 |
+
|
| 119 |
+
results = []
|
| 120 |
+
if type(args.pads) == str :
|
| 121 |
+
args.pads = [int(x) for x in args.pads.split(",")]
|
| 122 |
+
pady1, pady2, padx1, padx2 = args.pads
|
| 123 |
+
for rect, image in zip(predictions, images):
|
| 124 |
+
if rect is None:
|
| 125 |
+
raise ValueError('Face not detected!')
|
| 126 |
+
|
| 127 |
+
y1 = max(0, rect[1] - pady1)
|
| 128 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
| 129 |
+
x1 = max(0, rect[0] - padx1)
|
| 130 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
| 131 |
+
|
| 132 |
+
results.append([x1, y1, x2, y2])
|
| 133 |
+
|
| 134 |
+
boxes = get_smoothened_boxes(np.array(results), T=5)
|
| 135 |
+
|
| 136 |
+
if resize:
|
| 137 |
+
if args.is_voxceleb2:
|
| 138 |
+
results = [[cv2.resize(my_voxceleb2_crop(image),(args.image_size, args.image_size)), my_voxceleb2_crop_bboxs(image), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 139 |
+
else:
|
| 140 |
+
results = [[cv2.resize(image[y1: y2, x1:x2],(args.image_size, args.image_size)), (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 141 |
+
else:
|
| 142 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 143 |
+
return results
|
| 144 |
+
|
| 145 |
+
def normalise(tensor):
|
| 146 |
+
""" [-1,1]->[0,1]"""
|
| 147 |
+
return ((tensor+1)*0.5).clamp(0,1)
|
| 148 |
+
|
| 149 |
+
def normalise2(tensor):
|
| 150 |
+
""" [0,1]->[-1,1]"""
|
| 151 |
+
return (tensor*2-1).clamp(-1,1)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def sample_batch(batch, model, diffusion, args):
|
| 155 |
+
B, F, C, H, W = batch[f'image'].shape
|
| 156 |
+
sample_shape = (B*F, C, H, W)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
#generate fixed noise
|
| 160 |
+
init_noise = None
|
| 161 |
+
if args.sampling_seed:
|
| 162 |
+
|
| 163 |
+
state = torch.get_rng_state()
|
| 164 |
+
torch.manual_seed(args.sampling_seed)
|
| 165 |
+
torch.cuda.manual_seed_all(args.sampling_seed)
|
| 166 |
+
init_noise = torch.randn((1,C,H,W))
|
| 167 |
+
#repeat noise for all frames
|
| 168 |
+
init_noise = init_noise.repeat(B*F,1,1,1)
|
| 169 |
+
torch.set_rng_state(state)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
img_batch, model_kwargs = tfg_process_batch(batch, args.face_hide_percentage,
|
| 173 |
+
use_ref=args.use_ref,
|
| 174 |
+
use_audio=args.use_audio,
|
| 175 |
+
# sampling_use_gt_for_ref=args.sampling_use_gt_for_ref,
|
| 176 |
+
noise=init_noise)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
img_batch = img_batch.to(dist_util.dev())
|
| 180 |
+
model_kwargs = {k: v.to(dist_util.dev()) for k,v in model_kwargs.items()}
|
| 181 |
+
init_noise = init_noise.to(dist_util.dev()) if init_noise is not None else None
|
| 182 |
+
|
| 183 |
+
sample_fn = (
|
| 184 |
+
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
|
| 185 |
+
)
|
| 186 |
+
sample = sample_fn(
|
| 187 |
+
model,
|
| 188 |
+
sample_shape,
|
| 189 |
+
clip_denoised=args.clip_denoised,
|
| 190 |
+
model_kwargs=model_kwargs,
|
| 191 |
+
noise = init_noise
|
| 192 |
+
)
|
| 193 |
+
return sample, img_batch, model_kwargs
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def generate(video_path, audio_path, model, diffusion, detector, args, out_path=None, save_orig=True):
|
| 197 |
+
video_frames = load_video_frames(video_path, args)
|
| 198 |
+
try:
|
| 199 |
+
face_det_results = face_detect(video_frames.copy(), detector, args, resize=True)
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print("Error:", e, video_path, audio_path)
|
| 202 |
+
import traceback
|
| 203 |
+
print(traceback.format_exc())
|
| 204 |
+
wrong_all_indiv_mels, wrong_audio_wavform = load_all_indiv_mels(audio_path, args)
|
| 205 |
+
|
| 206 |
+
min_frames = min(len(video_frames), len(wrong_all_indiv_mels))
|
| 207 |
+
video_frames = video_frames[:min_frames]
|
| 208 |
+
face_det_results = face_det_results[:min_frames]
|
| 209 |
+
face_bboxes = [face_det_results[i][1] for i in range(min_frames)]
|
| 210 |
+
face_frames = torch.FloatTensor(np.transpose(np.asarray([face_det_results[i][0] for i in range(min_frames)], dtype=np.float32)/255.,(0,3,1,2)))#[N, C, H, W]
|
| 211 |
+
wrong_all_indiv_mels = torch.FloatTensor(np.asarray(wrong_all_indiv_mels[:min_frames])).unsqueeze(1) #[N, 1, h, w]
|
| 212 |
+
|
| 213 |
+
if save_orig:
|
| 214 |
+
if out_path is None:
|
| 215 |
+
out_path_orig = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+"_orig.mp4")
|
| 216 |
+
else:
|
| 217 |
+
out_path_orig = out_path.replace(".mp4", "_orig.mp4")
|
| 218 |
+
torchvision.io.write_video(
|
| 219 |
+
out_path_orig,
|
| 220 |
+
video_array=torch.from_numpy(np.array(video_frames)), fps = args.video_fps, video_codec='libx264',
|
| 221 |
+
audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if args.sampling_ref_type=='gt':
|
| 225 |
+
ref_frames = face_frames.clone()
|
| 226 |
+
elif args.sampling_ref_type=='first_frame':
|
| 227 |
+
ref_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
|
| 228 |
+
elif args.sampling_ref_type=='random':
|
| 229 |
+
rand_idx = random.Random(args.sampling_seed).randint(0, len(face_frames)-1)
|
| 230 |
+
ref_frames = face_frames[rand_idx:rand_idx+1].repeat(len(face_frames),1,1,1)
|
| 231 |
+
|
| 232 |
+
if args.sampling_input_type=='first_frame':
|
| 233 |
+
face_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
|
| 234 |
+
video_frames = np.array(video_frames[0:1]*len(video_frames))
|
| 235 |
+
face_bboxes = np.array(face_bboxes[0:1]*len(face_bboxes))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
generated_video_frames = []
|
| 239 |
+
b_s = args.sampling_batch_size
|
| 240 |
+
for i in range(0,min_frames, b_s*args.nframes):
|
| 241 |
+
video_frames_batch = video_frames[i:i+b_s*args.nframes]
|
| 242 |
+
face_bboxes_batch = face_bboxes[i:i+b_s*args.nframes]
|
| 243 |
+
|
| 244 |
+
try:
|
| 245 |
+
img_batch = face_frames[i:i+b_s*args.nframes] #[BF, C, H, W]
|
| 246 |
+
img_batch = img_batch.reshape(-1, args.nframes, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
|
| 247 |
+
ref_batch = ref_frames[i:i+b_s*args.nframes]
|
| 248 |
+
ref_batch = ref_batch.reshape(-1, args.nframes, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
|
| 249 |
+
wrong_indiv_mel_batch = wrong_all_indiv_mels[i:i+b_s*args.nframes] #[BF, 1, h, w]
|
| 250 |
+
wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, args.nframes, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
|
| 251 |
+
except: # of the last batch, if B*F % nframes!=0, then the above reshape throws error
|
| 252 |
+
# but internally everything is going to get converted to BF
|
| 253 |
+
# ie. (B,F, C, H, W) -> (B*F, C, H, W) but (B*F, 1, C, H, W) -> (B*F, C, H, W)
|
| 254 |
+
img_batch = face_frames[i:i+b_s*args.nframes] #[BF, C, H, W]
|
| 255 |
+
img_batch = img_batch.reshape(-1, 1, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
|
| 256 |
+
ref_batch = ref_frames[i:i+b_s*args.nframes]
|
| 257 |
+
ref_batch = ref_batch.reshape(-1, 1, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
|
| 258 |
+
wrong_indiv_mel_batch = wrong_all_indiv_mels[i:i+b_s*args.nframes] #[BF, 1, h, w]
|
| 259 |
+
wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, 1, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
batch = {"image":img_batch,
|
| 263 |
+
"ref_img":ref_batch,
|
| 264 |
+
"indiv_mels":wrong_indiv_mel_batch}
|
| 265 |
+
|
| 266 |
+
sample, img_batch, model_kwargs = sample_batch(batch, model, diffusion, args)
|
| 267 |
+
mask = model_kwargs['mask']
|
| 268 |
+
recon_batch = sample * mask + (1. -mask)*img_batch #[BF, C, H, W]
|
| 269 |
+
recon_batch = (normalise(recon_batch)*255).cpu().numpy().transpose(0,2,3,1) #[-1,1] -> [0,255]
|
| 270 |
+
|
| 271 |
+
for g,v,b in zip(recon_batch, video_frames_batch, face_bboxes_batch):
|
| 272 |
+
y1, y2, x1, x2 = b
|
| 273 |
+
g = cv2.resize(g.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 274 |
+
v[y1:y2, x1:x2] = g
|
| 275 |
+
generated_video_frames.append(v)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
|
| 280 |
+
min_time = len(generated_video_frames)/args.video_fps # because video is already smaller because it got chopped accoding to the mel array length
|
| 281 |
+
wrong_audio_wavform = wrong_audio_wavform[:int(min_time*args.sample_rate)]
|
| 282 |
+
print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
|
| 283 |
+
if out_path is None:
|
| 284 |
+
out_path = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+".mp4")
|
| 285 |
+
torchvision.io.write_video(
|
| 286 |
+
out_path,
|
| 287 |
+
video_array=torch.from_numpy(np.array(generated_video_frames)), fps = args.video_fps, video_codec='libx264',
|
| 288 |
+
audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def generate_from_filelist(test_video_dir, filelist, model, diffusion, detector, args):
|
| 296 |
+
video_names = []
|
| 297 |
+
audio_names = []
|
| 298 |
+
with open(filelist, "r") as f:
|
| 299 |
+
lines = f.readlines()
|
| 300 |
+
for line in tqdm(lines):
|
| 301 |
+
try:
|
| 302 |
+
audio_name, video_name = line.strip().split()
|
| 303 |
+
audio_path = join(test_video_dir, audio_name+'.mp4')
|
| 304 |
+
video_path = join(test_video_dir, video_name+'.mp4')
|
| 305 |
+
out_path = join(args.sample_path,audio_name.replace('/','.')+"_"+video_name.replace('/','.')+".mp4")
|
| 306 |
+
generate(video_path, audio_path, model, diffusion, detector, args, out_path=out_path ,save_orig=args.save_orig)
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print("Error:", e, video_path, audio_path)
|
| 309 |
+
import traceback
|
| 310 |
+
print(traceback.format_exc())
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def main():
|
| 315 |
+
args = create_argparser().parse_args()
|
| 316 |
+
dist_util.setup_dist()
|
| 317 |
+
logger.configure(dir=args.sample_path, format_strs=["stdout", "log"])
|
| 318 |
+
|
| 319 |
+
logger.log("creating model...")
|
| 320 |
+
model, diffusion = tfg_create_model_and_diffusion(
|
| 321 |
+
**args_to_dict(args, tfg_model_and_diffusion_defaults().keys())
|
| 322 |
+
)
|
| 323 |
+
model.load_state_dict(
|
| 324 |
+
dist_util.load_state_dict(args.model_path, map_location='cpu')
|
| 325 |
+
)
|
| 326 |
+
model.to(dist_util.dev())
|
| 327 |
+
if args.use_fp16:
|
| 328 |
+
model.convert_to_fp16()
|
| 329 |
+
model.eval()
|
| 330 |
+
|
| 331 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 332 |
+
|
| 333 |
+
if args.generate_from_filelist:
|
| 334 |
+
generate_from_filelist(args.test_video_dir, args.filelist, model, diffusion, detector, args)
|
| 335 |
+
else:
|
| 336 |
+
generate(args.video_path, args.audio_path, model, diffusion, detector, args, out_path=args.out_path, save_orig=args.save_orig)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def create_argparser():
|
| 340 |
+
defaults = dict(
|
| 341 |
+
# generate from a single audio-video pair
|
| 342 |
+
generate_from_filelist = False,
|
| 343 |
+
video_path = "",
|
| 344 |
+
audio_path = "",
|
| 345 |
+
out_path = None,
|
| 346 |
+
save_orig = True,
|
| 347 |
+
|
| 348 |
+
#generate from filelist : generate_from_filelist = True
|
| 349 |
+
test_video_dir = "test_videos",
|
| 350 |
+
filelist = "test_filelist.txt",
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
use_fp16 = True,
|
| 354 |
+
#tfg specific
|
| 355 |
+
face_hide_percentage=0.5,
|
| 356 |
+
use_ref=False,
|
| 357 |
+
use_audio=False,
|
| 358 |
+
audio_as_style=False,
|
| 359 |
+
audio_as_style_encoder_mlp=False,
|
| 360 |
+
|
| 361 |
+
#data args
|
| 362 |
+
nframes=1,
|
| 363 |
+
nrefer=0,
|
| 364 |
+
image_size=128,
|
| 365 |
+
syncnet_T = 5,
|
| 366 |
+
syncnet_mel_step_size = 16,
|
| 367 |
+
audio_frames_per_video = 16, #for tfg model, we use sound corresponding to 5 frames centred at that frame
|
| 368 |
+
audio_dim=80,
|
| 369 |
+
is_voxceleb2=True,
|
| 370 |
+
|
| 371 |
+
video_fps=25,
|
| 372 |
+
sample_rate=16000, #audio sampling rate
|
| 373 |
+
mel_steps_per_sec=80.,
|
| 374 |
+
|
| 375 |
+
#sampling args
|
| 376 |
+
clip_denoised=True, # not used in training
|
| 377 |
+
sampling_batch_size=2,
|
| 378 |
+
use_ddim=False,
|
| 379 |
+
model_path="",
|
| 380 |
+
sample_path="d2l_gen",
|
| 381 |
+
sample_partition="",
|
| 382 |
+
sampling_seed=None,
|
| 383 |
+
sampling_use_gt_for_ref=False,
|
| 384 |
+
sampling_ref_type='gt', #one of ['gt', 'first_frame', 'random']
|
| 385 |
+
sampling_input_type='gt', #one of ['gt', 'first_frame']
|
| 386 |
+
|
| 387 |
+
# face detection args
|
| 388 |
+
face_det_batch_size=64,
|
| 389 |
+
pads = "0,0,0,0"
|
| 390 |
+
)
|
| 391 |
+
defaults.update(tfg_model_and_diffusion_defaults())
|
| 392 |
+
parser = argparse.ArgumentParser()
|
| 393 |
+
add_dict_to_argparser(parser, defaults)
|
| 394 |
+
return parser
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
if __name__=="__main__":
|
| 398 |
+
main()
|
generate_dist.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
''' consistent initial noise for video generation'''
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
from os.path import join, basename, dirname, splitext
|
| 5 |
+
import shutil
|
| 6 |
+
import argparse
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
import torch, torchvision
|
| 10 |
+
import subprocess
|
| 11 |
+
from audio import audio
|
| 12 |
+
import face_detection
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
| 15 |
+
from guided_diffusion import dist_util, logger
|
| 16 |
+
from guided_diffusion.resample import create_named_schedule_sampler
|
| 17 |
+
from guided_diffusion.script_util import (
|
| 18 |
+
tfg_model_and_diffusion_defaults,
|
| 19 |
+
tfg_create_model_and_diffusion,
|
| 20 |
+
args_to_dict,
|
| 21 |
+
add_dict_to_argparser,
|
| 22 |
+
)
|
| 23 |
+
from time import time
|
| 24 |
+
import torch.distributed as dist
|
| 25 |
+
from guided_diffusion.tfg_data_util import (
|
| 26 |
+
tfg_process_batch,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def get_frame_id(frame):
|
| 30 |
+
return int(basename(frame).split('.')[0])
|
| 31 |
+
|
| 32 |
+
def crop_audio_window(spec, start_frame, args ):
|
| 33 |
+
if type(start_frame) == int:
|
| 34 |
+
start_frame_num = start_frame
|
| 35 |
+
else:
|
| 36 |
+
start_frame_num = get_frame_id(start_frame)
|
| 37 |
+
start_idx = int(args.mel_steps_per_sec * (start_frame_num / float(args.video_fps)))
|
| 38 |
+
end_idx = start_idx + args.syncnet_mel_step_size
|
| 39 |
+
return spec[start_idx : end_idx, :]
|
| 40 |
+
|
| 41 |
+
def load_all_indiv_mels(path, args):
|
| 42 |
+
in_path = path
|
| 43 |
+
out_dir = join(args.sample_path, "temp",str(dist.get_rank()), basename(in_path).replace(".mp4", ""))
|
| 44 |
+
os.makedirs(out_dir, exist_ok= True)
|
| 45 |
+
out_path = join(out_dir, "audio.wav")
|
| 46 |
+
command2 = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(in_path, out_path)
|
| 47 |
+
subprocess.call(command2, shell=True)
|
| 48 |
+
wav = audio.load_wav(out_path, args.sample_rate)
|
| 49 |
+
orig_mel = audio.melspectrogram(wav).T
|
| 50 |
+
|
| 51 |
+
all_indiv_mels = []
|
| 52 |
+
# i=0
|
| 53 |
+
i=1
|
| 54 |
+
while True:
|
| 55 |
+
m = crop_audio_window(orig_mel.copy(), max(i - args.syncnet_T//2,0), args)
|
| 56 |
+
if (m.shape[0] != args.syncnet_mel_step_size):
|
| 57 |
+
break
|
| 58 |
+
all_indiv_mels.append(m.T)
|
| 59 |
+
i+=1
|
| 60 |
+
|
| 61 |
+
#clean up
|
| 62 |
+
shutil.rmtree(join(args.sample_path, "temp", str(dist.get_rank())))
|
| 63 |
+
|
| 64 |
+
return all_indiv_mels, wav
|
| 65 |
+
|
| 66 |
+
def load_video_frames(path, args):
|
| 67 |
+
in_path = path
|
| 68 |
+
out_dir = join(args.sample_path, "temp", str(dist.get_rank()), basename(in_path).replace(".mp4", ""), "image")
|
| 69 |
+
os.makedirs(out_dir, exist_ok= True)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
command = "ffmpeg -loglevel error -y -i {} -vf fps={} -q:v 2 -qmin 1 {}/%05d.jpg".format(in_path, args.video_fps, out_dir)
|
| 73 |
+
subprocess.call(command, shell=True)
|
| 74 |
+
|
| 75 |
+
video_frames=[]
|
| 76 |
+
for i, img_name in enumerate(sorted(os.listdir(out_dir))):
|
| 77 |
+
img_path=join(out_dir, img_name)
|
| 78 |
+
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
|
| 79 |
+
video_frames.append(img)
|
| 80 |
+
|
| 81 |
+
#clean up
|
| 82 |
+
shutil.rmtree(join(args.sample_path, "temp", str(dist.get_rank())))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
return video_frames
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_smoothened_boxes(boxes, T):
|
| 89 |
+
for i in range(len(boxes)):
|
| 90 |
+
if i + T > len(boxes):
|
| 91 |
+
window = boxes[len(boxes) - T:]
|
| 92 |
+
else:
|
| 93 |
+
window = boxes[i : i + T]
|
| 94 |
+
boxes[i] = np.mean(window, axis=0)
|
| 95 |
+
return boxes
|
| 96 |
+
|
| 97 |
+
def my_voxceleb2_crop(img):
|
| 98 |
+
return img[:-int(img.shape[0]*2.36/8) , int(img.shape[1]*1.8/8): -int(img.shape[1]*1.8/8)]
|
| 99 |
+
|
| 100 |
+
def my_voxceleb2_crop_bboxs(img):
|
| 101 |
+
return 0,img.shape[0]-int(img.shape[0]*2.36/8), int(img.shape[1]*1.8/8), img.shape[1]-int(img.shape[1]*1.8/8)
|
| 102 |
+
|
| 103 |
+
def face_detect(images, detector, args, resize=False):
|
| 104 |
+
batch_size = args.face_det_batch_size
|
| 105 |
+
|
| 106 |
+
while 1:
|
| 107 |
+
predictions = []
|
| 108 |
+
try:
|
| 109 |
+
for i in range(0, len(images), batch_size):
|
| 110 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
| 111 |
+
except RuntimeError:
|
| 112 |
+
if batch_size == 1:
|
| 113 |
+
raise RuntimeError('Image too big to run face detection on GPU')
|
| 114 |
+
batch_size //= 2
|
| 115 |
+
args.face_det_batch_size = batch_size
|
| 116 |
+
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
| 117 |
+
continue
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
results = []
|
| 121 |
+
if type(args.pads) == str :
|
| 122 |
+
args.pads = [int(x) for x in args.pads.split(",")]
|
| 123 |
+
pady1, pady2, padx1, padx2 = args.pads
|
| 124 |
+
for rect, image in zip(predictions, images):
|
| 125 |
+
if rect is None:
|
| 126 |
+
raise ValueError('Face not detected!')
|
| 127 |
+
|
| 128 |
+
y1 = max(0, rect[1] - pady1)
|
| 129 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
| 130 |
+
x1 = max(0, rect[0] - padx1)
|
| 131 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
| 132 |
+
|
| 133 |
+
results.append([x1, y1, x2, y2])
|
| 134 |
+
|
| 135 |
+
boxes = get_smoothened_boxes(np.array(results), T=5)
|
| 136 |
+
|
| 137 |
+
if resize:
|
| 138 |
+
if args.is_voxceleb2:
|
| 139 |
+
results = [[cv2.resize(my_voxceleb2_crop(image),(args.image_size, args.image_size)), my_voxceleb2_crop_bboxs(image), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 140 |
+
else:
|
| 141 |
+
results = [[cv2.resize(image[y1: y2, x1:x2],(args.image_size, args.image_size)), (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 142 |
+
else:
|
| 143 |
+
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
| 144 |
+
return results
|
| 145 |
+
|
| 146 |
+
def normalise(tensor):
|
| 147 |
+
""" [-1,1]->[0,1]"""
|
| 148 |
+
return ((tensor+1)*0.5).clamp(0,1)
|
| 149 |
+
|
| 150 |
+
def normalise2(tensor):
|
| 151 |
+
""" [0,1]->[-1,1]"""
|
| 152 |
+
return (tensor*2-1).clamp(-1,1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def sample_batch(batch, model, diffusion, args):
|
| 156 |
+
B, F, C, H, W = batch[f'image'].shape
|
| 157 |
+
sample_shape = (B*F, C, H, W)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
#generate fixed noise
|
| 161 |
+
init_noise = None
|
| 162 |
+
if args.sampling_seed:
|
| 163 |
+
|
| 164 |
+
state = torch.get_rng_state()
|
| 165 |
+
torch.manual_seed(args.sampling_seed)
|
| 166 |
+
torch.cuda.manual_seed_all(args.sampling_seed)
|
| 167 |
+
init_noise = torch.randn((1,C,H,W))
|
| 168 |
+
#repeat noise for all frames
|
| 169 |
+
init_noise = init_noise.repeat(B*F,1,1,1)
|
| 170 |
+
torch.set_rng_state(state)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
img_batch, model_kwargs = tfg_process_batch(batch, args.face_hide_percentage,
|
| 174 |
+
use_ref=args.use_ref,
|
| 175 |
+
use_audio=args.use_audio,
|
| 176 |
+
# sampling_use_gt_for_ref=args.sampling_use_gt_for_ref,
|
| 177 |
+
noise=init_noise)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
img_batch = img_batch.to(dist_util.dev())
|
| 181 |
+
model_kwargs = {k: v.to(dist_util.dev()) for k,v in model_kwargs.items()}
|
| 182 |
+
init_noise = init_noise.to(dist_util.dev()) if init_noise is not None else None
|
| 183 |
+
|
| 184 |
+
sample_fn = (
|
| 185 |
+
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
|
| 186 |
+
)
|
| 187 |
+
sample = sample_fn(
|
| 188 |
+
model,
|
| 189 |
+
sample_shape,
|
| 190 |
+
clip_denoised=args.clip_denoised,
|
| 191 |
+
model_kwargs=model_kwargs,
|
| 192 |
+
noise = init_noise
|
| 193 |
+
)
|
| 194 |
+
return sample, img_batch, model_kwargs
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def generate(video_path, audio_path, model, diffusion, detector, args, out_path=None, save_orig=True):
|
| 198 |
+
video_frames = load_video_frames(video_path, args)
|
| 199 |
+
try:
|
| 200 |
+
face_det_results = face_detect(video_frames.copy(), detector, args, resize=True)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print("Error:", e, video_path, audio_path)
|
| 203 |
+
import traceback
|
| 204 |
+
print(traceback.format_exc())
|
| 205 |
+
wrong_all_indiv_mels, wrong_audio_wavform = load_all_indiv_mels(audio_path, args)
|
| 206 |
+
|
| 207 |
+
min_frames = min(len(video_frames), len(wrong_all_indiv_mels))
|
| 208 |
+
video_frames = video_frames[:min_frames]
|
| 209 |
+
face_det_results = face_det_results[:min_frames]
|
| 210 |
+
face_bboxes = [face_det_results[i][1] for i in range(min_frames)]
|
| 211 |
+
face_frames = torch.FloatTensor(np.transpose(np.asarray([face_det_results[i][0] for i in range(min_frames)], dtype=np.float32)/255.,(0,3,1,2)))#[N, C, H, W]
|
| 212 |
+
wrong_all_indiv_mels = torch.FloatTensor(np.asarray(wrong_all_indiv_mels[:min_frames])).unsqueeze(1) #[N, 1, h, w]
|
| 213 |
+
|
| 214 |
+
if save_orig:
|
| 215 |
+
if out_path is None:
|
| 216 |
+
out_path_orig = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+"_orig.mp4")
|
| 217 |
+
else:
|
| 218 |
+
out_path_orig = out_path.replace(".mp4", "_orig.mp4")
|
| 219 |
+
torchvision.io.write_video(
|
| 220 |
+
out_path_orig,
|
| 221 |
+
video_array=torch.from_numpy(np.array(video_frames)), fps = args.video_fps, video_codec='libx264',
|
| 222 |
+
audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if args.sampling_ref_type=='gt':
|
| 226 |
+
ref_frames = face_frames.clone()
|
| 227 |
+
elif args.sampling_ref_type=='first_frame':
|
| 228 |
+
ref_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
|
| 229 |
+
elif args.sampling_ref_type=='random':
|
| 230 |
+
rand_idx = random.Random(args.sampling_seed).randint(0, len(face_frames)-1)
|
| 231 |
+
ref_frames = face_frames[rand_idx:rand_idx+1].repeat(len(face_frames),1,1,1)
|
| 232 |
+
|
| 233 |
+
if args.sampling_input_type=='first_frame':
|
| 234 |
+
face_frames = face_frames[0:1].repeat(len(face_frames),1,1,1)
|
| 235 |
+
video_frames = np.array(video_frames[0:1]*len(video_frames))
|
| 236 |
+
face_bboxes = np.array(face_bboxes[0:1]*len(face_bboxes))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
rank = dist.get_rank()
|
| 240 |
+
world_size = dist.get_world_size()
|
| 241 |
+
chunk_size = int(np.ceil(min_frames/world_size))
|
| 242 |
+
start_idx = rank * chunk_size
|
| 243 |
+
end_idx = min(start_idx + chunk_size, min_frames)
|
| 244 |
+
generated_video_frames = []
|
| 245 |
+
b_s = args.sampling_batch_size
|
| 246 |
+
|
| 247 |
+
# print(rank,"/",world_size, "chunk: [",start_idx,"-", end_idx,"/",min_frames,"]")
|
| 248 |
+
|
| 249 |
+
dist.barrier()
|
| 250 |
+
torch.cuda.synchronize()
|
| 251 |
+
t1=time()
|
| 252 |
+
# for i in range(0,min_frames, b_s*args.nframes):
|
| 253 |
+
for i in range(start_idx,end_idx, b_s*args.nframes):
|
| 254 |
+
slice_end = min(i+b_s*args.nframes, end_idx)
|
| 255 |
+
# if rank==0:
|
| 256 |
+
# print("rank 0: slice:",i,":",slice_end)
|
| 257 |
+
video_frames_batch = video_frames[i:slice_end]
|
| 258 |
+
face_bboxes_batch = face_bboxes[i:slice_end]
|
| 259 |
+
|
| 260 |
+
# try:
|
| 261 |
+
if (slice_end-i) % args.nframes==0:
|
| 262 |
+
img_batch = face_frames[i:slice_end] #[BF, C, H, W]
|
| 263 |
+
img_batch = img_batch.reshape(-1, args.nframes, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
|
| 264 |
+
ref_batch = ref_frames[i:slice_end]
|
| 265 |
+
ref_batch = ref_batch.reshape(-1, args.nframes, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
|
| 266 |
+
wrong_indiv_mel_batch = wrong_all_indiv_mels[i:slice_end] #[BF, 1, h, w]
|
| 267 |
+
wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, args.nframes, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
|
| 268 |
+
# except:
|
| 269 |
+
else: # of the last batch, if B*F % nframes!=0, then the above reshape throws error
|
| 270 |
+
# but internally everything is going to get converted to BF
|
| 271 |
+
# ie. (B,F, C, H, W) -> (B*F, C, H, W) but (B*F, 1, C, H, W) -> (B*F, C, H, W)
|
| 272 |
+
img_batch = face_frames[i:slice_end] #[BF, C, H, W]
|
| 273 |
+
img_batch = img_batch.reshape(-1, 1, img_batch.size(-3), img_batch.size(-2), img_batch.size(-1))
|
| 274 |
+
ref_batch = ref_frames[i:slice_end]
|
| 275 |
+
ref_batch = ref_batch.reshape(-1, 1, ref_batch.size(-3), ref_batch.size(-2), ref_batch.size(-1))
|
| 276 |
+
wrong_indiv_mel_batch = wrong_all_indiv_mels[i:slice_end] #[BF, 1, h, w]
|
| 277 |
+
wrong_indiv_mel_batch = wrong_indiv_mel_batch.reshape(-1, 1, wrong_indiv_mel_batch.size(-3),wrong_indiv_mel_batch.size(-2),wrong_indiv_mel_batch.size(-1))
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
batch = {"image":img_batch,
|
| 281 |
+
"ref_img":ref_batch,
|
| 282 |
+
"indiv_mels":wrong_indiv_mel_batch}
|
| 283 |
+
|
| 284 |
+
sample, img_batch, model_kwargs = sample_batch(batch, model, diffusion, args)
|
| 285 |
+
mask = model_kwargs['mask']
|
| 286 |
+
recon_batch = sample * mask + (1. -mask)*img_batch #[BF, C, H, W]
|
| 287 |
+
recon_batch = (normalise(recon_batch)*255).cpu().numpy().transpose(0,2,3,1) #[-1,1] -> [0,255]
|
| 288 |
+
|
| 289 |
+
for g,v,b in zip(recon_batch, video_frames_batch, face_bboxes_batch):
|
| 290 |
+
y1, y2, x1, x2 = b
|
| 291 |
+
g = cv2.resize(g.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 292 |
+
v[y1:y2, x1:x2] = g
|
| 293 |
+
generated_video_frames.append(v)
|
| 294 |
+
|
| 295 |
+
torch.cuda.synchronize()
|
| 296 |
+
t3=time()
|
| 297 |
+
all_generated_video_frames = [None for _ in range(dist.get_world_size())]
|
| 298 |
+
dist.all_gather_object(all_generated_video_frames, generated_video_frames) # gather not supported with NCCL
|
| 299 |
+
all_generated_video_frames_combined = []
|
| 300 |
+
[all_generated_video_frames_combined.extend(gvf) for gvf in all_generated_video_frames]
|
| 301 |
+
generated_video_frames = all_generated_video_frames_combined
|
| 302 |
+
|
| 303 |
+
torch.cuda.synchronize()
|
| 304 |
+
t2=time()
|
| 305 |
+
|
| 306 |
+
if dist.get_rank() == 0:
|
| 307 |
+
print("Time taken for sampling, ", t2-t1, ",time without all gather, ", t3-t1, ",frames/gpu, ", len(generated_video_frames), ",total frames, ", min_frames)
|
| 308 |
+
print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
|
| 309 |
+
min_time = len(generated_video_frames)/args.video_fps # because video is already smaller because it got chopped accoding to the mel array length
|
| 310 |
+
wrong_audio_wavform = wrong_audio_wavform[:int(min_time*args.sample_rate)]
|
| 311 |
+
print(wrong_audio_wavform.shape, np.array(generated_video_frames).shape)
|
| 312 |
+
if out_path is None:
|
| 313 |
+
out_path = os.path.join(args.sample_path, splitext(basename(video_path))[0]+"_"+ splitext(basename(audio_path))[0]+".mp4")
|
| 314 |
+
torchvision.io.write_video(
|
| 315 |
+
out_path,
|
| 316 |
+
video_array=torch.from_numpy(np.array(generated_video_frames)), fps = args.video_fps, video_codec='libx264',
|
| 317 |
+
audio_array=torch.from_numpy(wrong_audio_wavform).unsqueeze(0), audio_fps=args.sample_rate, audio_codec='aac'
|
| 318 |
+
)
|
| 319 |
+
dist.barrier()
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def generate_from_filelist(test_video_dir, filelist, model, diffusion, detector, args):
|
| 326 |
+
video_names = []
|
| 327 |
+
audio_names = []
|
| 328 |
+
with open(filelist, "r") as f:
|
| 329 |
+
lines = f.readlines()
|
| 330 |
+
for line in tqdm(lines):
|
| 331 |
+
try:
|
| 332 |
+
audio_name, video_name = line.strip().split()
|
| 333 |
+
audio_path = join(test_video_dir, audio_name+'.mp4')
|
| 334 |
+
video_path = join(test_video_dir, video_name+'.mp4')
|
| 335 |
+
out_path = join(args.sample_path,audio_name.replace('/','.')+"_"+video_name.replace('/','.')+".mp4")
|
| 336 |
+
generate(video_path, audio_path, model, diffusion, detector, args, out_path=out_path ,save_orig=args.save_orig)
|
| 337 |
+
except Exception as e:
|
| 338 |
+
print("Error:", e, video_path, audio_path)
|
| 339 |
+
import traceback
|
| 340 |
+
print(traceback.format_exc())
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def main():
|
| 345 |
+
args = create_argparser().parse_args()
|
| 346 |
+
dist_util.setup_dist()
|
| 347 |
+
logger.configure(dir=args.sample_path, format_strs=["stdout", "log"])
|
| 348 |
+
|
| 349 |
+
logger.log("creating model...")
|
| 350 |
+
model, diffusion = tfg_create_model_and_diffusion(
|
| 351 |
+
**args_to_dict(args, tfg_model_and_diffusion_defaults().keys())
|
| 352 |
+
)
|
| 353 |
+
model.load_state_dict(
|
| 354 |
+
dist_util.load_state_dict(args.model_path, map_location='cpu')
|
| 355 |
+
)
|
| 356 |
+
model.to(dist_util.dev())
|
| 357 |
+
if args.use_fp16:
|
| 358 |
+
model.convert_to_fp16()
|
| 359 |
+
model.eval()
|
| 360 |
+
|
| 361 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 362 |
+
|
| 363 |
+
if args.generate_from_filelist:
|
| 364 |
+
generate_from_filelist(args.test_video_dir, args.filelist, model, diffusion, detector, args)
|
| 365 |
+
else:
|
| 366 |
+
generate(args.video_path, args.audio_path, model, diffusion, detector, args, out_path=args.out_path, save_orig=args.save_orig)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def create_argparser():
|
| 370 |
+
defaults = dict(
|
| 371 |
+
# generate from a single audio-video pair
|
| 372 |
+
generate_from_filelist = False,
|
| 373 |
+
video_path = "",
|
| 374 |
+
audio_path = "",
|
| 375 |
+
out_path = None,
|
| 376 |
+
save_orig = True,
|
| 377 |
+
|
| 378 |
+
#generate from filelist : generate_from_filelist = True
|
| 379 |
+
test_video_dir = "test_videos",
|
| 380 |
+
filelist = "test_filelist.txt",
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
use_fp16 = True,
|
| 384 |
+
#tfg specific
|
| 385 |
+
face_hide_percentage=0.5,
|
| 386 |
+
use_ref=False,
|
| 387 |
+
use_audio=False,
|
| 388 |
+
audio_as_style=False,
|
| 389 |
+
audio_as_style_encoder_mlp=False,
|
| 390 |
+
|
| 391 |
+
#data args
|
| 392 |
+
nframes=1,
|
| 393 |
+
nrefer=0,
|
| 394 |
+
image_size=128,
|
| 395 |
+
syncnet_T = 5,
|
| 396 |
+
syncnet_mel_step_size = 16,
|
| 397 |
+
audio_frames_per_video = 16, #for tfg model, we use sound corresponding to 5 frames centred at that frame
|
| 398 |
+
audio_dim=80,
|
| 399 |
+
is_voxceleb2=True,
|
| 400 |
+
|
| 401 |
+
video_fps=25,
|
| 402 |
+
sample_rate=16000, #audio sampling rate
|
| 403 |
+
mel_steps_per_sec=80.,
|
| 404 |
+
|
| 405 |
+
#sampling args
|
| 406 |
+
clip_denoised=True, # not used in training
|
| 407 |
+
sampling_batch_size=2,
|
| 408 |
+
use_ddim=False,
|
| 409 |
+
model_path="",
|
| 410 |
+
sample_path="d2l_gen",
|
| 411 |
+
sample_partition="",
|
| 412 |
+
sampling_seed=None,
|
| 413 |
+
sampling_use_gt_for_ref=False,
|
| 414 |
+
sampling_ref_type='gt', #one of ['gt', 'first_frame', 'random']
|
| 415 |
+
sampling_input_type='gt', #one of ['gt', 'first_frame']
|
| 416 |
+
|
| 417 |
+
# face detection args
|
| 418 |
+
face_det_batch_size=64,
|
| 419 |
+
pads = "0,0,0,0"
|
| 420 |
+
)
|
| 421 |
+
defaults.update(tfg_model_and_diffusion_defaults())
|
| 422 |
+
parser = argparse.ArgumentParser()
|
| 423 |
+
add_dict_to_argparser(parser, defaults)
|
| 424 |
+
return parser
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if __name__=="__main__":
|
| 428 |
+
main()
|
guided-diffusion/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 OpenAI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
guided-diffusion/guided_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Codebase for "Improved Denoising Diffusion Probabilistic Models".
|
| 3 |
+
"""
|
guided-diffusion/guided_diffusion/dist_util.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers for distributed training.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import socket
|
| 8 |
+
|
| 9 |
+
import blobfile as bf
|
| 10 |
+
from mpi4py import MPI
|
| 11 |
+
import torch as th
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
|
| 14 |
+
# Change this to reflect your cluster layout.
|
| 15 |
+
# The GPU for a given rank is (rank % GPUS_PER_NODE).
|
| 16 |
+
GPUS_PER_NODE = 8
|
| 17 |
+
|
| 18 |
+
SETUP_RETRY_COUNT = 3
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def setup_dist():
|
| 22 |
+
"""
|
| 23 |
+
Setup a distributed process group.
|
| 24 |
+
"""
|
| 25 |
+
if dist.is_initialized():
|
| 26 |
+
return
|
| 27 |
+
print("MPI.COMM_WORLD.Get_rank()", MPI.COMM_WORLD.Get_rank())
|
| 28 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
|
| 29 |
+
print('os.environ["CUDA_VISIBLE_DEVICES"]', os.environ["CUDA_VISIBLE_DEVICES"])
|
| 30 |
+
comm = MPI.COMM_WORLD
|
| 31 |
+
backend = "gloo" if not th.cuda.is_available() else "nccl"
|
| 32 |
+
|
| 33 |
+
if backend == "gloo":
|
| 34 |
+
hostname = "localhost"
|
| 35 |
+
else:
|
| 36 |
+
hostname = socket.gethostbyname(socket.getfqdn())
|
| 37 |
+
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
|
| 38 |
+
os.environ["RANK"] = str(comm.rank)
|
| 39 |
+
os.environ["WORLD_SIZE"] = str(comm.size)
|
| 40 |
+
|
| 41 |
+
port = comm.bcast(_find_free_port(), root=0)
|
| 42 |
+
os.environ["MASTER_PORT"] = str(port)
|
| 43 |
+
dist.init_process_group(backend=backend, init_method="env://")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def dev():
|
| 47 |
+
"""
|
| 48 |
+
Get the device to use for torch.distributed.
|
| 49 |
+
"""
|
| 50 |
+
if th.cuda.is_available():
|
| 51 |
+
return th.device(f"cuda")
|
| 52 |
+
return th.device("cpu")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_state_dict(path, **kwargs):
|
| 56 |
+
"""
|
| 57 |
+
Load a PyTorch file without redundant fetches across MPI ranks.
|
| 58 |
+
"""
|
| 59 |
+
chunk_size = 2 ** 30 # MPI has a relatively small size limit
|
| 60 |
+
if MPI.COMM_WORLD.Get_rank() == 0:
|
| 61 |
+
with bf.BlobFile(path, "rb") as f:
|
| 62 |
+
data = f.read()
|
| 63 |
+
num_chunks = len(data) // chunk_size
|
| 64 |
+
if len(data) % chunk_size:
|
| 65 |
+
num_chunks += 1
|
| 66 |
+
MPI.COMM_WORLD.bcast(num_chunks)
|
| 67 |
+
for i in range(0, len(data), chunk_size):
|
| 68 |
+
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
|
| 69 |
+
else:
|
| 70 |
+
num_chunks = MPI.COMM_WORLD.bcast(None)
|
| 71 |
+
data = bytes()
|
| 72 |
+
for _ in range(num_chunks):
|
| 73 |
+
data += MPI.COMM_WORLD.bcast(None)
|
| 74 |
+
|
| 75 |
+
return th.load(io.BytesIO(data), **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def sync_params(params):
|
| 79 |
+
"""
|
| 80 |
+
Synchronize a sequence of Tensors across ranks from rank 0.
|
| 81 |
+
"""
|
| 82 |
+
for p in params:
|
| 83 |
+
with th.no_grad():
|
| 84 |
+
dist.broadcast(p, 0)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _find_free_port():
|
| 88 |
+
try:
|
| 89 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 90 |
+
s.bind(("", 0))
|
| 91 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 92 |
+
return s.getsockname()[1]
|
| 93 |
+
finally:
|
| 94 |
+
s.close()
|
guided-diffusion/guided_diffusion/fp16_util.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers to train with 16-bit precision.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
| 9 |
+
|
| 10 |
+
from . import logger
|
| 11 |
+
|
| 12 |
+
INITIAL_LOG_LOSS_SCALE = 20.0
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def convert_module_to_f16(l):
|
| 16 |
+
"""
|
| 17 |
+
Convert primitive modules to float16.
|
| 18 |
+
"""
|
| 19 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
| 20 |
+
l.weight.data = l.weight.data.half()
|
| 21 |
+
if l.bias is not None:
|
| 22 |
+
l.bias.data = l.bias.data.half()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def convert_module_to_f32(l):
|
| 26 |
+
"""
|
| 27 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
| 28 |
+
"""
|
| 29 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
| 30 |
+
l.weight.data = l.weight.data.float()
|
| 31 |
+
if l.bias is not None:
|
| 32 |
+
l.bias.data = l.bias.data.float()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def make_master_params(param_groups_and_shapes):
|
| 36 |
+
"""
|
| 37 |
+
Copy model parameters into a (differently-shaped) list of full-precision
|
| 38 |
+
parameters.
|
| 39 |
+
"""
|
| 40 |
+
master_params = []
|
| 41 |
+
for param_group, shape in param_groups_and_shapes:
|
| 42 |
+
master_param = nn.Parameter(
|
| 43 |
+
_flatten_dense_tensors(
|
| 44 |
+
[param.detach().float() for (_, param) in param_group]
|
| 45 |
+
).view(shape)
|
| 46 |
+
)
|
| 47 |
+
master_param.requires_grad = True
|
| 48 |
+
master_params.append(master_param)
|
| 49 |
+
return master_params
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
|
| 53 |
+
"""
|
| 54 |
+
Copy the gradients from the model parameters into the master parameters
|
| 55 |
+
from make_master_params().
|
| 56 |
+
"""
|
| 57 |
+
for master_param, (param_group, shape) in zip(
|
| 58 |
+
master_params, param_groups_and_shapes
|
| 59 |
+
):
|
| 60 |
+
master_param.grad = _flatten_dense_tensors(
|
| 61 |
+
[param_grad_or_zeros(param) for (_, param) in param_group]
|
| 62 |
+
).view(shape)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def master_params_to_model_params(param_groups_and_shapes, master_params):
|
| 66 |
+
"""
|
| 67 |
+
Copy the master parameter data back into the model parameters.
|
| 68 |
+
"""
|
| 69 |
+
# Without copying to a list, if a generator is passed, this will
|
| 70 |
+
# silently not copy any parameters.
|
| 71 |
+
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
|
| 72 |
+
for (_, param), unflat_master_param in zip(
|
| 73 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
| 74 |
+
):
|
| 75 |
+
param.detach().copy_(unflat_master_param)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def unflatten_master_params(param_group, master_param):
|
| 79 |
+
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_param_groups_and_shapes(named_model_params):
|
| 83 |
+
named_model_params = list(named_model_params)
|
| 84 |
+
scalar_vector_named_params = (
|
| 85 |
+
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
|
| 86 |
+
(-1),
|
| 87 |
+
)
|
| 88 |
+
matrix_named_params = (
|
| 89 |
+
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
|
| 90 |
+
(1, -1),
|
| 91 |
+
)
|
| 92 |
+
return [scalar_vector_named_params, matrix_named_params]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def master_params_to_state_dict(
|
| 96 |
+
model, param_groups_and_shapes, master_params, use_fp16
|
| 97 |
+
):
|
| 98 |
+
if use_fp16:
|
| 99 |
+
state_dict = model.state_dict()
|
| 100 |
+
for master_param, (param_group, _) in zip(
|
| 101 |
+
master_params, param_groups_and_shapes
|
| 102 |
+
):
|
| 103 |
+
for (name, _), unflat_master_param in zip(
|
| 104 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
| 105 |
+
):
|
| 106 |
+
assert name in state_dict
|
| 107 |
+
state_dict[name] = unflat_master_param
|
| 108 |
+
else:
|
| 109 |
+
state_dict = model.state_dict()
|
| 110 |
+
for i, (name, _value) in enumerate(model.named_parameters()):
|
| 111 |
+
assert name in state_dict
|
| 112 |
+
state_dict[name] = master_params[i]
|
| 113 |
+
return state_dict
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def state_dict_to_master_params(model, state_dict, use_fp16):
|
| 117 |
+
if use_fp16:
|
| 118 |
+
named_model_params = [
|
| 119 |
+
(name, state_dict[name]) for name, _ in model.named_parameters()
|
| 120 |
+
]
|
| 121 |
+
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
|
| 122 |
+
master_params = make_master_params(param_groups_and_shapes)
|
| 123 |
+
else:
|
| 124 |
+
master_params = [state_dict[name] for name, _ in model.named_parameters()]
|
| 125 |
+
return master_params
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def zero_master_grads(master_params):
|
| 129 |
+
for param in master_params:
|
| 130 |
+
param.grad = None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def zero_grad(model_params):
|
| 134 |
+
for param in model_params:
|
| 135 |
+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
|
| 136 |
+
if param.grad is not None:
|
| 137 |
+
param.grad.detach_()
|
| 138 |
+
param.grad.zero_()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def param_grad_or_zeros(param):
|
| 142 |
+
if param.grad is not None:
|
| 143 |
+
return param.grad.data.detach()
|
| 144 |
+
else:
|
| 145 |
+
return th.zeros_like(param)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class MixedPrecisionTrainer:
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
*,
|
| 152 |
+
model,
|
| 153 |
+
use_fp16=False,
|
| 154 |
+
fp16_scale_growth=1e-3,
|
| 155 |
+
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
|
| 156 |
+
):
|
| 157 |
+
self.model = model
|
| 158 |
+
self.use_fp16 = use_fp16
|
| 159 |
+
self.fp16_scale_growth = fp16_scale_growth
|
| 160 |
+
|
| 161 |
+
self.model_params = list(self.model.parameters())
|
| 162 |
+
self.master_params = self.model_params
|
| 163 |
+
self.param_groups_and_shapes = None
|
| 164 |
+
self.lg_loss_scale = initial_lg_loss_scale
|
| 165 |
+
|
| 166 |
+
if self.use_fp16:
|
| 167 |
+
self.param_groups_and_shapes = get_param_groups_and_shapes(
|
| 168 |
+
self.model.named_parameters()
|
| 169 |
+
)
|
| 170 |
+
self.master_params = make_master_params(self.param_groups_and_shapes)
|
| 171 |
+
self.model.convert_to_fp16()
|
| 172 |
+
|
| 173 |
+
def zero_grad(self):
|
| 174 |
+
zero_grad(self.model_params)
|
| 175 |
+
|
| 176 |
+
def backward(self, loss: th.Tensor):
|
| 177 |
+
if self.use_fp16:
|
| 178 |
+
loss_scale = 2 ** self.lg_loss_scale
|
| 179 |
+
(loss * loss_scale).backward()
|
| 180 |
+
else:
|
| 181 |
+
loss.backward()
|
| 182 |
+
|
| 183 |
+
def optimize(self, opt: th.optim.Optimizer):
|
| 184 |
+
if self.use_fp16:
|
| 185 |
+
return self._optimize_fp16(opt)
|
| 186 |
+
else:
|
| 187 |
+
return self._optimize_normal(opt)
|
| 188 |
+
|
| 189 |
+
def _optimize_fp16(self, opt: th.optim.Optimizer):
|
| 190 |
+
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
|
| 191 |
+
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
|
| 192 |
+
grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
|
| 193 |
+
if check_overflow(grad_norm):
|
| 194 |
+
self.lg_loss_scale -= 1
|
| 195 |
+
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
|
| 196 |
+
zero_master_grads(self.master_params)
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
| 200 |
+
logger.logkv_mean("param_norm", param_norm)
|
| 201 |
+
|
| 202 |
+
for p in self.master_params:
|
| 203 |
+
p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
|
| 204 |
+
opt.step()
|
| 205 |
+
zero_master_grads(self.master_params)
|
| 206 |
+
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
|
| 207 |
+
self.lg_loss_scale += self.fp16_scale_growth
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
def _optimize_normal(self, opt: th.optim.Optimizer):
|
| 211 |
+
grad_norm, param_norm = self._compute_norms()
|
| 212 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
| 213 |
+
logger.logkv_mean("param_norm", param_norm)
|
| 214 |
+
opt.step()
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
def _compute_norms(self, grad_scale=1.0):
|
| 218 |
+
grad_norm = 0.0
|
| 219 |
+
param_norm = 0.0
|
| 220 |
+
for p in self.master_params:
|
| 221 |
+
with th.no_grad():
|
| 222 |
+
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
|
| 223 |
+
if p.grad is not None:
|
| 224 |
+
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
|
| 225 |
+
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
|
| 226 |
+
|
| 227 |
+
def master_params_to_state_dict(self, master_params):
|
| 228 |
+
return master_params_to_state_dict(
|
| 229 |
+
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def state_dict_to_master_params(self, state_dict):
|
| 233 |
+
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def check_overflow(value):
|
| 237 |
+
return (value == float("inf")) or (value == -float("inf")) or (value != value)
|
guided-diffusion/guided_diffusion/gaussian_diffusion.py
ADDED
|
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This code started out as a PyTorch port of Ho et al's diffusion models:
|
| 3 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
|
| 4 |
+
|
| 5 |
+
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import enum
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch as th
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
from . import dist_util
|
| 16 |
+
from .nn import mean_flat
|
| 17 |
+
from .losses import normal_kl, discretized_gaussian_log_likelihood
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
| 21 |
+
"""
|
| 22 |
+
Get a pre-defined beta schedule for the given name.
|
| 23 |
+
|
| 24 |
+
The beta schedule library consists of beta schedules which remain similar
|
| 25 |
+
in the limit of num_diffusion_timesteps.
|
| 26 |
+
Beta schedules may be added, but should not be removed or changed once
|
| 27 |
+
they are committed to maintain backwards compatibility.
|
| 28 |
+
"""
|
| 29 |
+
if schedule_name == "linear":
|
| 30 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
| 31 |
+
# diffusion steps.
|
| 32 |
+
scale = 1000 / num_diffusion_timesteps
|
| 33 |
+
beta_start = scale * 0.0001
|
| 34 |
+
beta_end = scale * 0.02
|
| 35 |
+
return np.linspace(
|
| 36 |
+
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
| 37 |
+
)
|
| 38 |
+
elif schedule_name == "cosine":
|
| 39 |
+
return betas_for_alpha_bar(
|
| 40 |
+
num_diffusion_timesteps,
|
| 41 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 48 |
+
"""
|
| 49 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 50 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 51 |
+
|
| 52 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 53 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 54 |
+
produces the cumulative product of (1-beta) up to that
|
| 55 |
+
part of the diffusion process.
|
| 56 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 57 |
+
prevent singularities.
|
| 58 |
+
"""
|
| 59 |
+
betas = []
|
| 60 |
+
for i in range(num_diffusion_timesteps):
|
| 61 |
+
t1 = i / num_diffusion_timesteps
|
| 62 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 63 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 64 |
+
return np.array(betas)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ModelMeanType(enum.Enum):
|
| 68 |
+
"""
|
| 69 |
+
Which type of output the model predicts.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
| 73 |
+
START_X = enum.auto() # the model predicts x_0
|
| 74 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ModelVarType(enum.Enum):
|
| 78 |
+
"""
|
| 79 |
+
What is used as the model's output variance.
|
| 80 |
+
|
| 81 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
| 82 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
LEARNED = enum.auto()
|
| 86 |
+
FIXED_SMALL = enum.auto()
|
| 87 |
+
FIXED_LARGE = enum.auto()
|
| 88 |
+
LEARNED_RANGE = enum.auto()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LossType(enum.Enum):
|
| 92 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
| 93 |
+
RESCALED_MSE = (
|
| 94 |
+
enum.auto()
|
| 95 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
| 96 |
+
KL = enum.auto() # use the variational lower-bound
|
| 97 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
| 98 |
+
|
| 99 |
+
def is_vb(self):
|
| 100 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class GaussianDiffusion:
|
| 104 |
+
"""
|
| 105 |
+
Utilities for training and sampling diffusion models.
|
| 106 |
+
|
| 107 |
+
Ported directly from here, and then adapted over time to further experimentation.
|
| 108 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
| 109 |
+
|
| 110 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
| 111 |
+
starting at T and going to 1.
|
| 112 |
+
:param model_mean_type: a ModelMeanType determining what the model outputs.
|
| 113 |
+
:param model_var_type: a ModelVarType determining how variance is output.
|
| 114 |
+
:param loss_type: a LossType determining the loss function to use.
|
| 115 |
+
:param rescale_timesteps: if True, pass floating point timesteps into the
|
| 116 |
+
model so that they are always scaled like in the
|
| 117 |
+
original paper (0 to 1000).
|
| 118 |
+
:param loss_variation: if True, then use composite loss
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
*,
|
| 124 |
+
betas,
|
| 125 |
+
model_mean_type,
|
| 126 |
+
model_var_type,
|
| 127 |
+
loss_type,
|
| 128 |
+
rescale_timesteps=False,
|
| 129 |
+
loss_variation=False,
|
| 130 |
+
):
|
| 131 |
+
self.model_mean_type = model_mean_type
|
| 132 |
+
self.model_var_type = model_var_type
|
| 133 |
+
self.loss_type = loss_type
|
| 134 |
+
self.rescale_timesteps = rescale_timesteps
|
| 135 |
+
self.loss_variation = loss_variation
|
| 136 |
+
|
| 137 |
+
# Use float64 for accuracy.
|
| 138 |
+
betas = np.array(betas, dtype=np.float64)
|
| 139 |
+
self.betas = betas
|
| 140 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
| 141 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
| 142 |
+
|
| 143 |
+
self.num_timesteps = int(betas.shape[0])
|
| 144 |
+
|
| 145 |
+
alphas = 1.0 - betas
|
| 146 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 147 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
| 148 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
| 149 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
| 150 |
+
|
| 151 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 152 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
| 153 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
| 154 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
| 155 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
| 156 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
| 157 |
+
|
| 158 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 159 |
+
self.posterior_variance = (
|
| 160 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 161 |
+
)
|
| 162 |
+
# log calculation clipped because the posterior variance is 0 at the
|
| 163 |
+
# beginning of the diffusion chain.
|
| 164 |
+
self.posterior_log_variance_clipped = np.log(
|
| 165 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
| 166 |
+
)
|
| 167 |
+
self.posterior_mean_coef1 = (
|
| 168 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 169 |
+
)
|
| 170 |
+
self.posterior_mean_coef2 = (
|
| 171 |
+
(1.0 - self.alphas_cumprod_prev)
|
| 172 |
+
* np.sqrt(alphas)
|
| 173 |
+
/ (1.0 - self.alphas_cumprod)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def q_mean_variance(self, x_start, t):
|
| 177 |
+
"""
|
| 178 |
+
Get the distribution q(x_t | x_0).
|
| 179 |
+
|
| 180 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
| 181 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 182 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
| 183 |
+
"""
|
| 184 |
+
mean = (
|
| 185 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 186 |
+
)
|
| 187 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 188 |
+
log_variance = _extract_into_tensor(
|
| 189 |
+
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
| 190 |
+
)
|
| 191 |
+
return mean, variance, log_variance
|
| 192 |
+
|
| 193 |
+
def q_sample(self, x_start, t, noise=None):
|
| 194 |
+
"""
|
| 195 |
+
Diffuse the data for a given number of diffusion steps.
|
| 196 |
+
|
| 197 |
+
In other words, sample from q(x_t | x_0).
|
| 198 |
+
|
| 199 |
+
:param x_start: the initial data batch.
|
| 200 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 201 |
+
:param noise: if specified, the split-out normal noise.
|
| 202 |
+
:return: A noisy version of x_start.
|
| 203 |
+
"""
|
| 204 |
+
if noise is None:
|
| 205 |
+
noise = th.randn_like(x_start)
|
| 206 |
+
assert noise.shape == x_start.shape
|
| 207 |
+
return (
|
| 208 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 209 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 210 |
+
* noise
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
| 214 |
+
"""
|
| 215 |
+
Compute the mean and variance of the diffusion posterior:
|
| 216 |
+
|
| 217 |
+
q(x_{t-1} | x_t, x_0)
|
| 218 |
+
|
| 219 |
+
"""
|
| 220 |
+
assert x_start.shape == x_t.shape
|
| 221 |
+
posterior_mean = (
|
| 222 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 223 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 224 |
+
)
|
| 225 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 226 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
| 227 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
| 228 |
+
)
|
| 229 |
+
assert (
|
| 230 |
+
posterior_mean.shape[0]
|
| 231 |
+
== posterior_variance.shape[0]
|
| 232 |
+
== posterior_log_variance_clipped.shape[0]
|
| 233 |
+
== x_start.shape[0]
|
| 234 |
+
)
|
| 235 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 236 |
+
|
| 237 |
+
def p_mean_variance(
|
| 238 |
+
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
| 242 |
+
the initial x, x_0.
|
| 243 |
+
|
| 244 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
| 245 |
+
as input.
|
| 246 |
+
:param x: the [N x C x ...] tensor at time t.
|
| 247 |
+
:param t: a 1-D Tensor of timesteps.
|
| 248 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
| 249 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 250 |
+
x_start prediction before it is used to sample. Applies before
|
| 251 |
+
clip_denoised.
|
| 252 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 253 |
+
pass to the model. This can be used for conditioning.
|
| 254 |
+
:return: a dict with the following keys:
|
| 255 |
+
- 'mean': the model mean output.
|
| 256 |
+
- 'variance': the model variance output.
|
| 257 |
+
- 'log_variance': the log of 'variance'.
|
| 258 |
+
- 'pred_xstart': the prediction for x_0.
|
| 259 |
+
"""
|
| 260 |
+
if model_kwargs is None:
|
| 261 |
+
model_kwargs = {}
|
| 262 |
+
|
| 263 |
+
B, C = x.shape[:2]
|
| 264 |
+
assert t.shape == (B,)
|
| 265 |
+
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
| 266 |
+
|
| 267 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
| 268 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
| 269 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
| 270 |
+
if self.model_var_type == ModelVarType.LEARNED:
|
| 271 |
+
model_log_variance = model_var_values
|
| 272 |
+
model_variance = th.exp(model_log_variance)
|
| 273 |
+
else:
|
| 274 |
+
min_log = _extract_into_tensor(
|
| 275 |
+
self.posterior_log_variance_clipped, t, x.shape
|
| 276 |
+
)
|
| 277 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
| 278 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
| 279 |
+
frac = (model_var_values + 1) / 2
|
| 280 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
| 281 |
+
model_variance = th.exp(model_log_variance)
|
| 282 |
+
else:
|
| 283 |
+
model_variance, model_log_variance = {
|
| 284 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
| 285 |
+
# to get a better decoder log likelihood.
|
| 286 |
+
ModelVarType.FIXED_LARGE: (
|
| 287 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
| 288 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
| 289 |
+
),
|
| 290 |
+
ModelVarType.FIXED_SMALL: (
|
| 291 |
+
self.posterior_variance,
|
| 292 |
+
self.posterior_log_variance_clipped,
|
| 293 |
+
),
|
| 294 |
+
}[self.model_var_type]
|
| 295 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
| 296 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
| 297 |
+
|
| 298 |
+
def process_xstart(x):
|
| 299 |
+
if denoised_fn is not None:
|
| 300 |
+
x = denoised_fn(x)
|
| 301 |
+
if clip_denoised:
|
| 302 |
+
return x.clamp(-1, 1)
|
| 303 |
+
return x
|
| 304 |
+
|
| 305 |
+
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
| 306 |
+
pred_xstart = process_xstart(
|
| 307 |
+
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
| 308 |
+
)
|
| 309 |
+
model_mean = model_output
|
| 310 |
+
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
| 311 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
| 312 |
+
pred_xstart = process_xstart(model_output)
|
| 313 |
+
else:
|
| 314 |
+
pred_xstart = process_xstart(
|
| 315 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
| 316 |
+
)
|
| 317 |
+
model_mean, _, _ = self.q_posterior_mean_variance(
|
| 318 |
+
x_start=pred_xstart, x_t=x, t=t
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
raise NotImplementedError(self.model_mean_type)
|
| 322 |
+
|
| 323 |
+
assert (
|
| 324 |
+
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
| 325 |
+
)
|
| 326 |
+
return {
|
| 327 |
+
"mean": model_mean,
|
| 328 |
+
"variance": model_variance,
|
| 329 |
+
"log_variance": model_log_variance,
|
| 330 |
+
"pred_xstart": pred_xstart,
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
| 334 |
+
assert x_t.shape == eps.shape
|
| 335 |
+
return (
|
| 336 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 337 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def _predict_xstart_from_xprev(self, x_t, t, xprev):
|
| 341 |
+
assert x_t.shape == xprev.shape
|
| 342 |
+
return ( # (xprev - coef2*x_t) / coef1
|
| 343 |
+
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
|
| 344 |
+
- _extract_into_tensor(
|
| 345 |
+
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
|
| 346 |
+
)
|
| 347 |
+
* x_t
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 351 |
+
return (
|
| 352 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 353 |
+
- pred_xstart
|
| 354 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 355 |
+
|
| 356 |
+
def _scale_timesteps(self, t):
|
| 357 |
+
if self.rescale_timesteps:
|
| 358 |
+
return t.float() * (1000.0 / self.num_timesteps)
|
| 359 |
+
return t
|
| 360 |
+
|
| 361 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
| 362 |
+
"""
|
| 363 |
+
Compute the mean for the previous step, given a function cond_fn that
|
| 364 |
+
computes the gradient of a conditional log probability with respect to
|
| 365 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
| 366 |
+
condition on y.
|
| 367 |
+
|
| 368 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
| 369 |
+
"""
|
| 370 |
+
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
|
| 371 |
+
new_mean = (
|
| 372 |
+
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
| 373 |
+
)
|
| 374 |
+
return new_mean
|
| 375 |
+
|
| 376 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
| 377 |
+
"""
|
| 378 |
+
Compute what the p_mean_variance output would have been, should the
|
| 379 |
+
model's score function be conditioned by cond_fn.
|
| 380 |
+
|
| 381 |
+
See condition_mean() for details on cond_fn.
|
| 382 |
+
|
| 383 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
| 384 |
+
from Song et al (2020).
|
| 385 |
+
"""
|
| 386 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
| 387 |
+
|
| 388 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
| 389 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
|
| 390 |
+
x, self._scale_timesteps(t), **model_kwargs
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
out = p_mean_var.copy()
|
| 394 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
| 395 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(
|
| 396 |
+
x_start=out["pred_xstart"], x_t=x, t=t
|
| 397 |
+
)
|
| 398 |
+
return out
|
| 399 |
+
|
| 400 |
+
def p_sample(
|
| 401 |
+
self,
|
| 402 |
+
model,
|
| 403 |
+
x,
|
| 404 |
+
t,
|
| 405 |
+
clip_denoised=True,
|
| 406 |
+
denoised_fn=None,
|
| 407 |
+
cond_fn=None,
|
| 408 |
+
model_kwargs=None,
|
| 409 |
+
):
|
| 410 |
+
"""
|
| 411 |
+
Sample x_{t-1} from the model at the given timestep.
|
| 412 |
+
|
| 413 |
+
:param model: the model to sample from.
|
| 414 |
+
:param x: the current tensor at x_{t-1}.
|
| 415 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
| 416 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
| 417 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 418 |
+
x_start prediction before it is used to sample.
|
| 419 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
| 420 |
+
similarly to the model.
|
| 421 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 422 |
+
pass to the model. This can be used for conditioning.
|
| 423 |
+
:return: a dict containing the following keys:
|
| 424 |
+
- 'sample': a random sample from the model.
|
| 425 |
+
- 'pred_xstart': a prediction of x_0.
|
| 426 |
+
"""
|
| 427 |
+
out = self.p_mean_variance(
|
| 428 |
+
model,
|
| 429 |
+
x,
|
| 430 |
+
t,
|
| 431 |
+
clip_denoised=clip_denoised,
|
| 432 |
+
denoised_fn=denoised_fn,
|
| 433 |
+
model_kwargs=model_kwargs,
|
| 434 |
+
)
|
| 435 |
+
noise = th.randn_like(x)
|
| 436 |
+
nonzero_mask = (
|
| 437 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
| 438 |
+
) # no noise when t == 0
|
| 439 |
+
if cond_fn is not None:
|
| 440 |
+
out["mean"] = self.condition_mean(
|
| 441 |
+
cond_fn, out, x, t, model_kwargs=model_kwargs
|
| 442 |
+
)
|
| 443 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
| 444 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
| 445 |
+
|
| 446 |
+
def p_sample_loop(
|
| 447 |
+
self,
|
| 448 |
+
model,
|
| 449 |
+
shape,
|
| 450 |
+
noise=None,
|
| 451 |
+
clip_denoised=True,
|
| 452 |
+
denoised_fn=None,
|
| 453 |
+
cond_fn=None,
|
| 454 |
+
model_kwargs=None,
|
| 455 |
+
device=None,
|
| 456 |
+
progress=False,
|
| 457 |
+
):
|
| 458 |
+
"""
|
| 459 |
+
Generate samples from the model.
|
| 460 |
+
|
| 461 |
+
:param model: the model module.
|
| 462 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
| 463 |
+
:param noise: if specified, the noise from the encoder to sample.
|
| 464 |
+
Should be of the same shape as `shape`.
|
| 465 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
| 466 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 467 |
+
x_start prediction before it is used to sample.
|
| 468 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
| 469 |
+
similarly to the model.
|
| 470 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 471 |
+
pass to the model. This can be used for conditioning.
|
| 472 |
+
:param device: if specified, the device to create the samples on.
|
| 473 |
+
If not specified, use a model parameter's device.
|
| 474 |
+
:param progress: if True, show a tqdm progress bar.
|
| 475 |
+
:return: a non-differentiable batch of samples.
|
| 476 |
+
"""
|
| 477 |
+
final = None
|
| 478 |
+
for sample in self.p_sample_loop_progressive(
|
| 479 |
+
model,
|
| 480 |
+
shape,
|
| 481 |
+
noise=noise,
|
| 482 |
+
clip_denoised=clip_denoised,
|
| 483 |
+
denoised_fn=denoised_fn,
|
| 484 |
+
cond_fn=cond_fn,
|
| 485 |
+
model_kwargs=model_kwargs,
|
| 486 |
+
device=device,
|
| 487 |
+
progress=progress,
|
| 488 |
+
):
|
| 489 |
+
final = sample
|
| 490 |
+
return final["sample"]
|
| 491 |
+
|
| 492 |
+
def p_sample_loop_progressive(
|
| 493 |
+
self,
|
| 494 |
+
model,
|
| 495 |
+
shape,
|
| 496 |
+
noise=None,
|
| 497 |
+
clip_denoised=True,
|
| 498 |
+
denoised_fn=None,
|
| 499 |
+
cond_fn=None,
|
| 500 |
+
model_kwargs=None,
|
| 501 |
+
device=None,
|
| 502 |
+
progress=False,
|
| 503 |
+
):
|
| 504 |
+
"""
|
| 505 |
+
Generate samples from the model and yield intermediate samples from
|
| 506 |
+
each timestep of diffusion.
|
| 507 |
+
|
| 508 |
+
Arguments are the same as p_sample_loop().
|
| 509 |
+
Returns a generator over dicts, where each dict is the return value of
|
| 510 |
+
p_sample().
|
| 511 |
+
"""
|
| 512 |
+
if device is None:
|
| 513 |
+
device = next(model.parameters()).device
|
| 514 |
+
assert isinstance(shape, (tuple, list))
|
| 515 |
+
if noise is not None:
|
| 516 |
+
img = noise
|
| 517 |
+
else:
|
| 518 |
+
img = th.randn(*shape, device=device)
|
| 519 |
+
indices = list(range(self.num_timesteps))[::-1]
|
| 520 |
+
|
| 521 |
+
if progress:
|
| 522 |
+
# Lazy import so that we don't depend on tqdm.
|
| 523 |
+
from tqdm.auto import tqdm
|
| 524 |
+
|
| 525 |
+
indices = tqdm(indices)
|
| 526 |
+
|
| 527 |
+
for i in indices:
|
| 528 |
+
t = th.tensor([i] * shape[0], device=device)
|
| 529 |
+
with th.no_grad():
|
| 530 |
+
out = self.p_sample(
|
| 531 |
+
model,
|
| 532 |
+
img,
|
| 533 |
+
t,
|
| 534 |
+
clip_denoised=clip_denoised,
|
| 535 |
+
denoised_fn=denoised_fn,
|
| 536 |
+
cond_fn=cond_fn,
|
| 537 |
+
model_kwargs=model_kwargs,
|
| 538 |
+
)
|
| 539 |
+
yield out
|
| 540 |
+
img = out["sample"]
|
| 541 |
+
|
| 542 |
+
def ddim_sample(
|
| 543 |
+
self,
|
| 544 |
+
model,
|
| 545 |
+
x,
|
| 546 |
+
t,
|
| 547 |
+
clip_denoised=True,
|
| 548 |
+
denoised_fn=None,
|
| 549 |
+
cond_fn=None,
|
| 550 |
+
model_kwargs=None,
|
| 551 |
+
eta=0.0,
|
| 552 |
+
):
|
| 553 |
+
"""
|
| 554 |
+
Sample x_{t-1} from the model using DDIM.
|
| 555 |
+
|
| 556 |
+
Same usage as p_sample().
|
| 557 |
+
"""
|
| 558 |
+
out = self.p_mean_variance(
|
| 559 |
+
model,
|
| 560 |
+
x,
|
| 561 |
+
t,
|
| 562 |
+
clip_denoised=clip_denoised,
|
| 563 |
+
denoised_fn=denoised_fn,
|
| 564 |
+
model_kwargs=model_kwargs,
|
| 565 |
+
)
|
| 566 |
+
if cond_fn is not None:
|
| 567 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
| 568 |
+
|
| 569 |
+
# Usually our model outputs epsilon, but we re-derive it
|
| 570 |
+
# in case we used x_start or x_prev prediction.
|
| 571 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
| 572 |
+
|
| 573 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
| 574 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
| 575 |
+
sigma = (
|
| 576 |
+
eta
|
| 577 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
| 578 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
| 579 |
+
)
|
| 580 |
+
# Equation 12.
|
| 581 |
+
noise = th.randn_like(x)
|
| 582 |
+
mean_pred = (
|
| 583 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
| 584 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
| 585 |
+
)
|
| 586 |
+
nonzero_mask = (
|
| 587 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
| 588 |
+
) # no noise when t == 0
|
| 589 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
| 590 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
| 591 |
+
|
| 592 |
+
def ddim_reverse_sample(
|
| 593 |
+
self,
|
| 594 |
+
model,
|
| 595 |
+
x,
|
| 596 |
+
t,
|
| 597 |
+
clip_denoised=True,
|
| 598 |
+
denoised_fn=None,
|
| 599 |
+
model_kwargs=None,
|
| 600 |
+
eta=0.0,
|
| 601 |
+
):
|
| 602 |
+
"""
|
| 603 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
| 604 |
+
"""
|
| 605 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
| 606 |
+
out = self.p_mean_variance(
|
| 607 |
+
model,
|
| 608 |
+
x,
|
| 609 |
+
t,
|
| 610 |
+
clip_denoised=clip_denoised,
|
| 611 |
+
denoised_fn=denoised_fn,
|
| 612 |
+
model_kwargs=model_kwargs,
|
| 613 |
+
)
|
| 614 |
+
# Usually our model outputs epsilon, but we re-derive it
|
| 615 |
+
# in case we used x_start or x_prev prediction.
|
| 616 |
+
eps = (
|
| 617 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
| 618 |
+
- out["pred_xstart"]
|
| 619 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
| 620 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
| 621 |
+
|
| 622 |
+
# Equation 12. reversed
|
| 623 |
+
mean_pred = (
|
| 624 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_next)
|
| 625 |
+
+ th.sqrt(1 - alpha_bar_next) * eps
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
| 629 |
+
|
| 630 |
+
def ddim_sample_loop(
|
| 631 |
+
self,
|
| 632 |
+
model,
|
| 633 |
+
shape,
|
| 634 |
+
noise=None,
|
| 635 |
+
clip_denoised=True,
|
| 636 |
+
denoised_fn=None,
|
| 637 |
+
cond_fn=None,
|
| 638 |
+
model_kwargs=None,
|
| 639 |
+
device=None,
|
| 640 |
+
progress=False,
|
| 641 |
+
eta=0.0,
|
| 642 |
+
):
|
| 643 |
+
"""
|
| 644 |
+
Generate samples from the model using DDIM.
|
| 645 |
+
|
| 646 |
+
Same usage as p_sample_loop().
|
| 647 |
+
"""
|
| 648 |
+
final = None
|
| 649 |
+
for sample in self.ddim_sample_loop_progressive(
|
| 650 |
+
model,
|
| 651 |
+
shape,
|
| 652 |
+
noise=noise,
|
| 653 |
+
clip_denoised=clip_denoised,
|
| 654 |
+
denoised_fn=denoised_fn,
|
| 655 |
+
cond_fn=cond_fn,
|
| 656 |
+
model_kwargs=model_kwargs,
|
| 657 |
+
device=device,
|
| 658 |
+
progress=progress,
|
| 659 |
+
eta=eta,
|
| 660 |
+
):
|
| 661 |
+
final = sample
|
| 662 |
+
return final["sample"]
|
| 663 |
+
|
| 664 |
+
def ddim_sample_loop_progressive(
|
| 665 |
+
self,
|
| 666 |
+
model,
|
| 667 |
+
shape,
|
| 668 |
+
noise=None,
|
| 669 |
+
clip_denoised=True,
|
| 670 |
+
denoised_fn=None,
|
| 671 |
+
cond_fn=None,
|
| 672 |
+
model_kwargs=None,
|
| 673 |
+
device=None,
|
| 674 |
+
progress=False,
|
| 675 |
+
eta=0.0,
|
| 676 |
+
):
|
| 677 |
+
"""
|
| 678 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
| 679 |
+
each timestep of DDIM.
|
| 680 |
+
|
| 681 |
+
Same usage as p_sample_loop_progressive().
|
| 682 |
+
"""
|
| 683 |
+
if device is None:
|
| 684 |
+
device = next(model.parameters()).device
|
| 685 |
+
assert isinstance(shape, (tuple, list))
|
| 686 |
+
if noise is not None:
|
| 687 |
+
img = noise
|
| 688 |
+
else:
|
| 689 |
+
img = th.randn(*shape, device=device)
|
| 690 |
+
indices = list(range(self.num_timesteps))[::-1]
|
| 691 |
+
|
| 692 |
+
if progress:
|
| 693 |
+
# Lazy import so that we don't depend on tqdm.
|
| 694 |
+
from tqdm.auto import tqdm
|
| 695 |
+
|
| 696 |
+
indices = tqdm(indices)
|
| 697 |
+
|
| 698 |
+
for i in indices:
|
| 699 |
+
t = th.tensor([i] * shape[0], device=device)
|
| 700 |
+
with th.no_grad():
|
| 701 |
+
out = self.ddim_sample(
|
| 702 |
+
model,
|
| 703 |
+
img,
|
| 704 |
+
t,
|
| 705 |
+
clip_denoised=clip_denoised,
|
| 706 |
+
denoised_fn=denoised_fn,
|
| 707 |
+
cond_fn=cond_fn,
|
| 708 |
+
model_kwargs=model_kwargs,
|
| 709 |
+
eta=eta,
|
| 710 |
+
)
|
| 711 |
+
yield out
|
| 712 |
+
img = out["sample"]
|
| 713 |
+
|
| 714 |
+
def _vb_terms_bpd(
|
| 715 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
| 716 |
+
):
|
| 717 |
+
"""
|
| 718 |
+
Get a term for the variational lower-bound.
|
| 719 |
+
|
| 720 |
+
The resulting units are bits (rather than nats, as one might expect).
|
| 721 |
+
This allows for comparison to other papers.
|
| 722 |
+
|
| 723 |
+
:return: a dict with the following keys:
|
| 724 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
| 725 |
+
- 'pred_xstart': the x_0 predictions.
|
| 726 |
+
"""
|
| 727 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
| 728 |
+
x_start=x_start, x_t=x_t, t=t
|
| 729 |
+
)
|
| 730 |
+
out = self.p_mean_variance(
|
| 731 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
| 732 |
+
)
|
| 733 |
+
kl = normal_kl(
|
| 734 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
| 735 |
+
)
|
| 736 |
+
if ("cond_img" in model_kwargs) and ("mask" in model_kwargs): #added by soumik
|
| 737 |
+
kl = kl*model_kwargs["mask"]
|
| 738 |
+
kl = mean_flat(kl) / np.log(2.0)
|
| 739 |
+
|
| 740 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
| 741 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
| 742 |
+
)
|
| 743 |
+
assert decoder_nll.shape == x_start.shape
|
| 744 |
+
if ("cond_img" in model_kwargs) and ("mask" in model_kwargs): #added by soumik
|
| 745 |
+
decoder_nll=decoder_nll*model_kwargs["mask"]
|
| 746 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
| 747 |
+
|
| 748 |
+
# At the first timestep return the decoder NLL,
|
| 749 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
| 750 |
+
output = th.where((t == 0), decoder_nll, kl)
|
| 751 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def _prior_bpd(self, x_start):
|
| 755 |
+
"""
|
| 756 |
+
Get the prior KL term for the variational lower-bound, measured in
|
| 757 |
+
bits-per-dim.
|
| 758 |
+
|
| 759 |
+
This term can't be optimized, as it only depends on the encoder.
|
| 760 |
+
|
| 761 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 762 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
| 763 |
+
"""
|
| 764 |
+
batch_size = x_start.shape[0]
|
| 765 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
| 766 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 767 |
+
kl_prior = normal_kl(
|
| 768 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
| 769 |
+
)
|
| 770 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
| 771 |
+
|
| 772 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
| 773 |
+
"""
|
| 774 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
| 775 |
+
as well as other related quantities.
|
| 776 |
+
|
| 777 |
+
:param model: the model to evaluate loss on.
|
| 778 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 779 |
+
:param clip_denoised: if True, clip denoised samples.
|
| 780 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 781 |
+
pass to the model. This can be used for conditioning.
|
| 782 |
+
|
| 783 |
+
:return: a dict containing the following keys:
|
| 784 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
| 785 |
+
- prior_bpd: the prior term in the lower-bound.
|
| 786 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
| 787 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
| 788 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
| 789 |
+
"""
|
| 790 |
+
device = x_start.device
|
| 791 |
+
batch_size = x_start.shape[0]
|
| 792 |
+
|
| 793 |
+
vb = []
|
| 794 |
+
xstart_mse = []
|
| 795 |
+
mse = []
|
| 796 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
| 797 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
| 798 |
+
noise = th.randn_like(x_start)
|
| 799 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
| 800 |
+
# Calculate VLB term at the current timestep
|
| 801 |
+
with th.no_grad():
|
| 802 |
+
out = self._vb_terms_bpd(
|
| 803 |
+
model,
|
| 804 |
+
x_start=x_start,
|
| 805 |
+
x_t=x_t,
|
| 806 |
+
t=t_batch,
|
| 807 |
+
clip_denoised=clip_denoised,
|
| 808 |
+
model_kwargs=model_kwargs,
|
| 809 |
+
)
|
| 810 |
+
vb.append(out["output"])
|
| 811 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
| 812 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
| 813 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
| 814 |
+
|
| 815 |
+
vb = th.stack(vb, dim=1)
|
| 816 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
| 817 |
+
mse = th.stack(mse, dim=1)
|
| 818 |
+
|
| 819 |
+
prior_bpd = self._prior_bpd(x_start)
|
| 820 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
| 821 |
+
return {
|
| 822 |
+
"total_bpd": total_bpd,
|
| 823 |
+
"prior_bpd": prior_bpd,
|
| 824 |
+
"vb": vb,
|
| 825 |
+
"xstart_mse": xstart_mse,
|
| 826 |
+
"mse": mse,
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
| 831 |
+
"""
|
| 832 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
| 833 |
+
|
| 834 |
+
:param arr: the 1-D numpy array.
|
| 835 |
+
:param timesteps: a tensor of indices into the array to extract.
|
| 836 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
| 837 |
+
dimension equal to the length of timesteps.
|
| 838 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
| 839 |
+
"""
|
| 840 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
| 841 |
+
while len(res.shape) < len(broadcast_shape):
|
| 842 |
+
res = res[..., None]
|
| 843 |
+
return res.expand(broadcast_shape)
|
guided-diffusion/guided_diffusion/image_datasets.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import blobfile as bf
|
| 6 |
+
from mpi4py import MPI
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_data(
|
| 12 |
+
*,
|
| 13 |
+
data_dir,
|
| 14 |
+
batch_size,
|
| 15 |
+
image_size,
|
| 16 |
+
class_cond=False,
|
| 17 |
+
deterministic=False,
|
| 18 |
+
random_crop=False,
|
| 19 |
+
random_flip=True,
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
For a dataset, create a generator over (images, kwargs) pairs.
|
| 23 |
+
|
| 24 |
+
Each images is an NCHW float tensor, and the kwargs dict contains zero or
|
| 25 |
+
more keys, each of which map to a batched Tensor of their own.
|
| 26 |
+
The kwargs dict can be used for class labels, in which case the key is "y"
|
| 27 |
+
and the values are integer tensors of class labels.
|
| 28 |
+
|
| 29 |
+
:param data_dir: a dataset directory.
|
| 30 |
+
:param batch_size: the batch size of each returned pair.
|
| 31 |
+
:param image_size: the size to which images are resized.
|
| 32 |
+
:param class_cond: if True, include a "y" key in returned dicts for class
|
| 33 |
+
label. If classes are not available and this is true, an
|
| 34 |
+
exception will be raised.
|
| 35 |
+
:param deterministic: if True, yield results in a deterministic order.
|
| 36 |
+
:param random_crop: if True, randomly crop the images for augmentation.
|
| 37 |
+
:param random_flip: if True, randomly flip the images for augmentation.
|
| 38 |
+
"""
|
| 39 |
+
if not data_dir:
|
| 40 |
+
raise ValueError("unspecified data directory")
|
| 41 |
+
all_files = _list_image_files_recursively(data_dir)
|
| 42 |
+
classes = None
|
| 43 |
+
if class_cond:
|
| 44 |
+
# Assume classes are the first part of the filename,
|
| 45 |
+
# before an underscore.
|
| 46 |
+
class_names = [bf.basename(path).split("_")[0] for path in all_files]
|
| 47 |
+
sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
|
| 48 |
+
classes = [sorted_classes[x] for x in class_names]
|
| 49 |
+
dataset = ImageDataset(
|
| 50 |
+
image_size,
|
| 51 |
+
all_files,
|
| 52 |
+
classes=classes,
|
| 53 |
+
shard=MPI.COMM_WORLD.Get_rank(),
|
| 54 |
+
num_shards=MPI.COMM_WORLD.Get_size(),
|
| 55 |
+
random_crop=random_crop,
|
| 56 |
+
random_flip=random_flip,
|
| 57 |
+
)
|
| 58 |
+
if deterministic:
|
| 59 |
+
loader = DataLoader(
|
| 60 |
+
dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
loader = DataLoader(
|
| 64 |
+
dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
|
| 65 |
+
)
|
| 66 |
+
while True:
|
| 67 |
+
yield from loader
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _list_image_files_recursively(data_dir):
|
| 71 |
+
results = []
|
| 72 |
+
for entry in sorted(bf.listdir(data_dir)):
|
| 73 |
+
full_path = bf.join(data_dir, entry)
|
| 74 |
+
ext = entry.split(".")[-1]
|
| 75 |
+
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
|
| 76 |
+
results.append(full_path)
|
| 77 |
+
elif bf.isdir(full_path):
|
| 78 |
+
results.extend(_list_image_files_recursively(full_path))
|
| 79 |
+
return results
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ImageDataset(Dataset):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
resolution,
|
| 86 |
+
image_paths,
|
| 87 |
+
classes=None,
|
| 88 |
+
shard=0,
|
| 89 |
+
num_shards=1,
|
| 90 |
+
random_crop=False,
|
| 91 |
+
random_flip=True,
|
| 92 |
+
):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.resolution = resolution
|
| 95 |
+
self.local_images = image_paths[shard:][::num_shards]
|
| 96 |
+
self.local_classes = None if classes is None else classes[shard:][::num_shards]
|
| 97 |
+
self.random_crop = random_crop
|
| 98 |
+
self.random_flip = random_flip
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.local_images)
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, idx):
|
| 104 |
+
path = self.local_images[idx]
|
| 105 |
+
with bf.BlobFile(path, "rb") as f:
|
| 106 |
+
pil_image = Image.open(f)
|
| 107 |
+
pil_image.load()
|
| 108 |
+
pil_image = pil_image.convert("RGB")
|
| 109 |
+
|
| 110 |
+
if self.random_crop:
|
| 111 |
+
arr = random_crop_arr(pil_image, self.resolution)
|
| 112 |
+
else:
|
| 113 |
+
arr = center_crop_arr(pil_image, self.resolution)
|
| 114 |
+
|
| 115 |
+
if self.random_flip and random.random() < 0.5:
|
| 116 |
+
arr = arr[:, ::-1]
|
| 117 |
+
|
| 118 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
| 119 |
+
|
| 120 |
+
out_dict = {}
|
| 121 |
+
if self.local_classes is not None:
|
| 122 |
+
out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
|
| 123 |
+
return np.transpose(arr, [2, 0, 1]), out_dict
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def center_crop_arr(pil_image, image_size):
|
| 127 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
| 128 |
+
# argument, which uses BOX downsampling at powers of two first.
|
| 129 |
+
# Thus, we do it by hand to improve downsample quality.
|
| 130 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 131 |
+
pil_image = pil_image.resize(
|
| 132 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
scale = image_size / min(*pil_image.size)
|
| 136 |
+
pil_image = pil_image.resize(
|
| 137 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
arr = np.array(pil_image)
|
| 141 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 142 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 143 |
+
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
|
| 147 |
+
min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
|
| 148 |
+
max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
|
| 149 |
+
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
|
| 150 |
+
|
| 151 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
| 152 |
+
# argument, which uses BOX downsampling at powers of two first.
|
| 153 |
+
# Thus, we do it by hand to improve downsample quality.
|
| 154 |
+
while min(*pil_image.size) >= 2 * smaller_dim_size:
|
| 155 |
+
pil_image = pil_image.resize(
|
| 156 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
scale = smaller_dim_size / min(*pil_image.size)
|
| 160 |
+
pil_image = pil_image.resize(
|
| 161 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
arr = np.array(pil_image)
|
| 165 |
+
crop_y = random.randrange(arr.shape[0] - image_size + 1)
|
| 166 |
+
crop_x = random.randrange(arr.shape[1] - image_size + 1)
|
| 167 |
+
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
|
guided-diffusion/guided_diffusion/logger.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
|
| 3 |
+
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import shutil
|
| 9 |
+
import os.path as osp
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import datetime
|
| 13 |
+
import tempfile
|
| 14 |
+
import warnings
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
from contextlib import contextmanager
|
| 17 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 18 |
+
|
| 19 |
+
DEBUG = 10
|
| 20 |
+
INFO = 20
|
| 21 |
+
WARN = 30
|
| 22 |
+
ERROR = 40
|
| 23 |
+
|
| 24 |
+
DISABLED = 50
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class KVWriter(object):
|
| 28 |
+
def writekvs(self, kvs):
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SeqWriter(object):
|
| 33 |
+
def writeseq(self, seq):
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class HumanOutputFormat(KVWriter, SeqWriter):
|
| 38 |
+
def __init__(self, filename_or_file):
|
| 39 |
+
if isinstance(filename_or_file, str):
|
| 40 |
+
self.file = open(filename_or_file, "wt")
|
| 41 |
+
self.own_file = True
|
| 42 |
+
else:
|
| 43 |
+
assert hasattr(filename_or_file, "read"), (
|
| 44 |
+
"expected file or str, got %s" % filename_or_file
|
| 45 |
+
)
|
| 46 |
+
self.file = filename_or_file
|
| 47 |
+
self.own_file = False
|
| 48 |
+
|
| 49 |
+
def writekvs(self, kvs):
|
| 50 |
+
# Create strings for printing
|
| 51 |
+
key2str = {}
|
| 52 |
+
for (key, val) in sorted(kvs.items()):
|
| 53 |
+
if hasattr(val, "__float__"):
|
| 54 |
+
valstr = "%-8.3g" % val
|
| 55 |
+
else:
|
| 56 |
+
valstr = str(val)
|
| 57 |
+
key2str[self._truncate(key)] = self._truncate(valstr)
|
| 58 |
+
|
| 59 |
+
# Find max widths
|
| 60 |
+
if len(key2str) == 0:
|
| 61 |
+
print("WARNING: tried to write empty key-value dict")
|
| 62 |
+
return
|
| 63 |
+
else:
|
| 64 |
+
keywidth = max(map(len, key2str.keys()))
|
| 65 |
+
valwidth = max(map(len, key2str.values()))
|
| 66 |
+
|
| 67 |
+
# Write out the data
|
| 68 |
+
dashes = "-" * (keywidth + valwidth + 7)
|
| 69 |
+
lines = [dashes]
|
| 70 |
+
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
|
| 71 |
+
lines.append(
|
| 72 |
+
"| %s%s | %s%s |"
|
| 73 |
+
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
|
| 74 |
+
)
|
| 75 |
+
lines.append(dashes)
|
| 76 |
+
self.file.write("\n".join(lines) + "\n")
|
| 77 |
+
|
| 78 |
+
# Flush the output to the file
|
| 79 |
+
self.file.flush()
|
| 80 |
+
|
| 81 |
+
def _truncate(self, s):
|
| 82 |
+
maxlen = 30
|
| 83 |
+
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
|
| 84 |
+
|
| 85 |
+
def writeseq(self, seq):
|
| 86 |
+
seq = list(seq)
|
| 87 |
+
for (i, elem) in enumerate(seq):
|
| 88 |
+
self.file.write(elem)
|
| 89 |
+
if i < len(seq) - 1: # add space unless this is the last one
|
| 90 |
+
self.file.write(" ")
|
| 91 |
+
self.file.write("\n")
|
| 92 |
+
self.file.flush()
|
| 93 |
+
|
| 94 |
+
def close(self):
|
| 95 |
+
if self.own_file:
|
| 96 |
+
self.file.close()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class JSONOutputFormat(KVWriter):
|
| 100 |
+
def __init__(self, filename):
|
| 101 |
+
self.file = open(filename, "wt")
|
| 102 |
+
|
| 103 |
+
def writekvs(self, kvs):
|
| 104 |
+
for k, v in sorted(kvs.items()):
|
| 105 |
+
if hasattr(v, "dtype"):
|
| 106 |
+
kvs[k] = float(v)
|
| 107 |
+
self.file.write(json.dumps(kvs) + "\n")
|
| 108 |
+
self.file.flush()
|
| 109 |
+
|
| 110 |
+
def close(self):
|
| 111 |
+
self.file.close()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class CSVOutputFormat(KVWriter):
|
| 115 |
+
def __init__(self, filename):
|
| 116 |
+
self.file = open(filename, "w+t")
|
| 117 |
+
self.keys = []
|
| 118 |
+
self.sep = ","
|
| 119 |
+
|
| 120 |
+
def writekvs(self, kvs):
|
| 121 |
+
# Add our current row to the history
|
| 122 |
+
extra_keys = list(kvs.keys() - self.keys)
|
| 123 |
+
extra_keys.sort()
|
| 124 |
+
if extra_keys:
|
| 125 |
+
self.keys.extend(extra_keys)
|
| 126 |
+
self.file.seek(0)
|
| 127 |
+
lines = self.file.readlines()
|
| 128 |
+
self.file.seek(0)
|
| 129 |
+
for (i, k) in enumerate(self.keys):
|
| 130 |
+
if i > 0:
|
| 131 |
+
self.file.write(",")
|
| 132 |
+
self.file.write(k)
|
| 133 |
+
self.file.write("\n")
|
| 134 |
+
for line in lines[1:]:
|
| 135 |
+
self.file.write(line[:-1])
|
| 136 |
+
self.file.write(self.sep * len(extra_keys))
|
| 137 |
+
self.file.write("\n")
|
| 138 |
+
for (i, k) in enumerate(self.keys):
|
| 139 |
+
if i > 0:
|
| 140 |
+
self.file.write(",")
|
| 141 |
+
v = kvs.get(k)
|
| 142 |
+
if v is not None:
|
| 143 |
+
self.file.write(str(v))
|
| 144 |
+
self.file.write("\n")
|
| 145 |
+
self.file.flush()
|
| 146 |
+
|
| 147 |
+
def close(self):
|
| 148 |
+
self.file.close()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class TensorBoardOutputFormat(KVWriter):
|
| 152 |
+
"""
|
| 153 |
+
Dumps key/value pairs into TensorBoard's numeric format.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, dir):
|
| 157 |
+
os.makedirs(dir, exist_ok=True)
|
| 158 |
+
self.dir = dir
|
| 159 |
+
self.step = -1
|
| 160 |
+
self.writer = SummaryWriter(self.dir)
|
| 161 |
+
|
| 162 |
+
def writekvs(self, kvs):
|
| 163 |
+
self.step = int(kvs["step"])
|
| 164 |
+
for k, v in sorted(kvs.items()):
|
| 165 |
+
self.writer.add_scalar(k, float(v), self.step)
|
| 166 |
+
self.writer.flush()
|
| 167 |
+
|
| 168 |
+
def writeimage(self, key, image_tensor):
|
| 169 |
+
self.writer.add_image(key, image_tensor, self.step)
|
| 170 |
+
self.writer.flush()
|
| 171 |
+
|
| 172 |
+
def close(self):
|
| 173 |
+
if self.writer:
|
| 174 |
+
self.writer.close()
|
| 175 |
+
self.writer = None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def make_output_format(format, ev_dir, log_suffix=""):
|
| 179 |
+
os.makedirs(ev_dir, exist_ok=True)
|
| 180 |
+
if format == "stdout":
|
| 181 |
+
return HumanOutputFormat(sys.stdout)
|
| 182 |
+
elif format == "log":
|
| 183 |
+
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
|
| 184 |
+
elif format == "json":
|
| 185 |
+
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
|
| 186 |
+
elif format == "csv":
|
| 187 |
+
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
|
| 188 |
+
elif format == "tensorboard":
|
| 189 |
+
return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError("Unknown format specified: %s" % (format,))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ================================================================
|
| 195 |
+
# API
|
| 196 |
+
# ================================================================
|
| 197 |
+
|
| 198 |
+
def logimage(key, image_tensor):
|
| 199 |
+
"""
|
| 200 |
+
Log one image to tensorboard
|
| 201 |
+
"""
|
| 202 |
+
for fmt in get_current().output_formats:
|
| 203 |
+
if isinstance(fmt, TensorBoardOutputFormat):
|
| 204 |
+
tb_logger = fmt
|
| 205 |
+
tb_logger.writeimage(key, image_tensor)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def logkv(key, val):
|
| 209 |
+
"""
|
| 210 |
+
Log a value of some diagnostic
|
| 211 |
+
Call this once for each diagnostic quantity, each iteration
|
| 212 |
+
If called many times, last value will be used.
|
| 213 |
+
"""
|
| 214 |
+
get_current().logkv(key, val)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def logkv_mean(key, val):
|
| 218 |
+
"""
|
| 219 |
+
The same as logkv(), but if called many times, values averaged.
|
| 220 |
+
"""
|
| 221 |
+
get_current().logkv_mean(key, val)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def logkvs(d):
|
| 225 |
+
"""
|
| 226 |
+
Log a dictionary of key-value pairs
|
| 227 |
+
"""
|
| 228 |
+
for (k, v) in d.items():
|
| 229 |
+
logkv(k, v)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def dumpkvs():
|
| 233 |
+
"""
|
| 234 |
+
Write all of the diagnostics from the current iteration
|
| 235 |
+
"""
|
| 236 |
+
return get_current().dumpkvs()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def getkvs():
|
| 240 |
+
return get_current().name2val
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def log(*args, level=INFO):
|
| 244 |
+
"""
|
| 245 |
+
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
| 246 |
+
"""
|
| 247 |
+
get_current().log(*args, level=level)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def debug(*args):
|
| 251 |
+
log(*args, level=DEBUG)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def info(*args):
|
| 255 |
+
log(*args, level=INFO)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def warn(*args):
|
| 259 |
+
log(*args, level=WARN)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def error(*args):
|
| 263 |
+
log(*args, level=ERROR)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def set_level(level):
|
| 267 |
+
"""
|
| 268 |
+
Set logging threshold on current logger.
|
| 269 |
+
"""
|
| 270 |
+
get_current().set_level(level)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def set_comm(comm):
|
| 274 |
+
get_current().set_comm(comm)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def get_dir():
|
| 278 |
+
"""
|
| 279 |
+
Get directory that log files are being written to.
|
| 280 |
+
will be None if there is no output directory (i.e., if you didn't call start)
|
| 281 |
+
"""
|
| 282 |
+
return get_current().get_dir()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
record_tabular = logkv
|
| 286 |
+
dump_tabular = dumpkvs
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@contextmanager
|
| 290 |
+
def profile_kv(scopename):
|
| 291 |
+
logkey = "wait_" + scopename
|
| 292 |
+
tstart = time.time()
|
| 293 |
+
try:
|
| 294 |
+
yield
|
| 295 |
+
finally:
|
| 296 |
+
get_current().name2val[logkey] += time.time() - tstart
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def profile(n):
|
| 300 |
+
"""
|
| 301 |
+
Usage:
|
| 302 |
+
@profile("my_func")
|
| 303 |
+
def my_func(): code
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def decorator_with_name(func):
|
| 307 |
+
def func_wrapper(*args, **kwargs):
|
| 308 |
+
with profile_kv(n):
|
| 309 |
+
return func(*args, **kwargs)
|
| 310 |
+
|
| 311 |
+
return func_wrapper
|
| 312 |
+
|
| 313 |
+
return decorator_with_name
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ================================================================
|
| 317 |
+
# Backend
|
| 318 |
+
# ================================================================
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def get_current():
|
| 322 |
+
if Logger.CURRENT is None:
|
| 323 |
+
_configure_default_logger()
|
| 324 |
+
|
| 325 |
+
return Logger.CURRENT
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class Logger(object):
|
| 329 |
+
DEFAULT = None # A logger with no output files. (See right below class definition)
|
| 330 |
+
# So that you can still log to the terminal without setting up any output files
|
| 331 |
+
CURRENT = None # Current logger being used by the free functions above
|
| 332 |
+
|
| 333 |
+
def __init__(self, dir, output_formats, comm=None):
|
| 334 |
+
self.name2val = defaultdict(float) # values this iteration
|
| 335 |
+
self.name2cnt = defaultdict(int)
|
| 336 |
+
self.level = INFO
|
| 337 |
+
self.dir = dir
|
| 338 |
+
self.output_formats = output_formats
|
| 339 |
+
self.comm = comm
|
| 340 |
+
|
| 341 |
+
# Logging API, forwarded
|
| 342 |
+
# ----------------------------------------
|
| 343 |
+
def logkv(self, key, val):
|
| 344 |
+
self.name2val[key] = val
|
| 345 |
+
|
| 346 |
+
def logkv_mean(self, key, val):
|
| 347 |
+
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
| 348 |
+
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
|
| 349 |
+
self.name2cnt[key] = cnt + 1
|
| 350 |
+
|
| 351 |
+
def dumpkvs(self):
|
| 352 |
+
if self.comm is None:
|
| 353 |
+
d = self.name2val
|
| 354 |
+
else:
|
| 355 |
+
d = mpi_weighted_mean(
|
| 356 |
+
self.comm,
|
| 357 |
+
{
|
| 358 |
+
name: (val, self.name2cnt.get(name, 1))
|
| 359 |
+
for (name, val) in self.name2val.items()
|
| 360 |
+
},
|
| 361 |
+
)
|
| 362 |
+
if self.comm.rank != 0:
|
| 363 |
+
d["dummy"] = 1 # so we don't get a warning about empty dict
|
| 364 |
+
out = d.copy() # Return the dict for unit testing purposes
|
| 365 |
+
for fmt in self.output_formats:
|
| 366 |
+
if isinstance(fmt, KVWriter):
|
| 367 |
+
fmt.writekvs(d)
|
| 368 |
+
self.name2val.clear()
|
| 369 |
+
self.name2cnt.clear()
|
| 370 |
+
return out
|
| 371 |
+
|
| 372 |
+
def log(self, *args, level=INFO):
|
| 373 |
+
if self.level <= level:
|
| 374 |
+
self._do_log(args)
|
| 375 |
+
|
| 376 |
+
# Configuration
|
| 377 |
+
# ----------------------------------------
|
| 378 |
+
def set_level(self, level):
|
| 379 |
+
self.level = level
|
| 380 |
+
|
| 381 |
+
def set_comm(self, comm):
|
| 382 |
+
self.comm = comm
|
| 383 |
+
|
| 384 |
+
def get_dir(self):
|
| 385 |
+
return self.dir
|
| 386 |
+
|
| 387 |
+
def close(self):
|
| 388 |
+
for fmt in self.output_formats:
|
| 389 |
+
fmt.close()
|
| 390 |
+
|
| 391 |
+
# Misc
|
| 392 |
+
# ----------------------------------------
|
| 393 |
+
def _do_log(self, args):
|
| 394 |
+
for fmt in self.output_formats:
|
| 395 |
+
if isinstance(fmt, SeqWriter):
|
| 396 |
+
fmt.writeseq(map(str, args))
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def get_rank_without_mpi_import():
|
| 400 |
+
# check environment variables here instead of importing mpi4py
|
| 401 |
+
# to avoid calling MPI_Init() when this module is imported
|
| 402 |
+
for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
|
| 403 |
+
if varname in os.environ:
|
| 404 |
+
return int(os.environ[varname])
|
| 405 |
+
return 0
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def mpi_weighted_mean(comm, local_name2valcount):
|
| 409 |
+
"""
|
| 410 |
+
Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
|
| 411 |
+
Perform a weighted average over dicts that are each on a different node
|
| 412 |
+
Input: local_name2valcount: dict mapping key -> (value, count)
|
| 413 |
+
Returns: key -> mean
|
| 414 |
+
"""
|
| 415 |
+
all_name2valcount = comm.gather(local_name2valcount)
|
| 416 |
+
if comm.rank == 0:
|
| 417 |
+
name2sum = defaultdict(float)
|
| 418 |
+
name2count = defaultdict(float)
|
| 419 |
+
for n2vc in all_name2valcount:
|
| 420 |
+
for (name, (val, count)) in n2vc.items():
|
| 421 |
+
try:
|
| 422 |
+
val = float(val)
|
| 423 |
+
except ValueError:
|
| 424 |
+
if comm.rank == 0:
|
| 425 |
+
warnings.warn(
|
| 426 |
+
"WARNING: tried to compute mean on non-float {}={}".format(
|
| 427 |
+
name, val
|
| 428 |
+
)
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
name2sum[name] += val * count
|
| 432 |
+
name2count[name] += count
|
| 433 |
+
return {name: name2sum[name] / name2count[name] for name in name2sum}
|
| 434 |
+
else:
|
| 435 |
+
return {}
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
|
| 439 |
+
"""
|
| 440 |
+
If comm is provided, average all numerical stats across that comm
|
| 441 |
+
"""
|
| 442 |
+
if dir is None:
|
| 443 |
+
dir = os.getenv("OPENAI_LOGDIR")
|
| 444 |
+
if dir is None:
|
| 445 |
+
dir = osp.join(
|
| 446 |
+
tempfile.gettempdir(),
|
| 447 |
+
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
|
| 448 |
+
)
|
| 449 |
+
assert isinstance(dir, str)
|
| 450 |
+
dir = os.path.expanduser(dir)
|
| 451 |
+
os.makedirs(os.path.expanduser(dir), exist_ok=True)
|
| 452 |
+
|
| 453 |
+
rank = get_rank_without_mpi_import()
|
| 454 |
+
if rank > 0:
|
| 455 |
+
log_suffix = log_suffix + "-rank%03i" % rank
|
| 456 |
+
|
| 457 |
+
if format_strs is None:
|
| 458 |
+
if rank == 0:
|
| 459 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv,tensorboard").split(",")
|
| 460 |
+
else:
|
| 461 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
|
| 462 |
+
format_strs = filter(None, format_strs)
|
| 463 |
+
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
|
| 464 |
+
|
| 465 |
+
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
|
| 466 |
+
if output_formats:
|
| 467 |
+
log("Logging to %s" % dir)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def _configure_default_logger():
|
| 471 |
+
configure()
|
| 472 |
+
Logger.DEFAULT = Logger.CURRENT
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def reset():
|
| 476 |
+
if Logger.CURRENT is not Logger.DEFAULT:
|
| 477 |
+
Logger.CURRENT.close()
|
| 478 |
+
Logger.CURRENT = Logger.DEFAULT
|
| 479 |
+
log("Reset logger")
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
@contextmanager
|
| 483 |
+
def scoped_configure(dir=None, format_strs=None, comm=None):
|
| 484 |
+
prevlogger = Logger.CURRENT
|
| 485 |
+
configure(dir=dir, format_strs=format_strs, comm=comm)
|
| 486 |
+
try:
|
| 487 |
+
yield
|
| 488 |
+
finally:
|
| 489 |
+
Logger.CURRENT.close()
|
| 490 |
+
Logger.CURRENT = prevlogger
|
| 491 |
+
|
guided-diffusion/guided_diffusion/losses.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers for various likelihood-based losses. These are ported from the original
|
| 3 |
+
Ho et al. diffusion models codebase:
|
| 4 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch as th
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
| 13 |
+
"""
|
| 14 |
+
Compute the KL divergence between two gaussians.
|
| 15 |
+
|
| 16 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
| 17 |
+
scalars, among other use cases.
|
| 18 |
+
"""
|
| 19 |
+
tensor = None
|
| 20 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
| 21 |
+
if isinstance(obj, th.Tensor):
|
| 22 |
+
tensor = obj
|
| 23 |
+
break
|
| 24 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
| 25 |
+
|
| 26 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
| 27 |
+
# Tensors, but it does not work for th.exp().
|
| 28 |
+
logvar1, logvar2 = [
|
| 29 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
| 30 |
+
for x in (logvar1, logvar2)
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
return 0.5 * (
|
| 34 |
+
-1.0
|
| 35 |
+
+ logvar2
|
| 36 |
+
- logvar1
|
| 37 |
+
+ th.exp(logvar1 - logvar2)
|
| 38 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def approx_standard_normal_cdf(x):
|
| 43 |
+
"""
|
| 44 |
+
A fast approximation of the cumulative distribution function of the
|
| 45 |
+
standard normal.
|
| 46 |
+
"""
|
| 47 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
| 51 |
+
"""
|
| 52 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
| 53 |
+
given image.
|
| 54 |
+
|
| 55 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
| 56 |
+
rescaled to the range [-1, 1].
|
| 57 |
+
:param means: the Gaussian mean Tensor.
|
| 58 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
| 59 |
+
:return: a tensor like x of log probabilities (in nats).
|
| 60 |
+
"""
|
| 61 |
+
assert x.shape == means.shape == log_scales.shape
|
| 62 |
+
centered_x = x - means
|
| 63 |
+
inv_stdv = th.exp(-log_scales)
|
| 64 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
| 65 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
| 66 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
| 67 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
| 68 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
| 69 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
| 70 |
+
cdf_delta = cdf_plus - cdf_min
|
| 71 |
+
log_probs = th.where(
|
| 72 |
+
x < -0.999,
|
| 73 |
+
log_cdf_plus,
|
| 74 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
| 75 |
+
)
|
| 76 |
+
assert log_probs.shape == x.shape
|
| 77 |
+
return log_probs
|
guided-diffusion/guided_diffusion/lpips.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lpips_pytorch import LPIPS
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class LPIPS1(LPIPS):
|
| 5 |
+
r"""
|
| 6 |
+
Overrriding the LPIPS to send loss without reducing the batch
|
| 7 |
+
Arguments:
|
| 8 |
+
net_type (str): the network type to compare the features:
|
| 9 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
| 10 |
+
version (str): the version of LPIPS. Default: 0.1.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
| 13 |
+
super(LPIPS1, self).__init__(net_type = 'alex', version ='0.1')
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 16 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
| 17 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
| 18 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
| 19 |
+
# return torch.sum(torch.cat(res, 0), 0, True)
|
| 20 |
+
return torch.sum(torch.cat(res, 1), 1, True)
|
guided-diffusion/guided_diffusion/nn.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Various utilities for neural networks.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch as th
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
| 12 |
+
class SiLU(nn.Module):
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return x * th.sigmoid(x)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GroupNorm32(nn.GroupNorm):
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return super().forward(x.float()).type(x.dtype)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def conv_nd(dims, *args, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Create a 1D, 2D, or 3D convolution module.
|
| 25 |
+
"""
|
| 26 |
+
if dims == 1:
|
| 27 |
+
return nn.Conv1d(*args, **kwargs)
|
| 28 |
+
elif dims == 2:
|
| 29 |
+
return nn.Conv2d(*args, **kwargs)
|
| 30 |
+
elif dims == 3:
|
| 31 |
+
return nn.Conv3d(*args, **kwargs)
|
| 32 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def linear(*args, **kwargs):
|
| 36 |
+
"""
|
| 37 |
+
Create a linear module.
|
| 38 |
+
"""
|
| 39 |
+
return nn.Linear(*args, **kwargs)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
| 43 |
+
"""
|
| 44 |
+
Create a 1D, 2D, or 3D average pooling module.
|
| 45 |
+
"""
|
| 46 |
+
if dims == 1:
|
| 47 |
+
return nn.AvgPool1d(*args, **kwargs)
|
| 48 |
+
elif dims == 2:
|
| 49 |
+
return nn.AvgPool2d(*args, **kwargs)
|
| 50 |
+
elif dims == 3:
|
| 51 |
+
return nn.AvgPool3d(*args, **kwargs)
|
| 52 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def update_ema(target_params, source_params, rate=0.99):
|
| 56 |
+
"""
|
| 57 |
+
Update target parameters to be closer to those of source parameters using
|
| 58 |
+
an exponential moving average.
|
| 59 |
+
|
| 60 |
+
:param target_params: the target parameter sequence.
|
| 61 |
+
:param source_params: the source parameter sequence.
|
| 62 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
| 63 |
+
"""
|
| 64 |
+
for targ, src in zip(target_params, source_params):
|
| 65 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def zero_module(module):
|
| 69 |
+
"""
|
| 70 |
+
Zero out the parameters of a module and return it.
|
| 71 |
+
"""
|
| 72 |
+
for p in module.parameters():
|
| 73 |
+
p.detach().zero_()
|
| 74 |
+
return module
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def scale_module(module, scale):
|
| 78 |
+
"""
|
| 79 |
+
Scale the parameters of a module and return it.
|
| 80 |
+
"""
|
| 81 |
+
for p in module.parameters():
|
| 82 |
+
p.detach().mul_(scale)
|
| 83 |
+
return module
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def mean_flat(tensor):
|
| 87 |
+
"""
|
| 88 |
+
Take the mean over all non-batch dimensions.
|
| 89 |
+
"""
|
| 90 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def normalization(channels):
|
| 94 |
+
"""
|
| 95 |
+
Make a standard normalization layer.
|
| 96 |
+
|
| 97 |
+
:param channels: number of input channels.
|
| 98 |
+
:return: an nn.Module for normalization.
|
| 99 |
+
"""
|
| 100 |
+
return GroupNorm32(32, channels)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 104 |
+
"""
|
| 105 |
+
Create sinusoidal timestep embeddings.
|
| 106 |
+
|
| 107 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 108 |
+
These may be fractional.
|
| 109 |
+
:param dim: the dimension of the output.
|
| 110 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 111 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 112 |
+
"""
|
| 113 |
+
half = dim // 2
|
| 114 |
+
freqs = th.exp(
|
| 115 |
+
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
| 116 |
+
).to(device=timesteps.device)
|
| 117 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 118 |
+
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
| 119 |
+
if dim % 2:
|
| 120 |
+
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
| 121 |
+
return embedding
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def checkpoint(func, inputs, params, flag):
|
| 125 |
+
"""
|
| 126 |
+
Evaluate a function without caching intermediate activations, allowing for
|
| 127 |
+
reduced memory at the expense of extra compute in the backward pass.
|
| 128 |
+
|
| 129 |
+
:param func: the function to evaluate.
|
| 130 |
+
:param inputs: the argument sequence to pass to `func`.
|
| 131 |
+
:param params: a sequence of parameters `func` depends on but does not
|
| 132 |
+
explicitly take as arguments.
|
| 133 |
+
:param flag: if False, disable gradient checkpointing.
|
| 134 |
+
"""
|
| 135 |
+
if flag:
|
| 136 |
+
args = tuple(inputs) + tuple(params)
|
| 137 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
| 138 |
+
else:
|
| 139 |
+
return func(*inputs)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class CheckpointFunction(th.autograd.Function):
|
| 143 |
+
@staticmethod
|
| 144 |
+
def forward(ctx, run_function, length, *args):
|
| 145 |
+
ctx.run_function = run_function
|
| 146 |
+
ctx.input_tensors = list(args[:length])
|
| 147 |
+
ctx.input_params = list(args[length:])
|
| 148 |
+
with th.no_grad():
|
| 149 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 150 |
+
return output_tensors
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def backward(ctx, *output_grads):
|
| 154 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 155 |
+
with th.enable_grad():
|
| 156 |
+
# Fixes a bug where the first op in run_function modifies the
|
| 157 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
| 158 |
+
# Tensors.
|
| 159 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 160 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
| 161 |
+
input_grads = th.autograd.grad(
|
| 162 |
+
output_tensors,
|
| 163 |
+
ctx.input_tensors + ctx.input_params,
|
| 164 |
+
output_grads,
|
| 165 |
+
allow_unused=True,
|
| 166 |
+
)
|
| 167 |
+
del ctx.input_tensors
|
| 168 |
+
del ctx.input_params
|
| 169 |
+
del output_tensors
|
| 170 |
+
return (None, None) + input_grads
|
guided-diffusion/guided_diffusion/resample.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch as th
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_named_schedule_sampler(name, diffusion):
|
| 9 |
+
"""
|
| 10 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
| 11 |
+
|
| 12 |
+
:param name: the name of the sampler.
|
| 13 |
+
:param diffusion: the diffusion object to sample for.
|
| 14 |
+
"""
|
| 15 |
+
if name == "uniform":
|
| 16 |
+
return UniformSampler(diffusion)
|
| 17 |
+
elif name == "loss-second-moment":
|
| 18 |
+
return LossSecondMomentResampler(diffusion)
|
| 19 |
+
else:
|
| 20 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ScheduleSampler(ABC):
|
| 24 |
+
"""
|
| 25 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
| 26 |
+
variance of the objective.
|
| 27 |
+
|
| 28 |
+
By default, samplers perform unbiased importance sampling, in which the
|
| 29 |
+
objective's mean is unchanged.
|
| 30 |
+
However, subclasses may override sample() to change how the resampled
|
| 31 |
+
terms are reweighted, allowing for actual changes in the objective.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def weights(self):
|
| 36 |
+
"""
|
| 37 |
+
Get a numpy array of weights, one per diffusion step.
|
| 38 |
+
|
| 39 |
+
The weights needn't be normalized, but must be positive.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def sample(self, batch_size, device):
|
| 43 |
+
"""
|
| 44 |
+
Importance-sample timesteps for a batch.
|
| 45 |
+
|
| 46 |
+
:param batch_size: the number of timesteps.
|
| 47 |
+
:param device: the torch device to save to.
|
| 48 |
+
:return: a tuple (timesteps, weights):
|
| 49 |
+
- timesteps: a tensor of timestep indices.
|
| 50 |
+
- weights: a tensor of weights to scale the resulting losses.
|
| 51 |
+
"""
|
| 52 |
+
w = self.weights()
|
| 53 |
+
p = w / np.sum(w)
|
| 54 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
| 55 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
| 56 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
| 57 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
| 58 |
+
return indices, weights
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class UniformSampler(ScheduleSampler):
|
| 62 |
+
def __init__(self, diffusion):
|
| 63 |
+
self.diffusion = diffusion
|
| 64 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
| 65 |
+
|
| 66 |
+
def weights(self):
|
| 67 |
+
return self._weights
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class LossAwareSampler(ScheduleSampler):
|
| 71 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
| 72 |
+
"""
|
| 73 |
+
Update the reweighting using losses from a model.
|
| 74 |
+
|
| 75 |
+
Call this method from each rank with a batch of timesteps and the
|
| 76 |
+
corresponding losses for each of those timesteps.
|
| 77 |
+
This method will perform synchronization to make sure all of the ranks
|
| 78 |
+
maintain the exact same reweighting.
|
| 79 |
+
|
| 80 |
+
:param local_ts: an integer Tensor of timesteps.
|
| 81 |
+
:param local_losses: a 1D Tensor of losses.
|
| 82 |
+
"""
|
| 83 |
+
batch_sizes = [
|
| 84 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
| 85 |
+
for _ in range(dist.get_world_size())
|
| 86 |
+
]
|
| 87 |
+
dist.all_gather(
|
| 88 |
+
batch_sizes,
|
| 89 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Pad all_gather batches to be the maximum batch size.
|
| 93 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
| 94 |
+
max_bs = max(batch_sizes)
|
| 95 |
+
|
| 96 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
| 97 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
| 98 |
+
dist.all_gather(timestep_batches, local_ts)
|
| 99 |
+
dist.all_gather(loss_batches, local_losses)
|
| 100 |
+
timesteps = [
|
| 101 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
| 102 |
+
]
|
| 103 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
| 104 |
+
self.update_with_all_losses(timesteps, losses)
|
| 105 |
+
|
| 106 |
+
@abstractmethod
|
| 107 |
+
def update_with_all_losses(self, ts, losses):
|
| 108 |
+
"""
|
| 109 |
+
Update the reweighting using losses from a model.
|
| 110 |
+
|
| 111 |
+
Sub-classes should override this method to update the reweighting
|
| 112 |
+
using losses from the model.
|
| 113 |
+
|
| 114 |
+
This method directly updates the reweighting without synchronizing
|
| 115 |
+
between workers. It is called by update_with_local_losses from all
|
| 116 |
+
ranks with identical arguments. Thus, it should have deterministic
|
| 117 |
+
behavior to maintain state across workers.
|
| 118 |
+
|
| 119 |
+
:param ts: a list of int timesteps.
|
| 120 |
+
:param losses: a list of float losses, one per timestep.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
| 125 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
| 126 |
+
self.diffusion = diffusion
|
| 127 |
+
self.history_per_term = history_per_term
|
| 128 |
+
self.uniform_prob = uniform_prob
|
| 129 |
+
self._loss_history = np.zeros(
|
| 130 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
| 131 |
+
)
|
| 132 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
| 133 |
+
|
| 134 |
+
def weights(self):
|
| 135 |
+
if not self._warmed_up():
|
| 136 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
| 137 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
| 138 |
+
weights /= np.sum(weights)
|
| 139 |
+
weights *= 1 - self.uniform_prob
|
| 140 |
+
weights += self.uniform_prob / len(weights)
|
| 141 |
+
return weights
|
| 142 |
+
|
| 143 |
+
def update_with_all_losses(self, ts, losses):
|
| 144 |
+
for t, loss in zip(ts, losses):
|
| 145 |
+
if self._loss_counts[t] == self.history_per_term:
|
| 146 |
+
# Shift out the oldest loss term.
|
| 147 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
| 148 |
+
self._loss_history[t, -1] = loss
|
| 149 |
+
else:
|
| 150 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
| 151 |
+
self._loss_counts[t] += 1
|
| 152 |
+
|
| 153 |
+
def _warmed_up(self):
|
| 154 |
+
return (self._loss_counts == self.history_per_term).all()
|
guided-diffusion/guided_diffusion/respace.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch as th
|
| 3 |
+
|
| 4 |
+
from .gaussian_diffusion import GaussianDiffusion
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def space_timesteps(num_timesteps, section_counts):
|
| 8 |
+
"""
|
| 9 |
+
Create a list of timesteps to use from an original diffusion process,
|
| 10 |
+
given the number of timesteps we want to take from equally-sized portions
|
| 11 |
+
of the original process.
|
| 12 |
+
|
| 13 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
| 14 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
| 15 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
| 16 |
+
|
| 17 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
| 18 |
+
from the DDIM paper is used, and only one section is allowed.
|
| 19 |
+
|
| 20 |
+
:param num_timesteps: the number of diffusion steps in the original
|
| 21 |
+
process to divide up.
|
| 22 |
+
:param section_counts: either a list of numbers, or a string containing
|
| 23 |
+
comma-separated numbers, indicating the step count
|
| 24 |
+
per section. As a special case, use "ddimN" where N
|
| 25 |
+
is a number of steps to use the striding from the
|
| 26 |
+
DDIM paper.
|
| 27 |
+
:return: a set of diffusion steps from the original process to use.
|
| 28 |
+
"""
|
| 29 |
+
if isinstance(section_counts, str):
|
| 30 |
+
if section_counts.startswith("ddim"):
|
| 31 |
+
desired_count = int(section_counts[len("ddim") :])
|
| 32 |
+
for i in range(1, num_timesteps):
|
| 33 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
| 34 |
+
return set(range(0, num_timesteps, i))
|
| 35 |
+
raise ValueError(
|
| 36 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
| 37 |
+
)
|
| 38 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
| 39 |
+
size_per = num_timesteps // len(section_counts)
|
| 40 |
+
extra = num_timesteps % len(section_counts)
|
| 41 |
+
start_idx = 0
|
| 42 |
+
all_steps = []
|
| 43 |
+
for i, section_count in enumerate(section_counts):
|
| 44 |
+
size = size_per + (1 if i < extra else 0)
|
| 45 |
+
if size < section_count:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
f"cannot divide section of {size} steps into {section_count}"
|
| 48 |
+
)
|
| 49 |
+
if section_count <= 1:
|
| 50 |
+
frac_stride = 1
|
| 51 |
+
else:
|
| 52 |
+
frac_stride = (size - 1) / (section_count - 1)
|
| 53 |
+
cur_idx = 0.0
|
| 54 |
+
taken_steps = []
|
| 55 |
+
for _ in range(section_count):
|
| 56 |
+
taken_steps.append(start_idx + round(cur_idx))
|
| 57 |
+
cur_idx += frac_stride
|
| 58 |
+
all_steps += taken_steps
|
| 59 |
+
start_idx += size
|
| 60 |
+
return set(all_steps)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SpacedDiffusion(GaussianDiffusion):
|
| 64 |
+
"""
|
| 65 |
+
A diffusion process which can skip steps in a base diffusion process.
|
| 66 |
+
|
| 67 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
| 68 |
+
original diffusion process to retain.
|
| 69 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, use_timesteps, **kwargs):
|
| 73 |
+
self.use_timesteps = set(use_timesteps)
|
| 74 |
+
self.timestep_map = []
|
| 75 |
+
self.original_num_steps = len(kwargs["betas"])
|
| 76 |
+
|
| 77 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
| 78 |
+
last_alpha_cumprod = 1.0
|
| 79 |
+
new_betas = []
|
| 80 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
| 81 |
+
if i in self.use_timesteps:
|
| 82 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
| 83 |
+
last_alpha_cumprod = alpha_cumprod
|
| 84 |
+
self.timestep_map.append(i)
|
| 85 |
+
kwargs["betas"] = np.array(new_betas)
|
| 86 |
+
super().__init__(**kwargs)
|
| 87 |
+
|
| 88 |
+
def p_mean_variance(
|
| 89 |
+
self, model, *args, **kwargs
|
| 90 |
+
): # pylint: disable=signature-differs
|
| 91 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
| 92 |
+
|
| 93 |
+
def training_losses(
|
| 94 |
+
self, model, *args, **kwargs
|
| 95 |
+
): # pylint: disable=signature-differs
|
| 96 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
| 97 |
+
|
| 98 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
| 99 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
| 100 |
+
|
| 101 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
| 102 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
| 103 |
+
|
| 104 |
+
def _wrap_model(self, model):
|
| 105 |
+
if isinstance(model, _WrappedModel):
|
| 106 |
+
return model
|
| 107 |
+
return _WrappedModel(
|
| 108 |
+
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def _scale_timesteps(self, t):
|
| 112 |
+
# Scaling is done by the wrapped model.
|
| 113 |
+
return t
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class _WrappedModel:
|
| 117 |
+
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
| 118 |
+
self.model = model
|
| 119 |
+
self.timestep_map = timestep_map
|
| 120 |
+
self.rescale_timesteps = rescale_timesteps
|
| 121 |
+
self.original_num_steps = original_num_steps
|
| 122 |
+
|
| 123 |
+
def __call__(self, x, ts, **kwargs):
|
| 124 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
| 125 |
+
new_ts = map_tensor[ts]
|
| 126 |
+
if self.rescale_timesteps:
|
| 127 |
+
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
| 128 |
+
return self.model(x, new_ts, **kwargs)
|
guided-diffusion/guided_diffusion/script_util.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
from . import gaussian_diffusion as gd
|
| 5 |
+
from .respace import SpacedDiffusion, space_timesteps
|
| 6 |
+
from .unet import SuperResModel, UNetModel, EncoderUNetModel, TFGModel
|
| 7 |
+
|
| 8 |
+
NUM_CLASSES = 1000
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def diffusion_defaults():
|
| 12 |
+
"""
|
| 13 |
+
Defaults for image and classifier training.
|
| 14 |
+
"""
|
| 15 |
+
return dict(
|
| 16 |
+
learn_sigma=False,
|
| 17 |
+
diffusion_steps=1000,
|
| 18 |
+
noise_schedule="linear",
|
| 19 |
+
timestep_respacing="",
|
| 20 |
+
use_kl=False,
|
| 21 |
+
predict_xstart=False,
|
| 22 |
+
rescale_timesteps=False,
|
| 23 |
+
rescale_learned_sigmas=False,
|
| 24 |
+
loss_variation=0, #added by soumik
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def classifier_defaults():
|
| 29 |
+
"""
|
| 30 |
+
Defaults for classifier models.
|
| 31 |
+
"""
|
| 32 |
+
return dict(
|
| 33 |
+
image_size=64,
|
| 34 |
+
classifier_use_fp16=False,
|
| 35 |
+
classifier_width=128,
|
| 36 |
+
classifier_depth=2,
|
| 37 |
+
classifier_attention_resolutions="32,16,8", # 16
|
| 38 |
+
classifier_use_scale_shift_norm=True, # False
|
| 39 |
+
classifier_resblock_updown=True, # False
|
| 40 |
+
classifier_pool="attention",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def model_and_diffusion_defaults():
|
| 45 |
+
"""
|
| 46 |
+
Defaults for image training.
|
| 47 |
+
"""
|
| 48 |
+
res = dict(
|
| 49 |
+
image_size=64,
|
| 50 |
+
num_channels=128,
|
| 51 |
+
num_res_blocks=2,
|
| 52 |
+
num_heads=4,
|
| 53 |
+
num_heads_upsample=-1,
|
| 54 |
+
num_head_channels=-1,
|
| 55 |
+
attention_resolutions="16,8",
|
| 56 |
+
channel_mult="",
|
| 57 |
+
dropout=0.0,
|
| 58 |
+
class_cond=False,
|
| 59 |
+
use_checkpoint=False,
|
| 60 |
+
use_scale_shift_norm=True,
|
| 61 |
+
resblock_updown=False,
|
| 62 |
+
use_fp16=False,
|
| 63 |
+
use_new_attention_order=False,
|
| 64 |
+
)
|
| 65 |
+
res.update(diffusion_defaults())
|
| 66 |
+
return res
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def classifier_and_diffusion_defaults():
|
| 70 |
+
res = classifier_defaults()
|
| 71 |
+
res.update(diffusion_defaults())
|
| 72 |
+
return res
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def create_model_and_diffusion(
|
| 76 |
+
image_size,
|
| 77 |
+
class_cond,
|
| 78 |
+
learn_sigma,
|
| 79 |
+
num_channels,
|
| 80 |
+
num_res_blocks,
|
| 81 |
+
channel_mult,
|
| 82 |
+
num_heads,
|
| 83 |
+
num_head_channels,
|
| 84 |
+
num_heads_upsample,
|
| 85 |
+
attention_resolutions,
|
| 86 |
+
dropout,
|
| 87 |
+
diffusion_steps,
|
| 88 |
+
noise_schedule,
|
| 89 |
+
timestep_respacing,
|
| 90 |
+
use_kl,
|
| 91 |
+
predict_xstart,
|
| 92 |
+
rescale_timesteps,
|
| 93 |
+
rescale_learned_sigmas,
|
| 94 |
+
use_checkpoint,
|
| 95 |
+
use_scale_shift_norm,
|
| 96 |
+
resblock_updown,
|
| 97 |
+
use_fp16,
|
| 98 |
+
use_new_attention_order,
|
| 99 |
+
):
|
| 100 |
+
model = create_model(
|
| 101 |
+
image_size,
|
| 102 |
+
num_channels,
|
| 103 |
+
num_res_blocks,
|
| 104 |
+
channel_mult=channel_mult,
|
| 105 |
+
learn_sigma=learn_sigma,
|
| 106 |
+
class_cond=class_cond,
|
| 107 |
+
use_checkpoint=use_checkpoint,
|
| 108 |
+
attention_resolutions=attention_resolutions,
|
| 109 |
+
num_heads=num_heads,
|
| 110 |
+
num_head_channels=num_head_channels,
|
| 111 |
+
num_heads_upsample=num_heads_upsample,
|
| 112 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 113 |
+
dropout=dropout,
|
| 114 |
+
resblock_updown=resblock_updown,
|
| 115 |
+
use_fp16=use_fp16,
|
| 116 |
+
use_new_attention_order=use_new_attention_order,
|
| 117 |
+
)
|
| 118 |
+
diffusion = create_gaussian_diffusion(
|
| 119 |
+
steps=diffusion_steps,
|
| 120 |
+
learn_sigma=learn_sigma,
|
| 121 |
+
noise_schedule=noise_schedule,
|
| 122 |
+
use_kl=use_kl,
|
| 123 |
+
predict_xstart=predict_xstart,
|
| 124 |
+
rescale_timesteps=rescale_timesteps,
|
| 125 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
| 126 |
+
timestep_respacing=timestep_respacing,
|
| 127 |
+
)
|
| 128 |
+
return model, diffusion
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def create_model(
|
| 132 |
+
image_size,
|
| 133 |
+
num_channels,
|
| 134 |
+
num_res_blocks,
|
| 135 |
+
channel_mult="",
|
| 136 |
+
learn_sigma=False,
|
| 137 |
+
class_cond=False,
|
| 138 |
+
use_checkpoint=False,
|
| 139 |
+
attention_resolutions="16",
|
| 140 |
+
num_heads=1,
|
| 141 |
+
num_head_channels=-1,
|
| 142 |
+
num_heads_upsample=-1,
|
| 143 |
+
use_scale_shift_norm=False,
|
| 144 |
+
dropout=0,
|
| 145 |
+
resblock_updown=False,
|
| 146 |
+
use_fp16=False,
|
| 147 |
+
use_new_attention_order=False,
|
| 148 |
+
):
|
| 149 |
+
if channel_mult == "":
|
| 150 |
+
if image_size == 512:
|
| 151 |
+
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
|
| 152 |
+
elif image_size == 256:
|
| 153 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
| 154 |
+
elif image_size == 128:
|
| 155 |
+
channel_mult = (1, 1, 2, 3, 4)
|
| 156 |
+
elif image_size == 64:
|
| 157 |
+
channel_mult = (1, 2, 3, 4)
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"unsupported image size: {image_size}")
|
| 160 |
+
else:
|
| 161 |
+
channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
|
| 162 |
+
|
| 163 |
+
attention_ds = []
|
| 164 |
+
for res in attention_resolutions.split(","):
|
| 165 |
+
attention_ds.append(image_size // int(res))
|
| 166 |
+
|
| 167 |
+
return UNetModel(
|
| 168 |
+
image_size=image_size,
|
| 169 |
+
in_channels=3,
|
| 170 |
+
model_channels=num_channels,
|
| 171 |
+
out_channels=(3 if not learn_sigma else 6),
|
| 172 |
+
num_res_blocks=num_res_blocks,
|
| 173 |
+
attention_resolutions=tuple(attention_ds),
|
| 174 |
+
dropout=dropout,
|
| 175 |
+
channel_mult=channel_mult,
|
| 176 |
+
num_classes=(NUM_CLASSES if class_cond else None),
|
| 177 |
+
use_checkpoint=use_checkpoint,
|
| 178 |
+
use_fp16=use_fp16,
|
| 179 |
+
num_heads=num_heads,
|
| 180 |
+
num_head_channels=num_head_channels,
|
| 181 |
+
num_heads_upsample=num_heads_upsample,
|
| 182 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 183 |
+
resblock_updown=resblock_updown,
|
| 184 |
+
use_new_attention_order=use_new_attention_order,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def create_classifier_and_diffusion(
|
| 189 |
+
image_size,
|
| 190 |
+
classifier_use_fp16,
|
| 191 |
+
classifier_width,
|
| 192 |
+
classifier_depth,
|
| 193 |
+
classifier_attention_resolutions,
|
| 194 |
+
classifier_use_scale_shift_norm,
|
| 195 |
+
classifier_resblock_updown,
|
| 196 |
+
classifier_pool,
|
| 197 |
+
learn_sigma,
|
| 198 |
+
diffusion_steps,
|
| 199 |
+
noise_schedule,
|
| 200 |
+
timestep_respacing,
|
| 201 |
+
use_kl,
|
| 202 |
+
predict_xstart,
|
| 203 |
+
rescale_timesteps,
|
| 204 |
+
rescale_learned_sigmas,
|
| 205 |
+
):
|
| 206 |
+
classifier = create_classifier(
|
| 207 |
+
image_size,
|
| 208 |
+
classifier_use_fp16,
|
| 209 |
+
classifier_width,
|
| 210 |
+
classifier_depth,
|
| 211 |
+
classifier_attention_resolutions,
|
| 212 |
+
classifier_use_scale_shift_norm,
|
| 213 |
+
classifier_resblock_updown,
|
| 214 |
+
classifier_pool,
|
| 215 |
+
)
|
| 216 |
+
diffusion = create_gaussian_diffusion(
|
| 217 |
+
steps=diffusion_steps,
|
| 218 |
+
learn_sigma=learn_sigma,
|
| 219 |
+
noise_schedule=noise_schedule,
|
| 220 |
+
use_kl=use_kl,
|
| 221 |
+
predict_xstart=predict_xstart,
|
| 222 |
+
rescale_timesteps=rescale_timesteps,
|
| 223 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
| 224 |
+
timestep_respacing=timestep_respacing,
|
| 225 |
+
)
|
| 226 |
+
return classifier, diffusion
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def create_classifier(
|
| 230 |
+
image_size,
|
| 231 |
+
classifier_use_fp16,
|
| 232 |
+
classifier_width,
|
| 233 |
+
classifier_depth,
|
| 234 |
+
classifier_attention_resolutions,
|
| 235 |
+
classifier_use_scale_shift_norm,
|
| 236 |
+
classifier_resblock_updown,
|
| 237 |
+
classifier_pool,
|
| 238 |
+
):
|
| 239 |
+
if image_size == 512:
|
| 240 |
+
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
|
| 241 |
+
elif image_size == 256:
|
| 242 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
| 243 |
+
elif image_size == 128:
|
| 244 |
+
channel_mult = (1, 1, 2, 3, 4)
|
| 245 |
+
elif image_size == 64:
|
| 246 |
+
channel_mult = (1, 2, 3, 4)
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError(f"unsupported image size: {image_size}")
|
| 249 |
+
|
| 250 |
+
attention_ds = []
|
| 251 |
+
for res in classifier_attention_resolutions.split(","):
|
| 252 |
+
attention_ds.append(image_size // int(res))
|
| 253 |
+
|
| 254 |
+
return EncoderUNetModel(
|
| 255 |
+
image_size=image_size,
|
| 256 |
+
in_channels=3,
|
| 257 |
+
model_channels=classifier_width,
|
| 258 |
+
out_channels=1000,
|
| 259 |
+
num_res_blocks=classifier_depth,
|
| 260 |
+
attention_resolutions=tuple(attention_ds),
|
| 261 |
+
channel_mult=channel_mult,
|
| 262 |
+
use_fp16=classifier_use_fp16,
|
| 263 |
+
num_head_channels=64,
|
| 264 |
+
use_scale_shift_norm=classifier_use_scale_shift_norm,
|
| 265 |
+
resblock_updown=classifier_resblock_updown,
|
| 266 |
+
pool=classifier_pool,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def sr_model_and_diffusion_defaults():
|
| 271 |
+
res = model_and_diffusion_defaults()
|
| 272 |
+
res["large_size"] = 256
|
| 273 |
+
res["small_size"] = 64
|
| 274 |
+
arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
|
| 275 |
+
for k in res.copy().keys():
|
| 276 |
+
if k not in arg_names:
|
| 277 |
+
del res[k]
|
| 278 |
+
return res
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def sr_create_model_and_diffusion(
|
| 282 |
+
large_size,
|
| 283 |
+
small_size,
|
| 284 |
+
class_cond,
|
| 285 |
+
learn_sigma,
|
| 286 |
+
num_channels,
|
| 287 |
+
num_res_blocks,
|
| 288 |
+
num_heads,
|
| 289 |
+
num_head_channels,
|
| 290 |
+
num_heads_upsample,
|
| 291 |
+
attention_resolutions,
|
| 292 |
+
dropout,
|
| 293 |
+
diffusion_steps,
|
| 294 |
+
noise_schedule,
|
| 295 |
+
timestep_respacing,
|
| 296 |
+
use_kl,
|
| 297 |
+
predict_xstart,
|
| 298 |
+
rescale_timesteps,
|
| 299 |
+
rescale_learned_sigmas,
|
| 300 |
+
use_checkpoint,
|
| 301 |
+
use_scale_shift_norm,
|
| 302 |
+
resblock_updown,
|
| 303 |
+
use_fp16,
|
| 304 |
+
):
|
| 305 |
+
model = sr_create_model(
|
| 306 |
+
large_size,
|
| 307 |
+
small_size,
|
| 308 |
+
num_channels,
|
| 309 |
+
num_res_blocks,
|
| 310 |
+
learn_sigma=learn_sigma,
|
| 311 |
+
class_cond=class_cond,
|
| 312 |
+
use_checkpoint=use_checkpoint,
|
| 313 |
+
attention_resolutions=attention_resolutions,
|
| 314 |
+
num_heads=num_heads,
|
| 315 |
+
num_head_channels=num_head_channels,
|
| 316 |
+
num_heads_upsample=num_heads_upsample,
|
| 317 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 318 |
+
dropout=dropout,
|
| 319 |
+
resblock_updown=resblock_updown,
|
| 320 |
+
use_fp16=use_fp16,
|
| 321 |
+
)
|
| 322 |
+
diffusion = create_gaussian_diffusion(
|
| 323 |
+
steps=diffusion_steps,
|
| 324 |
+
learn_sigma=learn_sigma,
|
| 325 |
+
noise_schedule=noise_schedule,
|
| 326 |
+
use_kl=use_kl,
|
| 327 |
+
predict_xstart=predict_xstart,
|
| 328 |
+
rescale_timesteps=rescale_timesteps,
|
| 329 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
| 330 |
+
timestep_respacing=timestep_respacing,
|
| 331 |
+
)
|
| 332 |
+
return model, diffusion
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def sr_create_model(
|
| 336 |
+
large_size,
|
| 337 |
+
small_size,
|
| 338 |
+
num_channels,
|
| 339 |
+
num_res_blocks,
|
| 340 |
+
learn_sigma,
|
| 341 |
+
class_cond,
|
| 342 |
+
use_checkpoint,
|
| 343 |
+
attention_resolutions,
|
| 344 |
+
num_heads,
|
| 345 |
+
num_head_channels,
|
| 346 |
+
num_heads_upsample,
|
| 347 |
+
use_scale_shift_norm,
|
| 348 |
+
dropout,
|
| 349 |
+
resblock_updown,
|
| 350 |
+
use_fp16,
|
| 351 |
+
):
|
| 352 |
+
_ = small_size # hack to prevent unused variable
|
| 353 |
+
|
| 354 |
+
if large_size == 512:
|
| 355 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
| 356 |
+
elif large_size == 256:
|
| 357 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
| 358 |
+
elif large_size == 64:
|
| 359 |
+
channel_mult = (1, 2, 3, 4)
|
| 360 |
+
else:
|
| 361 |
+
raise ValueError(f"unsupported large size: {large_size}")
|
| 362 |
+
|
| 363 |
+
attention_ds = []
|
| 364 |
+
for res in attention_resolutions.split(","):
|
| 365 |
+
attention_ds.append(large_size // int(res))
|
| 366 |
+
|
| 367 |
+
return SuperResModel(
|
| 368 |
+
image_size=large_size,
|
| 369 |
+
in_channels=3,
|
| 370 |
+
model_channels=num_channels,
|
| 371 |
+
out_channels=(3 if not learn_sigma else 6),
|
| 372 |
+
num_res_blocks=num_res_blocks,
|
| 373 |
+
attention_resolutions=tuple(attention_ds),
|
| 374 |
+
dropout=dropout,
|
| 375 |
+
channel_mult=channel_mult,
|
| 376 |
+
num_classes=(NUM_CLASSES if class_cond else None),
|
| 377 |
+
use_checkpoint=use_checkpoint,
|
| 378 |
+
num_heads=num_heads,
|
| 379 |
+
num_head_channels=num_head_channels,
|
| 380 |
+
num_heads_upsample=num_heads_upsample,
|
| 381 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 382 |
+
resblock_updown=resblock_updown,
|
| 383 |
+
use_fp16=use_fp16,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def create_gaussian_diffusion(
|
| 388 |
+
*,
|
| 389 |
+
steps=1000,
|
| 390 |
+
learn_sigma=False,
|
| 391 |
+
sigma_small=False,
|
| 392 |
+
noise_schedule="linear",
|
| 393 |
+
use_kl=False,
|
| 394 |
+
predict_xstart=False,
|
| 395 |
+
rescale_timesteps=False,
|
| 396 |
+
rescale_learned_sigmas=False,
|
| 397 |
+
timestep_respacing="",
|
| 398 |
+
loss_variation=0,
|
| 399 |
+
):
|
| 400 |
+
betas = gd.get_named_beta_schedule(noise_schedule, steps)
|
| 401 |
+
if use_kl:
|
| 402 |
+
loss_type = gd.LossType.RESCALED_KL
|
| 403 |
+
elif rescale_learned_sigmas:
|
| 404 |
+
loss_type = gd.LossType.RESCALED_MSE
|
| 405 |
+
else:
|
| 406 |
+
loss_type = gd.LossType.MSE
|
| 407 |
+
if not timestep_respacing:
|
| 408 |
+
timestep_respacing = [steps]
|
| 409 |
+
return SpacedDiffusion(
|
| 410 |
+
use_timesteps=space_timesteps(steps, timestep_respacing),
|
| 411 |
+
betas=betas,
|
| 412 |
+
model_mean_type=(
|
| 413 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
| 414 |
+
),
|
| 415 |
+
model_var_type=(
|
| 416 |
+
(
|
| 417 |
+
gd.ModelVarType.FIXED_LARGE
|
| 418 |
+
if not sigma_small
|
| 419 |
+
else gd.ModelVarType.FIXED_SMALL
|
| 420 |
+
)
|
| 421 |
+
if not learn_sigma
|
| 422 |
+
else gd.ModelVarType.LEARNED_RANGE
|
| 423 |
+
),
|
| 424 |
+
loss_type=loss_type,
|
| 425 |
+
rescale_timesteps=rescale_timesteps,
|
| 426 |
+
loss_variation=loss_variation, #added by soumik
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def add_dict_to_argparser(parser, default_dict):
|
| 431 |
+
for k, v in default_dict.items():
|
| 432 |
+
v_type = type(v)
|
| 433 |
+
if v is None:
|
| 434 |
+
v_type = str
|
| 435 |
+
elif isinstance(v, bool):
|
| 436 |
+
v_type = str2bool
|
| 437 |
+
parser.add_argument(f"--{k}", default=v, type=v_type)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def args_to_dict(args, keys):
|
| 441 |
+
return {k: getattr(args, k) for k in keys}
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def str2bool(v):
|
| 445 |
+
"""
|
| 446 |
+
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
| 447 |
+
"""
|
| 448 |
+
if isinstance(v, bool):
|
| 449 |
+
return v
|
| 450 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 451 |
+
return True
|
| 452 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 453 |
+
return False
|
| 454 |
+
else:
|
| 455 |
+
raise argparse.ArgumentTypeError("boolean value expected")
|
| 456 |
+
|
| 457 |
+
#________________________________ tfg model ________________________________#
|
| 458 |
+
def tfg_model_and_diffusion_defaults():
|
| 459 |
+
res = model_and_diffusion_defaults()
|
| 460 |
+
arg_names = inspect.getfullargspec(tfg_create_model_and_diffusion)[0]
|
| 461 |
+
for k in res.copy().keys():
|
| 462 |
+
if k not in arg_names:
|
| 463 |
+
del res[k]
|
| 464 |
+
|
| 465 |
+
#tfg args
|
| 466 |
+
res["use_ref"]=False
|
| 467 |
+
res["nframes"]=1
|
| 468 |
+
res["nrefer"]=0
|
| 469 |
+
res["use_audio"]=False
|
| 470 |
+
res["audio_encoder_kwargs"]={}
|
| 471 |
+
res["audio_as_style"]=False
|
| 472 |
+
res["audio_as_style_encoder_mlp"]=False
|
| 473 |
+
return res
|
| 474 |
+
|
| 475 |
+
def tfg_create_model_and_diffusion(
|
| 476 |
+
image_size,
|
| 477 |
+
class_cond,
|
| 478 |
+
learn_sigma,
|
| 479 |
+
num_channels,
|
| 480 |
+
num_res_blocks,
|
| 481 |
+
num_heads,
|
| 482 |
+
num_head_channels,
|
| 483 |
+
num_heads_upsample,
|
| 484 |
+
attention_resolutions,
|
| 485 |
+
dropout,
|
| 486 |
+
diffusion_steps,
|
| 487 |
+
noise_schedule,
|
| 488 |
+
timestep_respacing,
|
| 489 |
+
use_kl,
|
| 490 |
+
predict_xstart,
|
| 491 |
+
rescale_timesteps,
|
| 492 |
+
rescale_learned_sigmas,
|
| 493 |
+
use_checkpoint,
|
| 494 |
+
use_scale_shift_norm,
|
| 495 |
+
resblock_updown,
|
| 496 |
+
use_fp16,
|
| 497 |
+
use_ref,
|
| 498 |
+
nframes,
|
| 499 |
+
nrefer,
|
| 500 |
+
use_audio,
|
| 501 |
+
audio_encoder_kwargs,
|
| 502 |
+
audio_as_style,
|
| 503 |
+
audio_as_style_encoder_mlp,
|
| 504 |
+
loss_variation,
|
| 505 |
+
):
|
| 506 |
+
model = tfg_create_model(
|
| 507 |
+
image_size,
|
| 508 |
+
num_channels,
|
| 509 |
+
num_res_blocks,
|
| 510 |
+
learn_sigma=learn_sigma,
|
| 511 |
+
class_cond=class_cond,
|
| 512 |
+
use_checkpoint=use_checkpoint,
|
| 513 |
+
attention_resolutions=attention_resolutions,
|
| 514 |
+
num_heads=num_heads,
|
| 515 |
+
num_head_channels=num_head_channels,
|
| 516 |
+
num_heads_upsample=num_heads_upsample,
|
| 517 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 518 |
+
dropout=dropout,
|
| 519 |
+
resblock_updown=resblock_updown,
|
| 520 |
+
use_fp16=use_fp16,
|
| 521 |
+
use_ref=use_ref,
|
| 522 |
+
nframes=nframes,
|
| 523 |
+
nrefer=nrefer,
|
| 524 |
+
use_audio=use_audio,
|
| 525 |
+
audio_encoder_kwargs=audio_encoder_kwargs,
|
| 526 |
+
audio_as_style=audio_as_style,
|
| 527 |
+
audio_as_style_encoder_mlp=audio_as_style_encoder_mlp,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
diffusion = create_gaussian_diffusion(
|
| 531 |
+
steps=diffusion_steps,
|
| 532 |
+
learn_sigma=learn_sigma,
|
| 533 |
+
noise_schedule=noise_schedule,
|
| 534 |
+
use_kl=use_kl,
|
| 535 |
+
predict_xstart=predict_xstart,
|
| 536 |
+
rescale_timesteps=rescale_timesteps,
|
| 537 |
+
rescale_learned_sigmas=rescale_learned_sigmas,
|
| 538 |
+
timestep_respacing=timestep_respacing,
|
| 539 |
+
loss_variation=loss_variation,
|
| 540 |
+
)
|
| 541 |
+
return model, diffusion
|
| 542 |
+
|
| 543 |
+
def tfg_create_model(
|
| 544 |
+
image_size,
|
| 545 |
+
num_channels,
|
| 546 |
+
num_res_blocks,
|
| 547 |
+
learn_sigma,
|
| 548 |
+
class_cond,
|
| 549 |
+
use_checkpoint,
|
| 550 |
+
attention_resolutions,
|
| 551 |
+
num_heads,
|
| 552 |
+
num_head_channels,
|
| 553 |
+
num_heads_upsample,
|
| 554 |
+
use_scale_shift_norm,
|
| 555 |
+
dropout,
|
| 556 |
+
resblock_updown,
|
| 557 |
+
use_fp16,
|
| 558 |
+
use_ref,
|
| 559 |
+
nframes,
|
| 560 |
+
nrefer,
|
| 561 |
+
use_audio,
|
| 562 |
+
audio_encoder_kwargs,
|
| 563 |
+
audio_as_style,
|
| 564 |
+
audio_as_style_encoder_mlp,
|
| 565 |
+
):
|
| 566 |
+
|
| 567 |
+
if image_size == 512:
|
| 568 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
| 569 |
+
elif image_size == 256:
|
| 570 |
+
channel_mult = (1, 1, 2, 3, 4, 4)
|
| 571 |
+
elif image_size == 128:
|
| 572 |
+
channel_mult = (1, 1, 2, 3, 4)
|
| 573 |
+
elif image_size == 64:
|
| 574 |
+
channel_mult = (1, 2, 3, 4)
|
| 575 |
+
else:
|
| 576 |
+
raise ValueError(f"unsupported large size: {image_size}")
|
| 577 |
+
|
| 578 |
+
attention_ds = []
|
| 579 |
+
if "-1" not in attention_resolutions: # -1 = no attention
|
| 580 |
+
for res in attention_resolutions.split(","):
|
| 581 |
+
attention_ds.append(image_size // int(res))
|
| 582 |
+
|
| 583 |
+
return TFGModel(
|
| 584 |
+
image_size=image_size,
|
| 585 |
+
in_channels=3,
|
| 586 |
+
model_channels=num_channels,
|
| 587 |
+
out_channels=(3 if not learn_sigma else 6),
|
| 588 |
+
num_res_blocks=num_res_blocks,
|
| 589 |
+
attention_resolutions=tuple(attention_ds),
|
| 590 |
+
dropout=dropout,
|
| 591 |
+
channel_mult=channel_mult,
|
| 592 |
+
num_classes=(NUM_CLASSES if class_cond else None),
|
| 593 |
+
use_checkpoint=use_checkpoint,
|
| 594 |
+
use_fp16=use_fp16,
|
| 595 |
+
num_heads=num_heads,
|
| 596 |
+
num_head_channels=num_head_channels,
|
| 597 |
+
num_heads_upsample=num_heads_upsample,
|
| 598 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 599 |
+
resblock_updown=resblock_updown,
|
| 600 |
+
use_ref=use_ref,
|
| 601 |
+
nframes=nframes,
|
| 602 |
+
nrefer=nrefer,
|
| 603 |
+
use_audio=use_audio,
|
| 604 |
+
audio_encoder_kwargs=audio_encoder_kwargs,
|
| 605 |
+
audio_as_style=audio_as_style,
|
| 606 |
+
audio_as_style_encoder_mlp=audio_as_style_encoder_mlp
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
|
guided-diffusion/guided_diffusion/tfg_data_util.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def normalise2(tensor):
|
| 4 |
+
'''[0,1] -> [-1,1]'''
|
| 5 |
+
return (tensor*2 - 1.).clamp(-1,1)
|
| 6 |
+
|
| 7 |
+
def tfg_data(dataloader, face_hide_percentage, use_ref, use_audio):#, sampling_use_gt_for_ref=False, noise = None):
|
| 8 |
+
def inf_gen(generator):
|
| 9 |
+
while True:
|
| 10 |
+
yield from generator
|
| 11 |
+
data = inf_gen(dataloader)
|
| 12 |
+
for batch in data:
|
| 13 |
+
img_batch, model_kwargs = tfg_process_batch(batch, face_hide_percentage, use_ref, use_audio)
|
| 14 |
+
yield img_batch, model_kwargs
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def tfg_process_batch(batch, face_hide_percentage, use_ref=False, use_audio=False, sampling_use_gt_for_ref=False, noise = None):
|
| 18 |
+
model_kwargs = {}
|
| 19 |
+
B, F,C, H, W = batch["image"].shape
|
| 20 |
+
img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous())
|
| 21 |
+
model_kwargs = tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise)
|
| 22 |
+
if use_ref:
|
| 23 |
+
model_kwargs = tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref)
|
| 24 |
+
if use_audio:
|
| 25 |
+
model_kwargs = tfg_add_audio(batch,model_kwargs)
|
| 26 |
+
return img_batch, model_kwargs
|
| 27 |
+
|
| 28 |
+
def tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref=False):
|
| 29 |
+
# assuming nrefer = 1
|
| 30 |
+
#[B, nframes, C, H, W] -> #[B*nframes, C, H, W]
|
| 31 |
+
if sampling_use_gt_for_ref:
|
| 32 |
+
B, F,C, H, W = batch["image"].shape
|
| 33 |
+
img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous())
|
| 34 |
+
model_kwargs["ref_img"] = img_batch
|
| 35 |
+
else:
|
| 36 |
+
_, _, C, H , W = batch["ref_img"].shape
|
| 37 |
+
ref_img = normalise2(batch["ref_img"].reshape(-1, C, H, W).contiguous())
|
| 38 |
+
model_kwargs["ref_img"] = ref_img
|
| 39 |
+
return model_kwargs
|
| 40 |
+
|
| 41 |
+
def tfg_add_audio(batch, model_kwargs):
|
| 42 |
+
# unet needs [BF, h, w] as input
|
| 43 |
+
B, F, _, h, w = batch["indiv_mels"].shape
|
| 44 |
+
indiv_mels = batch["indiv_mels"] # [B, F, 1, h, w]
|
| 45 |
+
indiv_mels = indiv_mels.squeeze(dim=2).reshape(B*F, h , w)
|
| 46 |
+
model_kwargs["indiv_mels"] = indiv_mels
|
| 47 |
+
# syncloss needs [B, 1, 80, 16] as input
|
| 48 |
+
if "mel" in batch:
|
| 49 |
+
mel = batch["mel"] #[B, 1, h, w]
|
| 50 |
+
model_kwargs["mel"]=mel
|
| 51 |
+
return model_kwargs
|
| 52 |
+
|
| 53 |
+
def tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise=None):
|
| 54 |
+
B, C, H, W = img_batch.shape
|
| 55 |
+
mask = torch.zeros(B,1,H,W)
|
| 56 |
+
mask_start_idx = int (H*(1-face_hide_percentage))
|
| 57 |
+
mask[:,:,mask_start_idx:,:]=1.
|
| 58 |
+
if noise is None:
|
| 59 |
+
noise = torch.randn_like(img_batch)
|
| 60 |
+
assert noise.shape == img_batch.shape, "Noise shape != Image shape"
|
| 61 |
+
cond_img = img_batch *(1. - mask)+mask*noise
|
| 62 |
+
|
| 63 |
+
model_kwargs["cond_img"] = cond_img
|
| 64 |
+
model_kwargs["mask"] = mask
|
| 65 |
+
return model_kwargs
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_n_params(model):
|
| 69 |
+
pp=0
|
| 70 |
+
for p in list(model.parameters()):
|
| 71 |
+
nn=1
|
| 72 |
+
for s in list(p.size()):
|
| 73 |
+
nn=nn*s
|
| 74 |
+
pp+=nn
|
| 75 |
+
return pp
|
guided-diffusion/guided_diffusion/unet.py
ADDED
|
@@ -0,0 +1,1275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .fp16_util import convert_module_to_f16, convert_module_to_f32
|
| 11 |
+
from .nn import (
|
| 12 |
+
checkpoint,
|
| 13 |
+
conv_nd,
|
| 14 |
+
linear,
|
| 15 |
+
avg_pool_nd,
|
| 16 |
+
zero_module,
|
| 17 |
+
normalization,
|
| 18 |
+
timestep_embedding,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AttentionPool2d(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
spacial_dim: int,
|
| 30 |
+
embed_dim: int,
|
| 31 |
+
num_heads_channels: int,
|
| 32 |
+
output_dim: int = None,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.positional_embedding = nn.Parameter(
|
| 36 |
+
th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
| 37 |
+
)
|
| 38 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
| 39 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
| 40 |
+
self.num_heads = embed_dim // num_heads_channels
|
| 41 |
+
self.attention = QKVAttention(self.num_heads)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
b, c, *_spatial = x.shape
|
| 45 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
| 46 |
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
| 47 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
| 48 |
+
x = self.qkv_proj(x)
|
| 49 |
+
x = self.attention(x)
|
| 50 |
+
x = self.c_proj(x)
|
| 51 |
+
return x[:, :, 0]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TimestepBlock(nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
@abstractmethod
|
| 60 |
+
def forward(self, x, emb):
|
| 61 |
+
"""
|
| 62 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 67 |
+
"""
|
| 68 |
+
A sequential module that passes timestep embeddings to the children that
|
| 69 |
+
support it as an extra input.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def forward(self, x, emb):
|
| 73 |
+
for layer in self:
|
| 74 |
+
if isinstance(layer, TimestepBlock):
|
| 75 |
+
x = layer(x, emb)
|
| 76 |
+
else:
|
| 77 |
+
x = layer(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Upsample(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
An upsampling layer with an optional convolution.
|
| 84 |
+
|
| 85 |
+
:param channels: channels in the inputs and outputs.
|
| 86 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 87 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 88 |
+
upsampling occurs in the inner-two dimensions.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.channels = channels
|
| 94 |
+
self.out_channels = out_channels or channels
|
| 95 |
+
self.use_conv = use_conv
|
| 96 |
+
self.dims = dims
|
| 97 |
+
if use_conv:
|
| 98 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
assert x.shape[1] == self.channels
|
| 102 |
+
if self.dims == 3:
|
| 103 |
+
x = F.interpolate(
|
| 104 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 108 |
+
if self.use_conv:
|
| 109 |
+
x = self.conv(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Downsample(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
A downsampling layer with an optional convolution.
|
| 116 |
+
|
| 117 |
+
:param channels: channels in the inputs and outputs.
|
| 118 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 119 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 120 |
+
downsampling occurs in the inner-two dimensions.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, stride=None):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.channels = channels
|
| 126 |
+
self.out_channels = out_channels or channels
|
| 127 |
+
self.use_conv = use_conv
|
| 128 |
+
self.dims = dims
|
| 129 |
+
if stride is None:
|
| 130 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
| 131 |
+
|
| 132 |
+
if use_conv:
|
| 133 |
+
self.op = conv_nd(
|
| 134 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
assert self.channels == self.out_channels
|
| 138 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
assert x.shape[1] == self.channels
|
| 142 |
+
return self.op(x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class ResBlock(TimestepBlock):
|
| 146 |
+
"""
|
| 147 |
+
A residual block that can optionally change the number of channels.
|
| 148 |
+
|
| 149 |
+
:param channels: the number of input channels.
|
| 150 |
+
:param emb_channels: the number of timestep embedding channels.
|
| 151 |
+
:param dropout: the rate of dropout.
|
| 152 |
+
:param out_channels: if specified, the number of out channels.
|
| 153 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
| 154 |
+
convolution instead of a smaller 1x1 convolution to change the
|
| 155 |
+
channels in the skip connection.
|
| 156 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 157 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
| 158 |
+
:param up: if True, use this block for upsampling.
|
| 159 |
+
:param down: if True, use this block for downsampling.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
channels,
|
| 165 |
+
emb_channels,
|
| 166 |
+
dropout,
|
| 167 |
+
out_channels=None,
|
| 168 |
+
use_conv=False,
|
| 169 |
+
use_scale_shift_norm=False,
|
| 170 |
+
dims=2,
|
| 171 |
+
use_checkpoint=False,
|
| 172 |
+
up=False,
|
| 173 |
+
down=False,
|
| 174 |
+
down_stride = None,
|
| 175 |
+
):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.channels = channels
|
| 178 |
+
self.emb_channels = emb_channels
|
| 179 |
+
self.dropout = dropout
|
| 180 |
+
self.out_channels = out_channels or channels
|
| 181 |
+
self.use_conv = use_conv
|
| 182 |
+
self.use_checkpoint = use_checkpoint
|
| 183 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 184 |
+
self.dims = dims
|
| 185 |
+
|
| 186 |
+
self.in_layers = nn.Sequential(
|
| 187 |
+
normalization(channels),
|
| 188 |
+
nn.SiLU(),
|
| 189 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.updown = up or down
|
| 193 |
+
|
| 194 |
+
if up:
|
| 195 |
+
self.h_upd = Upsample(channels, False, dims)
|
| 196 |
+
self.x_upd = Upsample(channels, False, dims)
|
| 197 |
+
elif down:
|
| 198 |
+
self.h_upd = Downsample(channels, False, dims, stride = down_stride)
|
| 199 |
+
self.x_upd = Downsample(channels, False, dims, stride = down_stride)
|
| 200 |
+
else:
|
| 201 |
+
self.h_upd = self.x_upd = nn.Identity()
|
| 202 |
+
|
| 203 |
+
self.emb_layers = nn.Sequential(
|
| 204 |
+
nn.SiLU(),
|
| 205 |
+
linear(
|
| 206 |
+
emb_channels,
|
| 207 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 208 |
+
),
|
| 209 |
+
)
|
| 210 |
+
self.out_layers = nn.Sequential(
|
| 211 |
+
normalization(self.out_channels),
|
| 212 |
+
nn.SiLU(),
|
| 213 |
+
nn.Dropout(p=dropout),
|
| 214 |
+
zero_module(
|
| 215 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
| 216 |
+
),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if self.out_channels == channels:
|
| 220 |
+
self.skip_connection = nn.Identity()
|
| 221 |
+
elif use_conv:
|
| 222 |
+
self.skip_connection = conv_nd(
|
| 223 |
+
dims, channels, self.out_channels, 3, padding=1
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
| 227 |
+
|
| 228 |
+
def forward(self, x, emb):
|
| 229 |
+
"""
|
| 230 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
| 231 |
+
|
| 232 |
+
:param x: an [N x C x ...] Tensor of features.
|
| 233 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
| 234 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 235 |
+
"""
|
| 236 |
+
return checkpoint(
|
| 237 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def _forward(self, x, emb):
|
| 241 |
+
if self.updown:
|
| 242 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 243 |
+
h = in_rest(x)
|
| 244 |
+
h = self.h_upd(h)
|
| 245 |
+
x = self.x_upd(x)
|
| 246 |
+
h = in_conv(h)
|
| 247 |
+
else:
|
| 248 |
+
h = self.in_layers(x)
|
| 249 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
| 250 |
+
while len(emb_out.shape) < len(h.shape):
|
| 251 |
+
emb_out = emb_out[..., None]
|
| 252 |
+
if self.use_scale_shift_norm:
|
| 253 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 254 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
| 255 |
+
h = out_norm(h) * (1 + scale) + shift
|
| 256 |
+
h = out_rest(h)
|
| 257 |
+
else:
|
| 258 |
+
h = h + emb_out
|
| 259 |
+
h = self.out_layers(h)
|
| 260 |
+
return self.skip_connection(x) + h
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class AttentionBlock(nn.Module):
|
| 264 |
+
"""
|
| 265 |
+
An attention block that allows spatial positions to attend to each other.
|
| 266 |
+
|
| 267 |
+
Originally ported from here, but adapted to the N-d case.
|
| 268 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
channels,
|
| 274 |
+
num_heads=1,
|
| 275 |
+
num_head_channels=-1,
|
| 276 |
+
use_checkpoint=False,
|
| 277 |
+
use_new_attention_order=False,
|
| 278 |
+
):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.channels = channels
|
| 281 |
+
if num_head_channels == -1:
|
| 282 |
+
self.num_heads = num_heads
|
| 283 |
+
else:
|
| 284 |
+
assert (
|
| 285 |
+
channels % num_head_channels == 0
|
| 286 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| 287 |
+
self.num_heads = channels // num_head_channels
|
| 288 |
+
self.use_checkpoint = use_checkpoint
|
| 289 |
+
self.norm = normalization(channels)
|
| 290 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
| 291 |
+
if use_new_attention_order:
|
| 292 |
+
# split qkv before split heads
|
| 293 |
+
self.attention = QKVAttention(self.num_heads)
|
| 294 |
+
else:
|
| 295 |
+
# split heads before split qkv
|
| 296 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
| 297 |
+
|
| 298 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
| 299 |
+
|
| 300 |
+
def forward(self, x):
|
| 301 |
+
return checkpoint(self._forward, (x,), self.parameters(), True)
|
| 302 |
+
|
| 303 |
+
def _forward(self, x):
|
| 304 |
+
b, c, *spatial = x.shape
|
| 305 |
+
x = x.reshape(b, c, -1)
|
| 306 |
+
qkv = self.qkv(self.norm(x))
|
| 307 |
+
h = self.attention(qkv)
|
| 308 |
+
h = self.proj_out(h)
|
| 309 |
+
return (x + h).reshape(b, c, *spatial)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def count_flops_attn(model, _x, y):
|
| 313 |
+
"""
|
| 314 |
+
A counter for the `thop` package to count the operations in an
|
| 315 |
+
attention operation.
|
| 316 |
+
Meant to be used like:
|
| 317 |
+
macs, params = thop.profile(
|
| 318 |
+
model,
|
| 319 |
+
inputs=(inputs, timestamps),
|
| 320 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
| 321 |
+
)
|
| 322 |
+
"""
|
| 323 |
+
b, c, *spatial = y[0].shape
|
| 324 |
+
num_spatial = int(np.prod(spatial))
|
| 325 |
+
# We perform two matmuls with the same number of ops.
|
| 326 |
+
# The first computes the weight matrix, the second computes
|
| 327 |
+
# the combination of the value vectors.
|
| 328 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
| 329 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class QKVAttentionLegacy(nn.Module):
|
| 333 |
+
"""
|
| 334 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
| 335 |
+
"""
|
| 336 |
+
|
| 337 |
+
def __init__(self, n_heads):
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.n_heads = n_heads
|
| 340 |
+
|
| 341 |
+
def forward(self, qkv):
|
| 342 |
+
"""
|
| 343 |
+
Apply QKV attention.
|
| 344 |
+
|
| 345 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
| 346 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
| 347 |
+
"""
|
| 348 |
+
bs, width, length = qkv.shape
|
| 349 |
+
assert width % (3 * self.n_heads) == 0
|
| 350 |
+
ch = width // (3 * self.n_heads)
|
| 351 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
| 352 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 353 |
+
weight = th.einsum(
|
| 354 |
+
"bct,bcs->bts", q * scale, k * scale
|
| 355 |
+
) # More stable with f16 than dividing afterwards
|
| 356 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 357 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
| 358 |
+
return a.reshape(bs, -1, length)
|
| 359 |
+
|
| 360 |
+
@staticmethod
|
| 361 |
+
def count_flops(model, _x, y):
|
| 362 |
+
return count_flops_attn(model, _x, y)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class QKVAttention(nn.Module):
|
| 366 |
+
"""
|
| 367 |
+
A module which performs QKV attention and splits in a different order.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
def __init__(self, n_heads):
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.n_heads = n_heads
|
| 373 |
+
|
| 374 |
+
def forward(self, qkv):
|
| 375 |
+
"""
|
| 376 |
+
Apply QKV attention.
|
| 377 |
+
|
| 378 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
| 379 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
| 380 |
+
"""
|
| 381 |
+
bs, width, length = qkv.shape
|
| 382 |
+
assert width % (3 * self.n_heads) == 0
|
| 383 |
+
ch = width // (3 * self.n_heads)
|
| 384 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 385 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 386 |
+
weight = th.einsum(
|
| 387 |
+
"bct,bcs->bts",
|
| 388 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
| 389 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
| 390 |
+
) # More stable with f16 than dividing afterwards
|
| 391 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 392 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
| 393 |
+
return a.reshape(bs, -1, length)
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def count_flops(model, _x, y):
|
| 397 |
+
return count_flops_attn(model, _x, y)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class UNetModel(nn.Module):
|
| 401 |
+
"""
|
| 402 |
+
The full UNet model with attention and timestep embedding.
|
| 403 |
+
|
| 404 |
+
:param in_channels: channels in the input Tensor.
|
| 405 |
+
:param model_channels: base channel count for the model.
|
| 406 |
+
:param out_channels: channels in the output Tensor.
|
| 407 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
| 408 |
+
:param attention_resolutions: a collection of downsample rates at which
|
| 409 |
+
attention will take place. May be a set, list, or tuple.
|
| 410 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
| 411 |
+
will be used.
|
| 412 |
+
:param dropout: the dropout probability.
|
| 413 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
| 414 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
| 415 |
+
downsampling.
|
| 416 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 417 |
+
:param num_classes: if specified (as an int), then this model will be
|
| 418 |
+
class-conditional with `num_classes` classes.
|
| 419 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
| 420 |
+
:param num_heads: the number of attention heads in each attention layer.
|
| 421 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
| 422 |
+
a fixed channel width per attention head.
|
| 423 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
| 424 |
+
of heads for upsampling. Deprecated.
|
| 425 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
| 426 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
| 427 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
| 428 |
+
increased efficiency.
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
def __init__(
|
| 432 |
+
self,
|
| 433 |
+
image_size,
|
| 434 |
+
in_channels,
|
| 435 |
+
model_channels,
|
| 436 |
+
out_channels,
|
| 437 |
+
num_res_blocks,
|
| 438 |
+
attention_resolutions,
|
| 439 |
+
dropout=0,
|
| 440 |
+
channel_mult=(1, 2, 4, 8),
|
| 441 |
+
conv_resample=True,
|
| 442 |
+
dims=2,
|
| 443 |
+
num_classes=None,
|
| 444 |
+
use_checkpoint=False,
|
| 445 |
+
use_fp16=False,
|
| 446 |
+
num_heads=1,
|
| 447 |
+
num_head_channels=-1,
|
| 448 |
+
num_heads_upsample=-1,
|
| 449 |
+
use_scale_shift_norm=False,
|
| 450 |
+
resblock_updown=False,
|
| 451 |
+
use_new_attention_order=False,
|
| 452 |
+
):
|
| 453 |
+
super().__init__()
|
| 454 |
+
|
| 455 |
+
if num_heads_upsample == -1:
|
| 456 |
+
num_heads_upsample = num_heads
|
| 457 |
+
|
| 458 |
+
self.image_size = image_size
|
| 459 |
+
self.in_channels = in_channels
|
| 460 |
+
self.model_channels = model_channels
|
| 461 |
+
self.out_channels = out_channels
|
| 462 |
+
self.num_res_blocks = num_res_blocks
|
| 463 |
+
self.attention_resolutions = attention_resolutions
|
| 464 |
+
self.dropout = dropout
|
| 465 |
+
self.channel_mult = channel_mult
|
| 466 |
+
self.conv_resample = conv_resample
|
| 467 |
+
self.dims = dims
|
| 468 |
+
self.num_classes = num_classes
|
| 469 |
+
self.use_checkpoint = use_checkpoint
|
| 470 |
+
self.use_fp16 = use_fp16
|
| 471 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 472 |
+
self.num_heads = num_heads
|
| 473 |
+
self.num_head_channels = num_head_channels
|
| 474 |
+
self.num_heads_upsample = num_heads_upsample
|
| 475 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 476 |
+
self.resblock_updown = resblock_updown
|
| 477 |
+
|
| 478 |
+
time_embed_dim = model_channels * 4
|
| 479 |
+
self.time_embed_dim = time_embed_dim
|
| 480 |
+
self.time_embed = nn.Sequential(
|
| 481 |
+
linear(model_channels, time_embed_dim),
|
| 482 |
+
nn.SiLU(),
|
| 483 |
+
linear(time_embed_dim, time_embed_dim),
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if self.num_classes is not None:
|
| 487 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 488 |
+
|
| 489 |
+
ch = input_ch = int(channel_mult[0] * model_channels)
|
| 490 |
+
self.input_blocks = nn.ModuleList(
|
| 491 |
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
| 492 |
+
)
|
| 493 |
+
self._feature_size = ch
|
| 494 |
+
input_block_chans = [ch]
|
| 495 |
+
ds = 1
|
| 496 |
+
for level, mult in enumerate(channel_mult):
|
| 497 |
+
for _ in range(num_res_blocks):
|
| 498 |
+
layers = [
|
| 499 |
+
ResBlock(
|
| 500 |
+
ch,
|
| 501 |
+
time_embed_dim,
|
| 502 |
+
dropout,
|
| 503 |
+
out_channels=int(mult * model_channels),
|
| 504 |
+
dims=dims,
|
| 505 |
+
use_checkpoint=use_checkpoint,
|
| 506 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 507 |
+
)
|
| 508 |
+
]
|
| 509 |
+
ch = int(mult * model_channels)
|
| 510 |
+
if ds in attention_resolutions:
|
| 511 |
+
layers.append(
|
| 512 |
+
AttentionBlock(
|
| 513 |
+
ch,
|
| 514 |
+
use_checkpoint=use_checkpoint,
|
| 515 |
+
num_heads=num_heads,
|
| 516 |
+
num_head_channels=num_head_channels,
|
| 517 |
+
use_new_attention_order=use_new_attention_order,
|
| 518 |
+
)
|
| 519 |
+
)
|
| 520 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 521 |
+
self._feature_size += ch
|
| 522 |
+
input_block_chans.append(ch)
|
| 523 |
+
if level != len(channel_mult) - 1:
|
| 524 |
+
out_ch = ch
|
| 525 |
+
self.input_blocks.append(
|
| 526 |
+
TimestepEmbedSequential(
|
| 527 |
+
ResBlock(
|
| 528 |
+
ch,
|
| 529 |
+
time_embed_dim,
|
| 530 |
+
dropout,
|
| 531 |
+
out_channels=out_ch,
|
| 532 |
+
dims=dims,
|
| 533 |
+
use_checkpoint=use_checkpoint,
|
| 534 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 535 |
+
down=True,
|
| 536 |
+
)
|
| 537 |
+
if resblock_updown
|
| 538 |
+
else Downsample(
|
| 539 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 540 |
+
)
|
| 541 |
+
)
|
| 542 |
+
)
|
| 543 |
+
ch = out_ch
|
| 544 |
+
input_block_chans.append(ch)
|
| 545 |
+
ds *= 2
|
| 546 |
+
self._feature_size += ch
|
| 547 |
+
|
| 548 |
+
self.middle_block = TimestepEmbedSequential(
|
| 549 |
+
ResBlock(
|
| 550 |
+
ch,
|
| 551 |
+
time_embed_dim,
|
| 552 |
+
dropout,
|
| 553 |
+
dims=dims,
|
| 554 |
+
use_checkpoint=use_checkpoint,
|
| 555 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 556 |
+
),
|
| 557 |
+
AttentionBlock(
|
| 558 |
+
ch,
|
| 559 |
+
use_checkpoint=use_checkpoint,
|
| 560 |
+
num_heads=num_heads,
|
| 561 |
+
num_head_channels=num_head_channels,
|
| 562 |
+
use_new_attention_order=use_new_attention_order,
|
| 563 |
+
),
|
| 564 |
+
ResBlock(
|
| 565 |
+
ch,
|
| 566 |
+
time_embed_dim,
|
| 567 |
+
dropout,
|
| 568 |
+
dims=dims,
|
| 569 |
+
use_checkpoint=use_checkpoint,
|
| 570 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 571 |
+
),
|
| 572 |
+
)
|
| 573 |
+
self._feature_size += ch
|
| 574 |
+
|
| 575 |
+
self.output_blocks = nn.ModuleList([])
|
| 576 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
| 577 |
+
for i in range(num_res_blocks + 1):
|
| 578 |
+
ich = input_block_chans.pop()
|
| 579 |
+
layers = [
|
| 580 |
+
ResBlock(
|
| 581 |
+
ch + ich,
|
| 582 |
+
time_embed_dim,
|
| 583 |
+
dropout,
|
| 584 |
+
out_channels=int(model_channels * mult),
|
| 585 |
+
dims=dims,
|
| 586 |
+
use_checkpoint=use_checkpoint,
|
| 587 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 588 |
+
)
|
| 589 |
+
]
|
| 590 |
+
ch = int(model_channels * mult)
|
| 591 |
+
if ds in attention_resolutions:
|
| 592 |
+
layers.append(
|
| 593 |
+
AttentionBlock(
|
| 594 |
+
ch,
|
| 595 |
+
use_checkpoint=use_checkpoint,
|
| 596 |
+
num_heads=num_heads_upsample,
|
| 597 |
+
num_head_channels=num_head_channels,
|
| 598 |
+
use_new_attention_order=use_new_attention_order,
|
| 599 |
+
)
|
| 600 |
+
)
|
| 601 |
+
if level and i == num_res_blocks:
|
| 602 |
+
out_ch = ch
|
| 603 |
+
layers.append(
|
| 604 |
+
ResBlock(
|
| 605 |
+
ch,
|
| 606 |
+
time_embed_dim,
|
| 607 |
+
dropout,
|
| 608 |
+
out_channels=out_ch,
|
| 609 |
+
dims=dims,
|
| 610 |
+
use_checkpoint=use_checkpoint,
|
| 611 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 612 |
+
up=True,
|
| 613 |
+
)
|
| 614 |
+
if resblock_updown
|
| 615 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| 616 |
+
)
|
| 617 |
+
ds //= 2
|
| 618 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
| 619 |
+
self._feature_size += ch
|
| 620 |
+
|
| 621 |
+
self.out = nn.Sequential(
|
| 622 |
+
normalization(ch),
|
| 623 |
+
nn.SiLU(),
|
| 624 |
+
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
def convert_to_fp16(self):
|
| 628 |
+
"""
|
| 629 |
+
Convert the torso of the model to float16.
|
| 630 |
+
"""
|
| 631 |
+
self.input_blocks.apply(convert_module_to_f16)
|
| 632 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 633 |
+
self.output_blocks.apply(convert_module_to_f16)
|
| 634 |
+
|
| 635 |
+
def convert_to_fp32(self):
|
| 636 |
+
"""
|
| 637 |
+
Convert the torso of the model to float32.
|
| 638 |
+
"""
|
| 639 |
+
self.input_blocks.apply(convert_module_to_f32)
|
| 640 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 641 |
+
self.output_blocks.apply(convert_module_to_f32)
|
| 642 |
+
|
| 643 |
+
def forward(self, x, timesteps, y=None):
|
| 644 |
+
"""
|
| 645 |
+
Apply the model to an input batch.
|
| 646 |
+
|
| 647 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
| 648 |
+
:param timesteps: a 1-D batch of timesteps.
|
| 649 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
| 650 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 651 |
+
"""
|
| 652 |
+
assert (y is not None) == (
|
| 653 |
+
self.num_classes is not None
|
| 654 |
+
), "must specify y if and only if the model is class-conditional"
|
| 655 |
+
|
| 656 |
+
hs = []
|
| 657 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| 658 |
+
|
| 659 |
+
if self.num_classes is not None:
|
| 660 |
+
assert y.shape == (x.shape[0],)
|
| 661 |
+
emb = emb + self.label_emb(y)
|
| 662 |
+
|
| 663 |
+
h = x.type(self.dtype)
|
| 664 |
+
for module in self.input_blocks:
|
| 665 |
+
h = module(h, emb)
|
| 666 |
+
hs.append(h)
|
| 667 |
+
h = self.middle_block(h, emb)
|
| 668 |
+
for module in self.output_blocks:
|
| 669 |
+
h = th.cat([h, hs.pop()], dim=1)
|
| 670 |
+
h = module(h, emb)
|
| 671 |
+
h = h.type(x.dtype)
|
| 672 |
+
return self.out(h)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
class SuperResModel(UNetModel):
|
| 676 |
+
"""
|
| 677 |
+
A UNetModel that performs super-resolution.
|
| 678 |
+
|
| 679 |
+
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
| 680 |
+
"""
|
| 681 |
+
|
| 682 |
+
def __init__(self, image_size, in_channels, *args, **kwargs):
|
| 683 |
+
super().__init__(image_size, in_channels * 2, *args, **kwargs)
|
| 684 |
+
|
| 685 |
+
def forward(self, x, timesteps, low_res=None, **kwargs):
|
| 686 |
+
_, _, new_height, new_width = x.shape
|
| 687 |
+
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
| 688 |
+
x = th.cat([x, upsampled], dim=1)
|
| 689 |
+
return super().forward(x, timesteps, **kwargs)
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
class EncoderUNetModel(nn.Module):
|
| 693 |
+
"""
|
| 694 |
+
The half UNet model with attention and timestep embedding.
|
| 695 |
+
|
| 696 |
+
For usage, see UNet.
|
| 697 |
+
"""
|
| 698 |
+
|
| 699 |
+
def __init__(
|
| 700 |
+
self,
|
| 701 |
+
image_size,
|
| 702 |
+
in_channels,
|
| 703 |
+
model_channels,
|
| 704 |
+
out_channels,
|
| 705 |
+
num_res_blocks,
|
| 706 |
+
attention_resolutions,
|
| 707 |
+
dropout=0,
|
| 708 |
+
channel_mult=(1, 2, 4, 8),
|
| 709 |
+
conv_resample=True,
|
| 710 |
+
dims=2,
|
| 711 |
+
use_checkpoint=False,
|
| 712 |
+
use_fp16=False,
|
| 713 |
+
num_heads=1,
|
| 714 |
+
num_head_channels=-1,
|
| 715 |
+
num_heads_upsample=-1,
|
| 716 |
+
use_scale_shift_norm=False,
|
| 717 |
+
resblock_updown=False,
|
| 718 |
+
use_new_attention_order=False,
|
| 719 |
+
pool="adaptive",
|
| 720 |
+
):
|
| 721 |
+
super().__init__()
|
| 722 |
+
|
| 723 |
+
if num_heads_upsample == -1:
|
| 724 |
+
num_heads_upsample = num_heads
|
| 725 |
+
|
| 726 |
+
self.in_channels = in_channels
|
| 727 |
+
self.model_channels = model_channels
|
| 728 |
+
self.out_channels = out_channels
|
| 729 |
+
self.num_res_blocks = num_res_blocks
|
| 730 |
+
self.attention_resolutions = attention_resolutions
|
| 731 |
+
self.dropout = dropout
|
| 732 |
+
self.channel_mult = channel_mult
|
| 733 |
+
self.conv_resample = conv_resample
|
| 734 |
+
self.use_checkpoint = use_checkpoint
|
| 735 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 736 |
+
self.num_heads = num_heads
|
| 737 |
+
self.num_head_channels = num_head_channels
|
| 738 |
+
self.num_heads_upsample = num_heads_upsample
|
| 739 |
+
|
| 740 |
+
time_embed_dim = model_channels * 4
|
| 741 |
+
self.time_embed = nn.Sequential(
|
| 742 |
+
linear(model_channels, time_embed_dim),
|
| 743 |
+
nn.SiLU(),
|
| 744 |
+
linear(time_embed_dim, time_embed_dim),
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
ch = int(channel_mult[0] * model_channels)
|
| 748 |
+
self.input_blocks = nn.ModuleList(
|
| 749 |
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
| 750 |
+
)
|
| 751 |
+
self._feature_size = ch
|
| 752 |
+
input_block_chans = [ch]
|
| 753 |
+
ds = 1
|
| 754 |
+
for level, mult in enumerate(channel_mult):
|
| 755 |
+
for _ in range(num_res_blocks):
|
| 756 |
+
layers = [
|
| 757 |
+
ResBlock(
|
| 758 |
+
ch,
|
| 759 |
+
time_embed_dim,
|
| 760 |
+
dropout,
|
| 761 |
+
out_channels=int(mult * model_channels),
|
| 762 |
+
dims=dims,
|
| 763 |
+
use_checkpoint=use_checkpoint,
|
| 764 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 765 |
+
)
|
| 766 |
+
]
|
| 767 |
+
ch = int(mult * model_channels)
|
| 768 |
+
if ds in attention_resolutions:
|
| 769 |
+
layers.append(
|
| 770 |
+
AttentionBlock(
|
| 771 |
+
ch,
|
| 772 |
+
use_checkpoint=use_checkpoint,
|
| 773 |
+
num_heads=num_heads,
|
| 774 |
+
num_head_channels=num_head_channels,
|
| 775 |
+
use_new_attention_order=use_new_attention_order,
|
| 776 |
+
)
|
| 777 |
+
)
|
| 778 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 779 |
+
self._feature_size += ch
|
| 780 |
+
input_block_chans.append(ch)
|
| 781 |
+
if level != len(channel_mult) - 1:
|
| 782 |
+
out_ch = ch
|
| 783 |
+
self.input_blocks.append(
|
| 784 |
+
TimestepEmbedSequential(
|
| 785 |
+
ResBlock(
|
| 786 |
+
ch,
|
| 787 |
+
time_embed_dim,
|
| 788 |
+
dropout,
|
| 789 |
+
out_channels=out_ch,
|
| 790 |
+
dims=dims,
|
| 791 |
+
use_checkpoint=use_checkpoint,
|
| 792 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 793 |
+
down=True,
|
| 794 |
+
)
|
| 795 |
+
if resblock_updown
|
| 796 |
+
else Downsample(
|
| 797 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 798 |
+
)
|
| 799 |
+
)
|
| 800 |
+
)
|
| 801 |
+
ch = out_ch
|
| 802 |
+
input_block_chans.append(ch)
|
| 803 |
+
ds *= 2
|
| 804 |
+
self._feature_size += ch
|
| 805 |
+
|
| 806 |
+
self.middle_block = TimestepEmbedSequential(
|
| 807 |
+
ResBlock(
|
| 808 |
+
ch,
|
| 809 |
+
time_embed_dim,
|
| 810 |
+
dropout,
|
| 811 |
+
dims=dims,
|
| 812 |
+
use_checkpoint=use_checkpoint,
|
| 813 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 814 |
+
),
|
| 815 |
+
AttentionBlock(
|
| 816 |
+
ch,
|
| 817 |
+
use_checkpoint=use_checkpoint,
|
| 818 |
+
num_heads=num_heads,
|
| 819 |
+
num_head_channels=num_head_channels,
|
| 820 |
+
use_new_attention_order=use_new_attention_order,
|
| 821 |
+
),
|
| 822 |
+
ResBlock(
|
| 823 |
+
ch,
|
| 824 |
+
time_embed_dim,
|
| 825 |
+
dropout,
|
| 826 |
+
dims=dims,
|
| 827 |
+
use_checkpoint=use_checkpoint,
|
| 828 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 829 |
+
),
|
| 830 |
+
)
|
| 831 |
+
self._feature_size += ch
|
| 832 |
+
self.pool = pool
|
| 833 |
+
if pool == "adaptive":
|
| 834 |
+
self.out = nn.Sequential(
|
| 835 |
+
normalization(ch),
|
| 836 |
+
nn.SiLU(),
|
| 837 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
| 838 |
+
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
| 839 |
+
nn.Flatten(),
|
| 840 |
+
)
|
| 841 |
+
elif pool == "attention":
|
| 842 |
+
assert num_head_channels != -1
|
| 843 |
+
self.out = nn.Sequential(
|
| 844 |
+
normalization(ch),
|
| 845 |
+
nn.SiLU(),
|
| 846 |
+
AttentionPool2d(
|
| 847 |
+
(image_size // ds), ch, num_head_channels, out_channels
|
| 848 |
+
),
|
| 849 |
+
)
|
| 850 |
+
elif pool == "spatial":
|
| 851 |
+
self.out = nn.Sequential(
|
| 852 |
+
nn.Linear(self._feature_size, 2048),
|
| 853 |
+
nn.ReLU(),
|
| 854 |
+
nn.Linear(2048, self.out_channels),
|
| 855 |
+
)
|
| 856 |
+
elif pool == "spatial_v2":
|
| 857 |
+
self.out = nn.Sequential(
|
| 858 |
+
nn.Linear(self._feature_size, 2048),
|
| 859 |
+
normalization(2048),
|
| 860 |
+
nn.SiLU(),
|
| 861 |
+
nn.Linear(2048, self.out_channels),
|
| 862 |
+
)
|
| 863 |
+
else:
|
| 864 |
+
raise NotImplementedError(f"Unexpected {pool} pooling")
|
| 865 |
+
|
| 866 |
+
def convert_to_fp16(self):
|
| 867 |
+
"""
|
| 868 |
+
Convert the torso of the model to float16.
|
| 869 |
+
"""
|
| 870 |
+
self.input_blocks.apply(convert_module_to_f16)
|
| 871 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 872 |
+
|
| 873 |
+
def convert_to_fp32(self):
|
| 874 |
+
"""
|
| 875 |
+
Convert the torso of the model to float32.
|
| 876 |
+
"""
|
| 877 |
+
self.input_blocks.apply(convert_module_to_f32)
|
| 878 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 879 |
+
|
| 880 |
+
def forward(self, x, timesteps):
|
| 881 |
+
"""
|
| 882 |
+
Apply the model to an input batch.
|
| 883 |
+
|
| 884 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
| 885 |
+
:param timesteps: a 1-D batch of timesteps.
|
| 886 |
+
:return: an [N x K] Tensor of outputs.
|
| 887 |
+
"""
|
| 888 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| 889 |
+
|
| 890 |
+
results = []
|
| 891 |
+
h = x.type(self.dtype)
|
| 892 |
+
for module in self.input_blocks:
|
| 893 |
+
h = module(h, emb)
|
| 894 |
+
if self.pool.startswith("spatial"):
|
| 895 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
| 896 |
+
h = self.middle_block(h, emb)
|
| 897 |
+
if self.pool.startswith("spatial"):
|
| 898 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
| 899 |
+
h = th.cat(results, axis=-1)
|
| 900 |
+
return self.out(h)
|
| 901 |
+
else:
|
| 902 |
+
h = h.type(x.dtype)
|
| 903 |
+
return self.out(h)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
#________________________________ tfg model ________________________________#
|
| 907 |
+
class TFGModel(UNetModel):
|
| 908 |
+
'''
|
| 909 |
+
Talking Face Generation using UNet model
|
| 910 |
+
'''
|
| 911 |
+
def __init__(self,
|
| 912 |
+
image_size,
|
| 913 |
+
in_channels,
|
| 914 |
+
model_channels,
|
| 915 |
+
out_channels,
|
| 916 |
+
*args,
|
| 917 |
+
use_ref = False,
|
| 918 |
+
nframes = 1,
|
| 919 |
+
nrefer = 0,
|
| 920 |
+
use_audio = False,
|
| 921 |
+
audio_encoder_kwargs=None,
|
| 922 |
+
audio_as_style=False, # audio conditioned as style instead of concatenation in the middle
|
| 923 |
+
audio_as_style_encoder_mlp=False, # use mlp instead of audio encoder
|
| 924 |
+
**kwargs
|
| 925 |
+
):
|
| 926 |
+
if use_ref:
|
| 927 |
+
super().__init__(image_size, in_channels * (1+1+nrefer), model_channels, out_channels * 1, *args, **kwargs)
|
| 928 |
+
else:
|
| 929 |
+
super().__init__(image_size, in_channels * (1+1), model_channels, out_channels * 1, *args, **kwargs)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
self.use_ref = use_ref
|
| 933 |
+
self.nframes = nframes
|
| 934 |
+
self.nrefer = nrefer
|
| 935 |
+
self.use_audio = use_audio
|
| 936 |
+
|
| 937 |
+
if self.use_audio:
|
| 938 |
+
if audio_encoder_kwargs is not None:
|
| 939 |
+
self.audio_encoder_kwargs = audio_encoder_kwargs
|
| 940 |
+
else:
|
| 941 |
+
self.audio_encoder_kwargs = {}
|
| 942 |
+
|
| 943 |
+
self.audio_as_style = audio_as_style
|
| 944 |
+
self.audio_as_style_encoder_mlp = audio_as_style_encoder_mlp
|
| 945 |
+
|
| 946 |
+
self.audio_encoder = TFGAudioEncoder(
|
| 947 |
+
nframes = self.nframes,
|
| 948 |
+
dropout = self.dropout,
|
| 949 |
+
conv_resample = self.conv_resample,
|
| 950 |
+
dims = self.dims,
|
| 951 |
+
use_checkpoints = self.use_checkpoint,
|
| 952 |
+
use_fp16 = self.use_fp16,
|
| 953 |
+
use_scale_shift_norm = self.use_scale_shift_norm,
|
| 954 |
+
resblock_updown = self.resblock_updown,
|
| 955 |
+
**self.audio_encoder_kwargs
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
if not self.audio_as_style:
|
| 959 |
+
#concatenate audio encoding to the video encoding
|
| 960 |
+
old_middle_block_head = self.middle_block[0]
|
| 961 |
+
mid_img_ch = old_middle_block_head.channels
|
| 962 |
+
mid_aud_ch = self.audio_encoder.out_channels
|
| 963 |
+
self.middle_block[0] = ResBlock(
|
| 964 |
+
mid_img_ch + mid_aud_ch, #combined image and audio channels
|
| 965 |
+
old_middle_block_head.emb_channels,
|
| 966 |
+
old_middle_block_head.dropout,
|
| 967 |
+
out_channels = old_middle_block_head.out_channels,
|
| 968 |
+
dims = old_middle_block_head.dims,
|
| 969 |
+
use_checkpoint=old_middle_block_head.use_checkpoint,
|
| 970 |
+
use_scale_shift_norm=old_middle_block_head.use_scale_shift_norm,
|
| 971 |
+
)
|
| 972 |
+
else: # audio as style
|
| 973 |
+
if self.audio_as_style_encoder_mlp:
|
| 974 |
+
old_conv_encoder = self.audio_encoder
|
| 975 |
+
audio_dim = old_conv_encoder.audio_dim
|
| 976 |
+
audio_frames_per_video = old_conv_encoder.audio_frames_per_video
|
| 977 |
+
self.audio_encoder = nn.Sequential(
|
| 978 |
+
nn.Flatten(),
|
| 979 |
+
linear(audio_dim+audio_frames_per_video, self.time_embed_dim),
|
| 980 |
+
normalization(self.time_embed_dim),
|
| 981 |
+
nn.SiLU(),
|
| 982 |
+
linear(self.time_embed_dim, self.time_embed_dim),
|
| 983 |
+
)
|
| 984 |
+
else: # use conv_encoder+mlp to get style
|
| 985 |
+
# similar to the classifier defined
|
| 986 |
+
self.audio_encoder_to_style = nn.Sequential(
|
| 987 |
+
normalization(self.audio_encoder.out_channels),
|
| 988 |
+
nn.SiLU(),
|
| 989 |
+
nn.AdaptiveAvgPool2d((1,1)),
|
| 990 |
+
zero_module( #-> makes inital weights 0
|
| 991 |
+
conv_nd(self.dims, self.audio_encoder.out_channels, self.time_embed_dim, 1)
|
| 992 |
+
),
|
| 993 |
+
nn.Flatten(),
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
def convert_to_fp16(self):
|
| 997 |
+
"""
|
| 998 |
+
Convert the torso of the model to float16.
|
| 999 |
+
"""
|
| 1000 |
+
self.input_blocks.apply(convert_module_to_f16)
|
| 1001 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 1002 |
+
self.output_blocks.apply(convert_module_to_f16)
|
| 1003 |
+
if self.use_audio:
|
| 1004 |
+
self.audio_encoder.apply(convert_module_to_f16)
|
| 1005 |
+
if self.audio_as_style:
|
| 1006 |
+
self.audio_encoder_to_style.apply(convert_module_to_f16)
|
| 1007 |
+
|
| 1008 |
+
def convert_to_fp32(self):
|
| 1009 |
+
"""
|
| 1010 |
+
Convert the torso of the model to float32.
|
| 1011 |
+
"""
|
| 1012 |
+
self.input_blocks.apply(convert_module_to_f32)
|
| 1013 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 1014 |
+
self.output_blocks.apply(convert_module_to_f32)
|
| 1015 |
+
if self.use_audio:
|
| 1016 |
+
self.audio_encoder.apply(convert_module_to_f32)
|
| 1017 |
+
if self.audio_as_style:
|
| 1018 |
+
self.audio_encoder_to_style.apply(convert_module_to_f32)
|
| 1019 |
+
|
| 1020 |
+
def forward(self, x, timesteps, cond_img=None, mask = None, ref_img=None, indiv_mels=None, **kwargs):
|
| 1021 |
+
|
| 1022 |
+
#preprocessing
|
| 1023 |
+
x = x * mask + (1. - mask) * cond_img # mask the top half of the input
|
| 1024 |
+
x = th.cat([x,cond_img], dim=1)
|
| 1025 |
+
if self.use_ref:
|
| 1026 |
+
x=th.cat([x,ref_img], dim=1)
|
| 1027 |
+
|
| 1028 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
if self.use_audio:
|
| 1032 |
+
if self.audio_as_style:
|
| 1033 |
+
#audio encoder
|
| 1034 |
+
if self.audio_as_style_encoder_mlp:#mlp uses fp32
|
| 1035 |
+
a = self.audio_encoder(indiv_mels)
|
| 1036 |
+
a = self.audio_encoder_to_style(a)
|
| 1037 |
+
a = a.type(self.dtype)
|
| 1038 |
+
else:# conv uses fp16
|
| 1039 |
+
a = indiv_mels.type(self.dtype)
|
| 1040 |
+
a = self.audio_encoder(a)
|
| 1041 |
+
a= self.audio_encoder_to_style(a)
|
| 1042 |
+
#combine
|
| 1043 |
+
emb = emb + a
|
| 1044 |
+
#video encoder
|
| 1045 |
+
hs = []
|
| 1046 |
+
h = x.type(self.dtype)
|
| 1047 |
+
for module in self.input_blocks:
|
| 1048 |
+
h = module(h, emb)
|
| 1049 |
+
hs.append(h)
|
| 1050 |
+
|
| 1051 |
+
else: # concat audio in the middle
|
| 1052 |
+
#audio encoder
|
| 1053 |
+
a = indiv_mels.type(self.dtype)
|
| 1054 |
+
a = self.audio_encoder(a)
|
| 1055 |
+
#video encoder
|
| 1056 |
+
hs = []
|
| 1057 |
+
h = x.type(self.dtype)
|
| 1058 |
+
for module in self.input_blocks:
|
| 1059 |
+
h = module(h, emb)
|
| 1060 |
+
hs.append(h)
|
| 1061 |
+
#combine
|
| 1062 |
+
h = th.cat([h,a], dim=1)
|
| 1063 |
+
|
| 1064 |
+
#middle block
|
| 1065 |
+
h = self.middle_block(h, emb)
|
| 1066 |
+
|
| 1067 |
+
# decoder
|
| 1068 |
+
for module in self.output_blocks:
|
| 1069 |
+
h = th.cat([h, hs.pop()], dim=1)
|
| 1070 |
+
h = module(h, emb)
|
| 1071 |
+
h = h.type(x.dtype)
|
| 1072 |
+
return self.out(h)
|
| 1073 |
+
|
| 1074 |
+
|
| 1075 |
+
class TFGAudioEncoder(nn.Module):
|
| 1076 |
+
"""
|
| 1077 |
+
Audio Encoder
|
| 1078 |
+
|
| 1079 |
+
with audio_dim = 80,
|
| 1080 |
+
audio_frames_per_video = 16
|
| 1081 |
+
init_spatial_dim = 64
|
| 1082 |
+
model_channels=32
|
| 1083 |
+
channel_mult=(1,2,3,4)
|
| 1084 |
+
|
| 1085 |
+
following are the output shapes ->
|
| 1086 |
+
init: [BF, 80, 16]
|
| 1087 |
+
after in_block: [BF, 64, 16]
|
| 1088 |
+
adding new dim [BF, 1, 64, 16]
|
| 1089 |
+
encoder block before entering the loop: [BF, 32, 64, 16]
|
| 1090 |
+
level: 0
|
| 1091 |
+
0 _ 0 : [BF, 32, 64, 16]
|
| 1092 |
+
0 _ 1 : [BF, 32, 64, 16]
|
| 1093 |
+
0 _ 2 : [BF, 32, 32, 16]
|
| 1094 |
+
level: 1
|
| 1095 |
+
1 _ 0 : [BF, 64, 32, 16]
|
| 1096 |
+
1 _ 1 : [BF, 64, 32, 16]
|
| 1097 |
+
1 _ 2 : [BF, 64, 16, 16]
|
| 1098 |
+
level: 2
|
| 1099 |
+
2 _ 0 : [BF, 96, 16, 16]
|
| 1100 |
+
2 _ 1 : [BF, 96, 16, 16]
|
| 1101 |
+
2 _ 2 : [BF, 96, 8, 8]
|
| 1102 |
+
level: 3
|
| 1103 |
+
3 _ 0 : [BF, 128, 8, 8]
|
| 1104 |
+
3 _ 1 : [BF, 128, 8, 8]
|
| 1105 |
+
middle block: [BF, 128, 8, 8]
|
| 1106 |
+
out: [BF, 128, 8, 8]
|
| 1107 |
+
"""
|
| 1108 |
+
def __init__(
|
| 1109 |
+
self,
|
| 1110 |
+
audio_dim = 80,
|
| 1111 |
+
audio_frames_per_video = 16,
|
| 1112 |
+
nframes=1,
|
| 1113 |
+
|
| 1114 |
+
init_spatial_dim = 64,
|
| 1115 |
+
model_channels=32,
|
| 1116 |
+
out_channels=-1,
|
| 1117 |
+
num_res_blocks=2,
|
| 1118 |
+
dropout=0,
|
| 1119 |
+
channel_mult=(1,2,3,4), #(1,1,2,4,8),
|
| 1120 |
+
conv_resample=True,
|
| 1121 |
+
dims=2,
|
| 1122 |
+
use_checkpoint = False,
|
| 1123 |
+
use_fp16=False,
|
| 1124 |
+
use_scale_shift_norm=False,
|
| 1125 |
+
resblock_updown=False,
|
| 1126 |
+
**kwargs
|
| 1127 |
+
):
|
| 1128 |
+
super().__init__()
|
| 1129 |
+
self.audio_dim = audio_dim
|
| 1130 |
+
self.audio_frames_per_video = audio_frames_per_video
|
| 1131 |
+
self.nframes = nframes
|
| 1132 |
+
self.model_channels = model_channels
|
| 1133 |
+
self.out_channels = out_channels if out_channels > 0 else model_channels * channel_mult[-1]
|
| 1134 |
+
self.num_res_blocks = num_res_blocks
|
| 1135 |
+
self.dropout = dropout
|
| 1136 |
+
self.channel_mult = channel_mult
|
| 1137 |
+
self.conv_resample = conv_resample
|
| 1138 |
+
self.use_checkpoint = use_checkpoint
|
| 1139 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 1140 |
+
|
| 1141 |
+
time_embed_dim = model_channels
|
| 1142 |
+
self.time_embed = nn.Sequential(
|
| 1143 |
+
linear(model_channels, time_embed_dim),
|
| 1144 |
+
nn.SiLU(),
|
| 1145 |
+
linear(time_embed_dim, time_embed_dim),
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
ch = int(channel_mult[0] * model_channels)
|
| 1149 |
+
# init_spatial_dim = 4 * ( 2** (len(channel_mult)-1))
|
| 1150 |
+
|
| 1151 |
+
# convert spatial dim 80->64 using Conv1D: [N*F,80,16] -> [N*F, 64, 16]
|
| 1152 |
+
_conv_dim, _in_channels, _out_channels = 1, self.audio_dim, init_spatial_dim
|
| 1153 |
+
self.input_block = TimestepEmbedSequential(
|
| 1154 |
+
conv_nd(_conv_dim, _in_channels, _out_channels, 3, padding=1),
|
| 1155 |
+
normalization(_out_channels),
|
| 1156 |
+
nn.SiLU()
|
| 1157 |
+
)
|
| 1158 |
+
|
| 1159 |
+
# manually reshape [N*F, 64, 16] -> [N*F, 1, 64, 16] in __forward__()
|
| 1160 |
+
|
| 1161 |
+
# [NF, 1, 64, 16] -> [N*F, model_channels*channel_mult[0], 64, 16]
|
| 1162 |
+
# can't use a resblock, bc groupnorm needs 32 sized group of channels
|
| 1163 |
+
self.encoder_blocks = nn.ModuleList(
|
| 1164 |
+
[
|
| 1165 |
+
TimestepEmbedSequential(
|
| 1166 |
+
conv_nd(dims, 1, ch, 3, padding=1 )
|
| 1167 |
+
)
|
| 1168 |
+
|
| 1169 |
+
]
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
self._feature_size = ch
|
| 1173 |
+
input_block_chans = [ch]
|
| 1174 |
+
|
| 1175 |
+
ds = 1
|
| 1176 |
+
for level, mult in enumerate(channel_mult):
|
| 1177 |
+
for _ in range(num_res_blocks):
|
| 1178 |
+
layers = [
|
| 1179 |
+
ResBlock(
|
| 1180 |
+
ch,
|
| 1181 |
+
time_embed_dim,
|
| 1182 |
+
dropout,
|
| 1183 |
+
out_channels=int(mult*model_channels),
|
| 1184 |
+
dims = dims,
|
| 1185 |
+
use_checkpoint=use_checkpoint,
|
| 1186 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 1187 |
+
)
|
| 1188 |
+
]
|
| 1189 |
+
ch = int(mult*model_channels)
|
| 1190 |
+
self.encoder_blocks.append(TimestepEmbedSequential(*layers))
|
| 1191 |
+
self._feature_size += ch
|
| 1192 |
+
input_block_chans.append(ch)
|
| 1193 |
+
if level != len(channel_mult)-1:
|
| 1194 |
+
out_ch = ch
|
| 1195 |
+
self.encoder_blocks.append(
|
| 1196 |
+
TimestepEmbedSequential(
|
| 1197 |
+
ResBlock(
|
| 1198 |
+
ch,
|
| 1199 |
+
time_embed_dim,
|
| 1200 |
+
dropout,
|
| 1201 |
+
out_channels=out_ch,
|
| 1202 |
+
dims = dims,
|
| 1203 |
+
use_checkpoint=use_checkpoint,
|
| 1204 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 1205 |
+
down = True,
|
| 1206 |
+
down_stride = (2,1) if (init_spatial_dim//ds) > self.audio_frames_per_video else (2,2),
|
| 1207 |
+
)
|
| 1208 |
+
if resblock_updown
|
| 1209 |
+
else Downsample(
|
| 1210 |
+
ch, conv_resample, dims=dims, out_channels=out_ch,
|
| 1211 |
+
down_stride = (2,1) if (init_spatial_dim//ds) > self.audio_frames_per_video else (2,2),
|
| 1212 |
+
)
|
| 1213 |
+
)
|
| 1214 |
+
)
|
| 1215 |
+
ch = out_ch
|
| 1216 |
+
input_block_chans.append(ch)
|
| 1217 |
+
ds*=2
|
| 1218 |
+
self._feature_size += ch
|
| 1219 |
+
|
| 1220 |
+
self.middle_block = TimestepEmbedSequential(
|
| 1221 |
+
ResBlock(
|
| 1222 |
+
ch,
|
| 1223 |
+
time_embed_dim,
|
| 1224 |
+
dropout,
|
| 1225 |
+
out_channels=self.out_channels,
|
| 1226 |
+
dims=dims,
|
| 1227 |
+
use_checkpoint=use_checkpoint,
|
| 1228 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 1229 |
+
),
|
| 1230 |
+
)
|
| 1231 |
+
self._feature_size += ch
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
# self.out = Upsample(self.out_channels, False, dims)
|
| 1235 |
+
self.out = nn.Identity()
|
| 1236 |
+
|
| 1237 |
+
def convert_to_fp16(self):
|
| 1238 |
+
"""
|
| 1239 |
+
Convert the torso of the model to float16.
|
| 1240 |
+
"""
|
| 1241 |
+
self.input_blocks.apply(convert_module_to_f16)
|
| 1242 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 1243 |
+
self.output_blocks.apply(convert_module_to_f16)
|
| 1244 |
+
|
| 1245 |
+
def convert_to_fp32(self):
|
| 1246 |
+
"""
|
| 1247 |
+
Convert the torso of the model to float32.
|
| 1248 |
+
"""
|
| 1249 |
+
self.input_blocks.apply(convert_module_to_f32)
|
| 1250 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 1251 |
+
self.output_blocks.apply(convert_module_to_f32)
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
def forward(self, x):
|
| 1255 |
+
h = x.type(self.dtype)
|
| 1256 |
+
BF_in, H_in, W_in = h.shape
|
| 1257 |
+
|
| 1258 |
+
#Fixed time embedding,(for using the same modules)
|
| 1259 |
+
t = th.zeros(BF_in, dtype=th.long, device = x.device)
|
| 1260 |
+
emb = self.time_embed(timestep_embedding(t, self.model_channels))
|
| 1261 |
+
|
| 1262 |
+
#80 -> 64 using Conv1D
|
| 1263 |
+
h= self.input_block(h, emb)
|
| 1264 |
+
_, H, W = h.shape
|
| 1265 |
+
h = h.reshape(BF_in, 1, H, W) #[B*F, 64, 16] -> [B*F, 1, 64, 4]
|
| 1266 |
+
# call encoder blocks
|
| 1267 |
+
for module in self.encoder_blocks:
|
| 1268 |
+
h = module(h,emb)
|
| 1269 |
+
h = self.middle_block(h,emb)
|
| 1270 |
+
h = h.type(x.dtype)
|
| 1271 |
+
return self.out(h) # -> [B*F, 256, 8, 8]
|
| 1272 |
+
#______________________________________________________________________#
|
| 1273 |
+
|
| 1274 |
+
|
| 1275 |
+
|
guided-diffusion/setup.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="guided-diffusion",
|
| 5 |
+
py_modules=["guided_diffusion"],
|
| 6 |
+
install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
|
| 7 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
librosa==0.9.2
|
| 2 |
+
opencv-python==4.5.5.64
|
| 3 |
+
opencv-contrib-python==4.6.0.66
|
| 4 |
+
tensorboard==2.11.0
|
| 5 |
+
tqdm==4.64.1
|
| 6 |
+
mpi4py-mpich==3.1.2
|
| 7 |
+
av==9.2.0
|
| 8 |
+
torch --extra-index-url https://download.pytorch.org/whl/cu113
|
| 9 |
+
torchvision --extra-index-url https://download.pytorch.org/whl/cu113
|
| 10 |
+
torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
|
| 11 |
+
-e ./guided-diffusion
|
scripts/inference.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#set paths and arguments
|
| 4 |
+
real_video_root='dataset/VoxCeleb2/vox2_test_mp4/mp4/'
|
| 5 |
+
model_path="checkpoints/checkpoint.pt"
|
| 6 |
+
sample_path="output_dir"
|
| 7 |
+
sample_mode="cross" # or "reconstruction"
|
| 8 |
+
NUM_GPUS=2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#cross vs reconstruction
|
| 15 |
+
filelist_recon='dataset/filelists/voxceleb2_test_n_5000_reconstruction_5k.txt'
|
| 16 |
+
filelist_cross='dataset/filelists/voxceleb2_test_n_5000_seed_797_cross_5K.txt'
|
| 17 |
+
if [ "$sample_mode" = "reconstruction" ]; then
|
| 18 |
+
sample_input_flags="--sampling_input_type=first_frame --sampling_ref_type=first_frame"
|
| 19 |
+
filelist=$filelist_recon
|
| 20 |
+
elif [ "$sample_mode" = "cross" ]; then
|
| 21 |
+
sample_input_flags="--sampling_input_type=gt --sampling_ref_type=gt"
|
| 22 |
+
filelist=$filelist_cross
|
| 23 |
+
else
|
| 24 |
+
echo "Error: sample_mode can only be \"cross\" or \"reconstruction\""
|
| 25 |
+
exit 0
|
| 26 |
+
fi
|
| 27 |
+
test_video_dir=$real_video_root
|
| 28 |
+
mkdir -p $sample_path
|
| 29 |
+
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
|
| 30 |
+
DIFFUSION_FLAGS="--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
|
| 31 |
+
SAMPLE_FLAGS="--sampling_seed=7 $sample_input_flags --timestep_respacing ddim25 --use_ddim True --model_path=$model_path --sample_path=$sample_path"
|
| 32 |
+
DATA_FLAGS="--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32 "
|
| 33 |
+
TFG_FLAGS="--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
|
| 34 |
+
GEN_FLAGS="--generate_from_filelist 1 --test_video_dir=$test_video_dir --filelist=$filelist --save_orig=False --face_det_batch_size 64 --pads 0,0,0,0"
|
| 35 |
+
|
| 36 |
+
if [ "$NUM_GPUS" -gt 1 ]; then
|
| 37 |
+
mpiexec -n $NUM_GPUS python generate_dist.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
|
| 38 |
+
else
|
| 39 |
+
python generate.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
|
| 40 |
+
fi
|
scripts/inference_single_video.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#set paths and arguments
|
| 4 |
+
sample_mode="cross" # or "reconstruction"
|
| 5 |
+
NUM_GPUS=1
|
| 6 |
+
generate_from_filelist=0
|
| 7 |
+
video_path="path/to/video.mp4"
|
| 8 |
+
audio_path="path/to/audio.mp4"
|
| 9 |
+
out_path="path/to/output.mp4"
|
| 10 |
+
model_path="path/to/model.pt"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#cross vs reconstruction
|
| 15 |
+
if [ "$sample_mode" = "reconstruction" ]; then
|
| 16 |
+
sample_input_flags="--sampling_input_type=first_frame --sampling_ref_type=first_frame"
|
| 17 |
+
elif [ "$sample_mode" = "cross" ]; then
|
| 18 |
+
sample_input_flags="--sampling_input_type=gt --sampling_ref_type=gt"
|
| 19 |
+
else
|
| 20 |
+
echo "Error: sample_mode can only be \"cross\" or \"reconstruction\""
|
| 21 |
+
exit 0
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
|
| 25 |
+
DIFFUSION_FLAGS="--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
|
| 26 |
+
SAMPLE_FLAGS="--sampling_seed=7 $sample_input_flags --timestep_respacing ddim25 --use_ddim True --model_path=$model_path"
|
| 27 |
+
DATA_FLAGS="--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32 "
|
| 28 |
+
TFG_FLAGS="--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
|
| 29 |
+
GEN_FLAGS="--generate_from_filelist $generate_from_filelist --video_path=$video_path --audio_path=$audio_path --out_path=$out_path --save_orig=False --face_det_batch_size 64 --pads 0,0,0,0 --is_voxceleb2=False"
|
| 30 |
+
|
| 31 |
+
if [ "$NUM_GPUS" -gt 1 ]; then
|
| 32 |
+
mpiexec -n $NUM_GPUS python generate_dist.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
|
| 33 |
+
else
|
| 34 |
+
python generate.py $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS
|
| 35 |
+
fi
|