| | import os |
| | import gradio as gr |
| | from sd_model_cfg import model_dict |
| | from app import process, process0, process1, process2, get_frame_count, cfg_to_input |
| |
|
| | DESCRIPTION = ''' |
| | ## Rerender A Video |
| | ### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper. |
| | ### To avoid overload, we set limitations to the maximum frame number (8) and the maximum frame resolution (512x768). |
| | ### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU. |
| | ### How to use: |
| | 1. **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before run the whole video. |
| | 2. **Run Key Frames**: translate all the key frames based on the settings of the first frame |
| | 3. **Run All**: **Run 1st Key Frame** and **Run Key Frames** |
| | 4. **Run Propagation**: propogate the key frames to other frames for full video translation. This part will be released upon the publication of the paper. |
| | ### Tips: |
| | 1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**. |
| | 2. Pixel-aware fusion may not work for large or quick motions. |
| | 3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering. |
| | 4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style. |
| | 5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py). |
| | 6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one. |
| | |
| | **This code is for research purpose and non-commercial use only.** |
| | |
| | <a href="https://huggingface.co/spaces/Anonymous-sub/Rerender?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> |
| | <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p> |
| | ''' |
| |
|
| | MAX_KEYFRAME = 100000000 |
| |
|
| | block = gr.Blocks().queue() |
| | with block: |
| | with gr.Row(): |
| | gr.Markdown(DESCRIPTION) |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_path = gr.Video(label='Input Video', |
| | source='upload', |
| | format='mp4', |
| | visible=True) |
| | prompt = gr.Textbox(label='Prompt') |
| | seed = gr.Slider(label='Seed', |
| | minimum=0, |
| | maximum=2147483647, |
| | step=1, |
| | value=0, |
| | randomize=True) |
| | run_button = gr.Button(value='Run All') |
| | with gr.Row(): |
| | run_button1 = gr.Button(value='Run 1st Key Frame') |
| | run_button2 = gr.Button(value='Run Key Frames') |
| | run_button3 = gr.Button(value='Run Propagation') |
| | with gr.Accordion('Advanced options for the 1st frame translation', |
| | open=False): |
| | image_resolution = gr.Slider( |
| | label='Frame rsolution', |
| | minimum=256, |
| | maximum=512, |
| | value=512, |
| | step=64, |
| | info='To avoid overload, maximum 512') |
| | control_strength = gr.Slider(label='ControNet strength', |
| | minimum=0.0, |
| | maximum=2.0, |
| | value=1.0, |
| | step=0.01) |
| | x0_strength = gr.Slider( |
| | label='Denoising strength', |
| | minimum=0.00, |
| | maximum=1.05, |
| | value=0.75, |
| | step=0.05, |
| | info=('0: fully recover the input.' |
| | '1.05: fully rerender the input.')) |
| | color_preserve = gr.Checkbox( |
| | label='Preserve color', |
| | value=True, |
| | info='Keep the color of the input video') |
| | with gr.Row(): |
| | left_crop = gr.Slider(label='Left crop length', |
| | minimum=0, |
| | maximum=512, |
| | value=0, |
| | step=1) |
| | right_crop = gr.Slider(label='Right crop length', |
| | minimum=0, |
| | maximum=512, |
| | value=0, |
| | step=1) |
| | with gr.Row(): |
| | top_crop = gr.Slider(label='Top crop length', |
| | minimum=0, |
| | maximum=512, |
| | value=0, |
| | step=1) |
| | bottom_crop = gr.Slider(label='Bottom crop length', |
| | minimum=0, |
| | maximum=512, |
| | value=0, |
| | step=1) |
| | with gr.Row(): |
| | control_type = gr.Dropdown(['HED', 'canny'], |
| | label='Control type', |
| | value='HED') |
| | low_threshold = gr.Slider(label='Canny low threshold', |
| | minimum=1, |
| | maximum=255, |
| | value=100, |
| | step=1) |
| | high_threshold = gr.Slider(label='Canny high threshold', |
| | minimum=1, |
| | maximum=255, |
| | value=200, |
| | step=1) |
| | ddim_steps = gr.Slider(label='Steps', |
| | minimum=1, |
| | maximum=20, |
| | value=20, |
| | step=1, |
| | info='To avoid overload, maximum 20') |
| | scale = gr.Slider(label='CFG scale', |
| | minimum=0.1, |
| | maximum=30.0, |
| | value=7.5, |
| | step=0.1) |
| | sd_model_list = list(model_dict.keys()) |
| | sd_model = gr.Dropdown(sd_model_list, |
| | label='Base model', |
| | value='Stable Diffusion 1.5') |
| | a_prompt = gr.Textbox(label='Added prompt', |
| | value='best quality, extremely detailed') |
| | n_prompt = gr.Textbox( |
| | label='Negative prompt', |
| | value=('longbody, lowres, bad anatomy, bad hands, ' |
| | 'missing fingers, extra digit, fewer digits, ' |
| | 'cropped, worst quality, low quality')) |
| | with gr.Accordion('Advanced options for the key fame translation', |
| | open=False): |
| | interval = gr.Slider( |
| | label='Key frame frequency (K)', |
| | minimum=1, |
| | maximum=1, |
| | value=1, |
| | step=1, |
| | info='Uniformly sample the key frames every K frames') |
| | keyframe_count = gr.Slider( |
| | label='Number of key frames', |
| | minimum=1, |
| | maximum=1, |
| | value=1, |
| | step=1, |
| | info='To avoid overload, maximum 8 key frames') |
| |
|
| | use_constraints = gr.CheckboxGroup( |
| | [ |
| | 'shape-aware fusion', 'pixel-aware fusion', |
| | 'color-aware AdaIN' |
| | ], |
| | label='Select the cross-frame contraints to be used', |
| | value=[ |
| | 'shape-aware fusion', 'pixel-aware fusion', |
| | 'color-aware AdaIN' |
| | ]), |
| | with gr.Row(): |
| | cross_start = gr.Slider( |
| | label='Cross-frame attention start', |
| | minimum=0, |
| | maximum=1, |
| | value=0, |
| | step=0.05) |
| | cross_end = gr.Slider(label='Cross-frame attention end', |
| | minimum=0, |
| | maximum=1, |
| | value=1, |
| | step=0.05) |
| | style_update_freq = gr.Slider( |
| | label='Cross-frame attention update frequency', |
| | minimum=1, |
| | maximum=100, |
| | value=1, |
| | step=1, |
| | info= |
| | ('Update the key and value for ' |
| | 'cross-frame attention every N key frames (recommend N*K>=10)' |
| | )) |
| | with gr.Row(): |
| | warp_start = gr.Slider(label='Shape-aware fusion start', |
| | minimum=0, |
| | maximum=1, |
| | value=0, |
| | step=0.05) |
| | warp_end = gr.Slider(label='Shape-aware fusion end', |
| | minimum=0, |
| | maximum=1, |
| | value=0.1, |
| | step=0.05) |
| | with gr.Row(): |
| | mask_start = gr.Slider(label='Pixel-aware fusion start', |
| | minimum=0, |
| | maximum=1, |
| | value=0.5, |
| | step=0.05) |
| | mask_end = gr.Slider(label='Pixel-aware fusion end', |
| | minimum=0, |
| | maximum=1, |
| | value=0.8, |
| | step=0.05) |
| | with gr.Row(): |
| | ada_start = gr.Slider(label='Color-aware AdaIN start', |
| | minimum=0, |
| | maximum=1, |
| | value=0.8, |
| | step=0.05) |
| | ada_end = gr.Slider(label='Color-aware AdaIN end', |
| | minimum=0, |
| | maximum=1, |
| | value=1, |
| | step=0.05) |
| | mask_strength = gr.Slider(label='Pixel-aware fusion stength', |
| | minimum=0, |
| | maximum=1, |
| | value=0.5, |
| | step=0.01) |
| | inner_strength = gr.Slider( |
| | label='Pixel-aware fusion detail level', |
| | minimum=0.5, |
| | maximum=1, |
| | value=0.9, |
| | step=0.01, |
| | info='Use a low value to prevent artifacts') |
| | smooth_boundary = gr.Checkbox( |
| | label='Smooth fusion boundary', |
| | value=True, |
| | info='Select to prevent artifacts at boundary') |
| |
|
| | with gr.Accordion('Example configs', open=True): |
| | config_dir = 'config' |
| | config_list = os.listdir(config_dir) |
| | args_list = [] |
| | for config in config_list: |
| | try: |
| | config_path = os.path.join(config_dir, config) |
| | args = cfg_to_input(config_path) |
| | args_list.append(args) |
| | except FileNotFoundError: |
| | |
| | pass |
| |
|
| | ips = [ |
| | prompt, image_resolution, control_strength, color_preserve, |
| | left_crop, right_crop, top_crop, bottom_crop, control_type, |
| | low_threshold, high_threshold, ddim_steps, scale, seed, |
| | sd_model, a_prompt, n_prompt, interval, keyframe_count, |
| | x0_strength, use_constraints[0], cross_start, cross_end, |
| | style_update_freq, warp_start, warp_end, mask_start, |
| | mask_end, ada_start, ada_end, mask_strength, |
| | inner_strength, smooth_boundary |
| | ] |
| |
|
| | with gr.Column(): |
| | result_image = gr.Image(label='Output first frame', |
| | type='numpy', |
| | interactive=False) |
| | result_keyframe = gr.Video(label='Output key frame video', |
| | format='mp4', |
| | interactive=False) |
| | with gr.Row(): |
| | gr.Examples(examples=args_list, |
| | inputs=[input_path, *ips], |
| | fn=process0, |
| | outputs=[result_image, result_keyframe], |
| | cache_examples=True) |
| |
|
| | def input_uploaded(path): |
| | frame_count = get_frame_count(path) |
| | if frame_count <= 2: |
| | raise gr.Error('The input video is too short!' |
| | 'Please input another video.') |
| |
|
| | default_interval = min(10, frame_count - 2) |
| | max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) |
| |
|
| | global video_frame_count |
| | video_frame_count = frame_count |
| | global global_video_path |
| | global_video_path = path |
| |
|
| | return gr.Slider.update(value=default_interval, |
| | maximum=MAX_KEYFRAME), gr.Slider.update( |
| | value=max_keyframe, maximum=max_keyframe) |
| |
|
| | def input_changed(path): |
| | frame_count = get_frame_count(path) |
| | if frame_count <= 2: |
| | return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1) |
| |
|
| | default_interval = min(10, frame_count - 2) |
| | max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) |
| |
|
| | global video_frame_count |
| | video_frame_count = frame_count |
| | global global_video_path |
| | global_video_path = path |
| |
|
| | return gr.Slider.update(maximum=max_keyframe), \ |
| | gr.Slider.update(maximum=max_keyframe) |
| |
|
| | def interval_changed(interval): |
| | global video_frame_count |
| | if video_frame_count is None: |
| | return gr.Slider.update() |
| |
|
| | max_keyframe = (video_frame_count - 2) // interval |
| |
|
| | return gr.Slider.update(value=max_keyframe, maximum=max_keyframe) |
| |
|
| | input_path.change(input_changed, input_path, [interval, keyframe_count]) |
| | input_path.upload(input_uploaded, input_path, [interval, keyframe_count]) |
| | interval.change(interval_changed, interval, keyframe_count) |
| |
|
| | run_button.click(fn=process, |
| | inputs=ips, |
| | outputs=[result_image, result_keyframe]) |
| | run_button1.click(fn=process1, inputs=ips, outputs=[result_image]) |
| | run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe]) |
| |
|
| | def process3(): |
| | raise gr.Error( |
| | "Coming Soon. Full code for full video translation will be " |
| | "released upon the publication of the paper.") |
| |
|
| | run_button3.click(fn=process3, outputs=[result_keyframe]) |
| |
|
| | block.queue(concurrency_count=1, max_size=20) |
| | block.launch(server_name='0.0.0.0') |
| |
|