edl-relight / app.py
jayhsu0627
change model path
65698c4
import gradio as gr
import torch
from diffusers import StableVideoDiffusionPipeline
from utils.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from utils.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
from transformers import CLIPVisionModelWithProjection
from diffusers import AutoencoderKLTemporalDecoder
# 1. Load once at startup
unet = UNetSpatioTemporalConditionModel.from_pretrained("quantum-whisper/edl-relight", subfolder="unet", low_cpu_mem_usage=True).to("cuda")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("stabilityai/stable-video-diffusion-img2vid", subfolder="image_encoder", revision=None)
vae = AutoencoderKLTemporalDecoder.from_pretrained("stabilityai/stable-video-diffusion-img2vid", subfolder="vae", revision=None, variant="fp16").to("cuda")
pipeline = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
unet=unet,
image_encoder=image_encoder,
vae=vae,
revision=None,
torch_dtype=torch.float16,
)
def load_images_from_folder(folder, mask_folder, is_condition=False):
images = []
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed
# Function to extract frame number from the filename
def frame_number(filename):
parts = filename.split('_')
if len(parts) > 1 and parts[0] == 'frame':
try:
return int(parts[1].split('.')[0]) # Extracting the number part
except ValueError:
return float('inf') # In case of non-integer part, place this file at the end
return float('inf') # Non-frame files are placed at the end
# Sorting files based on frame number
sorted_files = sorted(os.listdir(folder))
# Load images in sorted order
for i,filename in enumerate(sorted_files):
img = Image.open(os.path.join(folder, filename))
# Check if the directory exists
if os.path.isdir(mask_folder):
mask = combine_masks(mask_folder)[i]
# Expand mask to 3D to match the shape of image_array (1080, 1920, 3)
mask_3d = np.expand_dims(mask, axis=-1).repeat(3, axis=-1)
# Convert image to a NumPy array
image_array = np.array(img)
multiplied_image_array = (image_array * mask_3d).astype(np.uint8)
multiplied_image_array = multiplied_image_array + ((1-mask_3d) * 255).astype(np.uint8)
img = Image.fromarray(multiplied_image_array)
if is_condition:
img = convert_colors(img)
w, h = img.size # PIL uses (width, height) order
img = resize_and_pad_image(img)
images.append(img)
return images
def export_to_gif(frames, output_gif_path, fps):
"""
Export a list of frames to a GIF.
Args:
- frames (list): List of frames (as numpy arrays or PIL Image objects).
- output_gif_path (str): Path to save the output GIF.
- duration_ms (int): Duration of each frame in milliseconds.
"""
# Convert numpy arrays to PIL Images if needed
pil_frames = [Image.fromarray(frame) if isinstance(
frame, np.ndarray) else frame for frame in frames]
pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'),
format='GIF',
append_images=pil_frames[1:],
save_all=True,
duration=500,
loop=0)
def generate(video_folder: str, num_frames: int = 4, height: int = 320, width: int = 512):
"""
video_folder: path to a folder of image frames (frame_0000.png, …)
"""
frames = load_images_from_folder(video_folder, mask_folder=None, is_condition=False)
# run the pipeline
output = pipeline(frames, num_frames=num_frames, height=height, width=width).frames[0]
# convert back to a GIF or video bytes
return export_frames_to_gif(output, fps=7)
# 2. Build the Gradio interface
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Video-frame folder path"),
gr.Slider(1, 16, value=4, step=1, label="Number of output frames"),
gr.Slider(128, 1024, value=320, step=32, label="Height"),
gr.Slider(128, 1024, value=512, step=32, label="Width"),
],
outputs=gr.Video(label="Relit Video"),
title="Stable Video Diffusion Demo",
description="Upload a folder of frames and get back your relit video."
)
if __name__ == "__main__":
iface.launch()