|
|
from diffusers import DiffusionPipeline |
|
|
import gradio as gr |
|
|
import torch |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD') |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
"CompVis/stable-diffusion-v1-4", |
|
|
use_auth_token=MY_SECRET_TOKEN, |
|
|
|
|
|
|
|
|
safety_checker=None, |
|
|
custom_pipeline="interpolate_stable_diffusion", |
|
|
).to(device) |
|
|
pipe.enable_attention_slicing() |
|
|
|
|
|
def run(prompt1, seed1, prompt2, seed2, prompt3, seed3): |
|
|
|
|
|
frame_filepaths = pipe.walk( |
|
|
prompts=[prompt1, prompt2, prompt3], |
|
|
seeds=[seed1, seed2, seed3], |
|
|
num_interpolation_steps=16, |
|
|
output_dir='./dreams', |
|
|
batch_size=4, |
|
|
height=512, |
|
|
width=512, |
|
|
guidance_scale=8.5, |
|
|
num_inference_steps=50, |
|
|
) |
|
|
print(frame_filepaths) |
|
|
|
|
|
frame = cv2.imread(frame_filepaths[0]) |
|
|
height, width, layers = frame.shape |
|
|
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') |
|
|
video = cv2.VideoWriter("out.mp4", fourcc, 24, (width,height)) |
|
|
for image in frame_filepaths: |
|
|
|
|
|
video.write(cv2.imread(image)) |
|
|
|
|
|
video.release() |
|
|
cv2.destroyAllWindows() |
|
|
|
|
|
|
|
|
return "out.mp4", frame_filepaths |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Column(): |
|
|
gr.HTML(''' |
|
|
<h1 style='font-size: 2em;text-align:center;font-weigh:900;'> |
|
|
Stable Diffusion Interpolation • Community pipeline |
|
|
</h1> |
|
|
<p style='text-align: center;'><br /> |
|
|
This community pipeline returns a list of images saved under the folder as defined in output_dir. <br /> |
|
|
You can use these images to create videos of stable diffusion. |
|
|
</p> |
|
|
|
|
|
<p style='text-align: center;'> |
|
|
This demo can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes.<br /> |
|
|
— |
|
|
</p> |
|
|
|
|
|
''') |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
intpol_prompt_1 = gr.Textbox(lines=1, label="prompt 1") |
|
|
seed1 = gr.Slider(label = "Seed 1", minimum = 0, maximum = 2147483647, step = 1, randomize = True) |
|
|
with gr.Row(): |
|
|
intpol_prompt_2 = gr.Textbox(lines=1, label="prompt 2") |
|
|
seed2 = gr.Slider(label = "Seed 2", minimum = 0, maximum = 2147483647, step = 1, randomize = True) |
|
|
with gr.Row(): |
|
|
intpol_prompt_3 = gr.Textbox(lines=1, label="prompt 3") |
|
|
seed3 = gr.Slider(label = "Seed 3", minimum = 0, maximum = 2147483647, step = 1, randomize = True) |
|
|
intpol_run = gr.Button("Run Interpolation") |
|
|
|
|
|
with gr.Column(): |
|
|
video_output = gr.Video(label="Generated video", show_label=True) |
|
|
gallery_output = gr.Gallery(label="Generated images", show_label=False).style(grid=2, height="auto") |
|
|
|
|
|
intpol_run.click(run, inputs=[intpol_prompt_1, seed1, intpol_prompt_2, seed2, intpol_prompt_3, seed3], outputs=[video_output, gallery_output]) |
|
|
|
|
|
demo.launch() |