| | import json |
| | import os |
| | from typing import Optional, Sequence, Tuple |
| |
|
| | from src.video_util import get_frame_count |
| |
|
| |
|
| | class RerenderConfig: |
| |
|
| | def __init__(self): |
| | ... |
| |
|
| | def create_from_parameters(self, |
| | input_path: str, |
| | output_path: str, |
| | prompt: str, |
| | work_dir: Optional[str] = None, |
| | key_subdir: str = 'keys', |
| | frame_count: Optional[int] = None, |
| | interval: int = 10, |
| | crop: Sequence[int] = (0, 0, 0, 0), |
| | sd_model: Optional[str] = None, |
| | a_prompt: str = '', |
| | n_prompt: str = '', |
| | ddim_steps=20, |
| | scale=7.5, |
| | control_type: str = 'HED', |
| | control_strength=1, |
| | seed: int = -1, |
| | image_resolution: int = 512, |
| | x0_strength: float = -1, |
| | style_update_freq: int = 10, |
| | cross_period: Tuple[float, float] = (0, 1), |
| | warp_period: Tuple[float, float] = (0, 0.1), |
| | mask_period: Tuple[float, float] = (0.5, 0.8), |
| | ada_period: Tuple[float, float] = (1.0, 1.0), |
| | mask_strength: float = 0.5, |
| | inner_strength: float = 0.9, |
| | smooth_boundary: bool = True, |
| | color_preserve: bool = True, |
| | **kwargs): |
| | self.input_path = input_path |
| | self.output_path = output_path |
| | self.prompt = prompt |
| | self.work_dir = work_dir |
| | if work_dir is None: |
| | self.work_dir = os.path.dirname(output_path) |
| | self.key_dir = os.path.join(self.work_dir, key_subdir) |
| | self.first_dir = os.path.join(self.work_dir, 'first') |
| |
|
| | |
| | if not os.path.isfile(input_path): |
| | raise FileNotFoundError(f'Cannot find video file {input_path}') |
| | self.input_dir = os.path.join(self.work_dir, 'video') |
| |
|
| | self.frame_count = frame_count |
| | if frame_count is None: |
| | self.frame_count = get_frame_count(self.input_path) |
| | self.interval = interval |
| | self.crop = crop |
| | self.sd_model = sd_model |
| | self.a_prompt = a_prompt |
| | self.n_prompt = n_prompt |
| | self.ddim_steps = ddim_steps |
| | self.scale = scale |
| | self.control_type = control_type |
| | if self.control_type == 'canny': |
| | self.canny_low = kwargs.get('canny_low', 100) |
| | self.canny_high = kwargs.get('canny_high', 200) |
| | else: |
| | self.canny_low = None |
| | self.canny_high = None |
| | self.control_strength = control_strength |
| | self.seed = seed |
| | self.image_resolution = image_resolution |
| | self.x0_strength = x0_strength |
| | self.style_update_freq = style_update_freq |
| | self.cross_period = cross_period |
| | self.mask_period = mask_period |
| | self.warp_period = warp_period |
| | self.ada_period = ada_period |
| | self.mask_strength = mask_strength |
| | self.inner_strength = inner_strength |
| | self.smooth_boundary = smooth_boundary |
| | self.color_preserve = color_preserve |
| |
|
| | os.makedirs(self.input_dir, exist_ok=True) |
| | os.makedirs(self.work_dir, exist_ok=True) |
| | os.makedirs(self.key_dir, exist_ok=True) |
| | os.makedirs(self.first_dir, exist_ok=True) |
| |
|
| | def create_from_path(self, cfg_path: str): |
| | with open(cfg_path, 'r') as fp: |
| | cfg = json.load(fp) |
| | kwargs = dict() |
| |
|
| | def append_if_not_none(key): |
| | value = cfg.get(key, None) |
| | if value is not None: |
| | kwargs[key] = value |
| |
|
| | kwargs['input_path'] = cfg['input'] |
| | kwargs['output_path'] = cfg['output'] |
| | kwargs['prompt'] = cfg['prompt'] |
| | append_if_not_none('work_dir') |
| | append_if_not_none('key_subdir') |
| | append_if_not_none('frame_count') |
| | append_if_not_none('interval') |
| | append_if_not_none('crop') |
| | append_if_not_none('sd_model') |
| | append_if_not_none('a_prompt') |
| | append_if_not_none('n_prompt') |
| | append_if_not_none('ddim_steps') |
| | append_if_not_none('scale') |
| | append_if_not_none('control_type') |
| | if kwargs.get('control_type', '') == 'canny': |
| | append_if_not_none('canny_low') |
| | append_if_not_none('canny_high') |
| | append_if_not_none('control_strength') |
| | append_if_not_none('seed') |
| | append_if_not_none('image_resolution') |
| | append_if_not_none('x0_strength') |
| | append_if_not_none('style_update_freq') |
| | append_if_not_none('cross_period') |
| | append_if_not_none('warp_period') |
| | append_if_not_none('mask_period') |
| | append_if_not_none('ada_period') |
| | append_if_not_none('mask_strength') |
| | append_if_not_none('inner_strength') |
| | append_if_not_none('smooth_boundary') |
| | append_if_not_none('color_perserve') |
| | self.create_from_parameters(**kwargs) |
| |
|
| | @property |
| | def use_warp(self): |
| | return self.warp_period[0] <= self.warp_period[1] |
| |
|
| | @property |
| | def use_mask(self): |
| | return self.mask_period[0] <= self.mask_period[1] |
| |
|
| | @property |
| | def use_ada(self): |
| | return self.ada_period[0] <= self.ada_period[1] |
| |
|