Spaces:
Build error
Build error
| import asyncio | |
| import json | |
| from pathlib import Path | |
| import asyncstdlib | |
| import numpy as np | |
| import pandas as pd | |
| from pydub import AudioSegment | |
| from stf_alternative.compose import get_compose_func_without_keying, get_keying_func | |
| from stf_alternative.dataset import LipGanAudio, LipGanImage, LipGanRemoteImage | |
| from stf_alternative.inference import ( | |
| adictzip, | |
| ainference_model_remote, | |
| audio_encode, | |
| dictzip, | |
| get_head_box, | |
| inference_model, | |
| inference_model_remote, | |
| ) | |
| from stf_alternative.preprocess_dir.utils import face_finder as ff | |
| from stf_alternative.readers import ( | |
| AsyncProcessPoolBatchIterator, | |
| ProcessPoolBatchIterator, | |
| get_image_folder_async_process_reader, | |
| get_image_folder_process_reader, | |
| ) | |
| from stf_alternative.util import ( | |
| acycle, | |
| get_crop_mp4_dir, | |
| get_frame_dir, | |
| get_preprocess_dir, | |
| icycle, | |
| read_config, | |
| ) | |
| def calc_audio_std(audio_segment): | |
| sample = np.array(audio_segment.get_array_of_samples(), dtype=np.int16) | |
| max_value = np.iinfo( | |
| np.int8 | |
| if audio_segment.sample_width == 1 | |
| else np.int16 | |
| if audio_segment.sample_width == 2 | |
| else np.int32 | |
| ).max | |
| return sample.std() / max_value, len(sample) | |
| class RunningAudioNormalizer: | |
| def __init__(self, ref_audio_segment, decay_rate=0.01): | |
| self.ref_std, _ = calc_audio_std(ref_audio_segment) | |
| self.running_var = np.float64(0) | |
| self.running_cnt = 0 | |
| self.decay_rate = decay_rate | |
| def __call__(self, audio_segment): | |
| std, cnt = calc_audio_std(audio_segment) | |
| self.running_var = (self.running_var + (std**2) * cnt) * (1 - self.decay_rate) | |
| self.running_cnt = (self.running_cnt + cnt) * (1 - self.decay_rate) | |
| return audio_segment._spawn( | |
| (audio_segment.get_array_of_samples() / self.std * self.ref_std) | |
| .astype(np.int16) | |
| .tobytes() | |
| ) | |
| def std(self): | |
| return np.sqrt(self.running_var / self.running_cnt) | |
| def get_video_metadata(preprocess_dir): | |
| json_path = preprocess_dir / "metadata.json" | |
| with open(json_path, "r") as f: | |
| return json.load(f) | |
| class Template: | |
| def __init__( | |
| self, | |
| config_path, | |
| model, | |
| template_video_path, | |
| wav_std=False, | |
| ref_wav=None, | |
| verbose=False, | |
| ): | |
| self.config = read_config(config_path) | |
| self.model = model | |
| self.template_video_path = Path(template_video_path) | |
| self.preprocess_dir = Path( | |
| get_preprocess_dir(model.work_root_path, model.args.name) | |
| ) | |
| self.crop_mp4_dir = Path( | |
| get_crop_mp4_dir(self.preprocess_dir, template_video_path) | |
| ) | |
| self.dataset_dir = self.crop_mp4_dir / f"{Path(template_video_path).stem}_000" | |
| self.template_frames_path = Path( | |
| get_frame_dir(self.preprocess_dir, template_video_path, ratio=1.0) | |
| ) | |
| self.verbose = verbose | |
| self.remote = self.model.args.model_type == "remote" | |
| self.audio_normalizer = ( | |
| RunningAudioNormalizer(ref_wav) if wav_std else lambda x: x | |
| ) | |
| self.df = pd.read_pickle(self.dataset_dir / "df_fan.pickle") | |
| metadata = get_video_metadata(self.preprocess_dir) | |
| self.fps = metadata["fps"] | |
| self.width, self.height = metadata["width"], metadata["height"] | |
| self.keying_func = get_keying_func(self) | |
| self.compose_func = get_compose_func_without_keying(self, ratio=1.0) | |
| self.move = "move" in self.config.keys() and self.config.move | |
| self.inference_func = inference_model_remote if self.remote else inference_model | |
| self.batch_size = self.model.args.batch_size | |
| self.unit = 1000 / self.fps | |
| def _get_reader(self, num_skip_frames): | |
| assert self.template_frames_path.exists() | |
| return get_image_folder_process_reader( | |
| data_path=self.template_frames_path, | |
| num_skip_frames=num_skip_frames, | |
| preload=self.batch_size, | |
| ) | |
| def _get_local_face_dataset(self, num_skip_frames): | |
| return LipGanImage( | |
| args=self.model.args, | |
| path=self.dataset_dir, | |
| num_skip_frames=num_skip_frames, | |
| ) | |
| def _get_remote_face_dataset(self, num_skip_frames): | |
| return LipGanRemoteImage( | |
| args=self.model.args, | |
| path=self.dataset_dir, | |
| num_skip_frames=num_skip_frames, | |
| ) | |
| def _get_mel_dataset(self, audio_segment): | |
| image_count = round( | |
| audio_segment.duration_seconds * self.fps | |
| ) # 패딩 했기 때문에 batch_size로 나뉜다 | |
| ids = list(range(image_count)) | |
| mel = audio_encode( | |
| model=self.model, | |
| audio_segment=audio_segment, | |
| device=self.model.device, | |
| ) | |
| return LipGanAudio( | |
| args=self.model.args, | |
| id_list=ids, | |
| mel=mel, | |
| fps=self.fps, | |
| ) | |
| def _get_face_dataset(self, num_skip_frames): | |
| if self.remote: | |
| return self._get_remote_face_dataset(num_skip_frames=num_skip_frames) | |
| else: | |
| return self._get_local_face_dataset(num_skip_frames=num_skip_frames) | |
| def _wrap_reader(self, reader): | |
| reader = icycle(reader) | |
| return reader | |
| def _wrap_dataset(self, dataset): | |
| dataloader = ProcessPoolBatchIterator( | |
| dataset=dataset, | |
| batch_size=self.batch_size, | |
| ) | |
| return dataloader | |
| def get_reader(self, num_skip_frames=0): | |
| reader = self._get_reader(num_skip_frames=num_skip_frames) | |
| reader = self._wrap_reader(reader) | |
| return reader | |
| def get_mel_loader(self, audio_segment): | |
| mel_dataset = self._get_mel_dataset(audio_segment) | |
| return self._wrap_dataset(mel_dataset) | |
| def get_face_loader(self, num_skip_frames=0): | |
| face_dataset = self._get_face_dataset(num_skip_frames=num_skip_frames) | |
| return self._wrap_dataset(face_dataset) # need cycle | |
| # padding according to batch size. | |
| def pad(self, audio_segment): | |
| num_frames = audio_segment.duration_seconds * self.fps | |
| pad = AudioSegment.silent( | |
| (self.batch_size - (num_frames % self.batch_size)) * (1000 / self.fps) | |
| ) | |
| return audio_segment + pad | |
| def _prepare_data( | |
| self, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ): | |
| video_start_offset_frame = video_start_offset_frame % len(self.df) | |
| padded = self.pad(audio_segment) | |
| face_dataset = self._get_face_dataset(num_skip_frames=video_start_offset_frame) | |
| mel_dataset = self._get_mel_dataset(audio_segment=padded) | |
| n_frames = len(mel_dataset) | |
| assert n_frames % self.batch_size == 0 | |
| face_loader = self._wrap_dataset(face_dataset) | |
| mel_loader = self._wrap_dataset(mel_dataset) | |
| return padded, face_loader, mel_loader | |
| def gen_infer( | |
| self, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ): | |
| padded, face_loader, mel_loader = self._prepare_data( | |
| audio_segment=audio_segment, | |
| video_start_offset_frame=video_start_offset_frame, | |
| ) | |
| for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): | |
| inferred = self.inference_func(self.model, v, self.model.device) | |
| for j, it in enumerate(inferred): | |
| chunk_pivot = i * self.unit * self.batch_size + j * self.unit | |
| chunk = padded[chunk_pivot : chunk_pivot + self.unit] | |
| yield it, chunk | |
| def gen_infer_batch( | |
| self, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ): | |
| padded, face_loader, mel_loader = self._prepare_data( | |
| audio_segment=audio_segment, | |
| video_start_offset_frame=video_start_offset_frame, | |
| ) | |
| for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): | |
| inferred = self.inference_func(self.model, v, self.model.device) | |
| yield inferred, padded[ | |
| i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size | |
| ] | |
| def gen_infer_batch_future( | |
| self, | |
| pool, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ): | |
| padded, face_loader, mel_loader = self._prepare_data( | |
| audio_segment=audio_segment, | |
| video_start_offset_frame=video_start_offset_frame, | |
| ) | |
| futures = [] | |
| for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): | |
| futures.append( | |
| pool.submit(self.inference_func, self.model, v, self.model.device) | |
| ) | |
| for i, future in enumerate(futures): | |
| yield future, padded[ | |
| i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size | |
| ] | |
| def gen_infer_concurrent( | |
| self, | |
| pool, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ): | |
| for future, chunk in self.gen_infer_batch_future( | |
| pool, audio_segment, video_start_offset_frame | |
| ): | |
| for i, inferred in enumerate(future.result()): | |
| yield inferred, chunk[i * self.unit : (i + 1) * self.unit] | |
| def compose( | |
| self, | |
| idx, | |
| frame, | |
| output, | |
| ): | |
| head_box_idx = idx % len(self.df) | |
| head_box = get_head_box( | |
| self.df, | |
| move=self.move, | |
| head_box_idx=head_box_idx, | |
| ) | |
| alpha2 = self.keying_func(output, head_box_idx, head_box) | |
| frame = self.compose_func(alpha2, frame[:, :, :4], head_box_idx) | |
| return frame | |
| def gen_frames( | |
| self, | |
| audio_segment, | |
| video_start_offset_frame, | |
| reader=None, | |
| ): | |
| reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame) | |
| gen_infer = self.gen_infer(audio_segment, video_start_offset_frame) | |
| for idx, ((o, a), f) in enumerate( | |
| zip(gen_infer, reader), video_start_offset_frame | |
| ): | |
| composed = self.compose(idx, f, o) | |
| yield composed, a | |
| def gen_frames_concurrent( | |
| self, | |
| pool, | |
| audio_segment, | |
| video_start_offset_frame, | |
| reader=None, | |
| ): | |
| reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame) | |
| gen_infer = self.gen_infer_concurrent( | |
| pool, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ) | |
| for idx, ((o, a), f) in enumerate( | |
| zip(gen_infer, reader), video_start_offset_frame | |
| ): | |
| yield self.compose(idx, f, o), a | |
| class AsyncTemplate(Template): | |
| async def agen_infer_batch_future( | |
| self, | |
| pool, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ): | |
| assert self.remote | |
| padded, face_loader, mel_loader = await self._aprepare_data( | |
| pool, | |
| audio_segment=audio_segment, | |
| video_start_offset_frame=video_start_offset_frame, | |
| ) | |
| futures = [] | |
| async for i, v in asyncstdlib.enumerate( | |
| adictzip(aiter(mel_loader), aiter(face_loader)) | |
| ): | |
| futures.append( | |
| asyncio.create_task( | |
| ainference_model_remote(pool, self.model, v, self.model.device) | |
| ) | |
| ) | |
| for i, future in enumerate(futures): | |
| yield future, padded[ | |
| i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size | |
| ] | |
| async def _awrap_dataset(self, dataset): | |
| dataloader = AsyncProcessPoolBatchIterator( | |
| dataset=dataset, | |
| batch_size=self.batch_size, | |
| ) | |
| return dataloader | |
| async def _aprepare_data( | |
| self, | |
| pool, | |
| audio_segment, | |
| video_start_offset_frame, | |
| ): | |
| video_start_offset_frame = video_start_offset_frame % len(self.df) | |
| padded = self.pad(audio_segment) | |
| loop = asyncio.get_running_loop() | |
| face_dataset, mel_dataset = await asyncio.gather( | |
| loop.run_in_executor( | |
| pool, self._get_face_dataset, video_start_offset_frame | |
| ), | |
| loop.run_in_executor(pool, self._get_mel_dataset, padded), | |
| ) | |
| n_frames = len(mel_dataset) | |
| assert n_frames % self.batch_size == 0 | |
| face_loader = await self._awrap_dataset(face_dataset) | |
| mel_loader = await self._awrap_dataset(mel_dataset) | |
| return padded, face_loader, mel_loader | |
| def _aget_reader(self, num_skip_frames): | |
| assert self.template_frames_path.exists() | |
| return get_image_folder_async_process_reader( | |
| data_path=self.template_frames_path, | |
| num_skip_frames=num_skip_frames, | |
| preload=self.batch_size, | |
| ) | |
| def _awrap_reader(self, reader): | |
| reader = acycle(reader) | |
| return reader | |
| def aget_reader(self, num_skip_frames=0): | |
| reader = self._aget_reader(num_skip_frames=num_skip_frames) | |
| reader = self._awrap_reader(reader) | |
| return reader | |