import gradio as gr import torch from diffusers import StableVideoDiffusionPipeline from utils.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from utils.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline from transformers import CLIPVisionModelWithProjection from diffusers import AutoencoderKLTemporalDecoder # 1. Load once at startup unet = UNetSpatioTemporalConditionModel.from_pretrained("quantum-whisper/edl-relight", subfolder="unet", low_cpu_mem_usage=True).to("cuda") image_encoder = CLIPVisionModelWithProjection.from_pretrained("stabilityai/stable-video-diffusion-img2vid", subfolder="image_encoder", revision=None) vae = AutoencoderKLTemporalDecoder.from_pretrained("stabilityai/stable-video-diffusion-img2vid", subfolder="vae", revision=None, variant="fp16").to("cuda") pipeline = StableVideoDiffusionPipeline.from_pretrained( "stabilityai/stable-video-diffusion-img2vid", unet=unet, image_encoder=image_encoder, vae=vae, revision=None, torch_dtype=torch.float16, ) def load_images_from_folder(folder, mask_folder, is_condition=False): images = [] valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed # Function to extract frame number from the filename def frame_number(filename): parts = filename.split('_') if len(parts) > 1 and parts[0] == 'frame': try: return int(parts[1].split('.')[0]) # Extracting the number part except ValueError: return float('inf') # In case of non-integer part, place this file at the end return float('inf') # Non-frame files are placed at the end # Sorting files based on frame number sorted_files = sorted(os.listdir(folder)) # Load images in sorted order for i,filename in enumerate(sorted_files): img = Image.open(os.path.join(folder, filename)) # Check if the directory exists if os.path.isdir(mask_folder): mask = combine_masks(mask_folder)[i] # Expand mask to 3D to match the shape of image_array (1080, 1920, 3) mask_3d = np.expand_dims(mask, axis=-1).repeat(3, axis=-1) # Convert image to a NumPy array image_array = np.array(img) multiplied_image_array = (image_array * mask_3d).astype(np.uint8) multiplied_image_array = multiplied_image_array + ((1-mask_3d) * 255).astype(np.uint8) img = Image.fromarray(multiplied_image_array) if is_condition: img = convert_colors(img) w, h = img.size # PIL uses (width, height) order img = resize_and_pad_image(img) images.append(img) return images def export_to_gif(frames, output_gif_path, fps): """ Export a list of frames to a GIF. Args: - frames (list): List of frames (as numpy arrays or PIL Image objects). - output_gif_path (str): Path to save the output GIF. - duration_ms (int): Duration of each frame in milliseconds. """ # Convert numpy arrays to PIL Images if needed pil_frames = [Image.fromarray(frame) if isinstance( frame, np.ndarray) else frame for frame in frames] pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'), format='GIF', append_images=pil_frames[1:], save_all=True, duration=500, loop=0) def generate(video_folder: str, num_frames: int = 4, height: int = 320, width: int = 512): """ video_folder: path to a folder of image frames (frame_0000.png, …) """ frames = load_images_from_folder(video_folder, mask_folder=None, is_condition=False) # run the pipeline output = pipeline(frames, num_frames=num_frames, height=height, width=width).frames[0] # convert back to a GIF or video bytes return export_frames_to_gif(output, fps=7) # 2. Build the Gradio interface iface = gr.Interface( fn=generate, inputs=[ gr.Textbox(label="Video-frame folder path"), gr.Slider(1, 16, value=4, step=1, label="Number of output frames"), gr.Slider(128, 1024, value=320, step=32, label="Height"), gr.Slider(128, 1024, value=512, step=32, label="Width"), ], outputs=gr.Video(label="Relit Video"), title="Stable Video Diffusion Demo", description="Upload a folder of frames and get back your relit video." ) if __name__ == "__main__": iface.launch()