Spaces:
Configuration error
Configuration error
| from pathlib import Path | |
| from typing import Optional | |
| from tqdm import tqdm | |
| import hydra | |
| from omegaconf import DictConfig | |
| import numpy as np | |
| from src.simswap import SimSwap | |
| from src.DataManager.ImageDataManager import ImageDataManager | |
| from src.DataManager.VideoDataManager import VideoDataManager | |
| from src.DataManager.utils import imread_rgb | |
| class Application: | |
| def __init__(self, config: DictConfig): | |
| id_image_path = Path(config.data.id_image) | |
| specific_id_image_path = Path(config.data.specific_id_image) | |
| att_image_path = Path(config.data.att_image) | |
| att_video_path = Path(config.data.att_video) | |
| output_dir = Path(config.data.output_dir) | |
| assert id_image_path.exists(), f"Can't find {id_image_path} file!" | |
| self.id_image: Optional[np.ndarray] = imread_rgb(id_image_path) | |
| self.specific_id_image: Optional[np.ndarray] = ( | |
| imread_rgb(specific_id_image_path) | |
| if specific_id_image_path and specific_id_image_path.is_file() | |
| else None | |
| ) | |
| self.att_image: Optional[ImageDataManager] = None | |
| if att_image_path and (att_image_path.is_file() or att_image_path.is_dir()): | |
| self.att_image: Optional[ImageDataManager] = ImageDataManager( | |
| src_data=att_image_path, output_dir=output_dir | |
| ) | |
| self.att_video: Optional[VideoDataManager] = None | |
| if att_video_path and att_video_path.is_file(): | |
| self.att_video: Optional[VideoDataManager] = VideoDataManager( | |
| src_data=att_video_path, output_dir=output_dir, clean_work_dir=config.data.clean_work_dir | |
| ) | |
| assert not (self.att_video and self.att_image), "Only one attribute source can be used!" | |
| self.data_manager = self.att_video if self.att_video else self.att_image | |
| self.model = SimSwap( | |
| config=config.pipeline, | |
| id_image=self.id_image, | |
| specific_image=self.specific_id_image, | |
| ) | |
| def run(self): | |
| for _ in tqdm(range(len(self.data_manager))): | |
| att_img = self.data_manager.get() | |
| output = self.model(att_img) | |
| self.data_manager.save(output) | |
| def main(config: DictConfig): | |
| app = Application(config) | |
| app.run() | |
| if __name__ == "__main__": | |
| main() | |