Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from omegaconf import OmegaConf | |
| from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| exp_path = "model" | |
| class BetterCenterCrop(T.CenterCrop): | |
| def __call__(self, img): | |
| h = img.shape[-2] | |
| w = img.shape[-1] | |
| dim = min(h, w) | |
| return T.functional.center_crop(img, dim) | |
| class ImageLoader: | |
| def __init__(self, path) -> None: | |
| self.path = path | |
| self.all_files = os.listdir(path) | |
| self.transform = T.Compose([ | |
| T.ToTensor(), | |
| BetterCenterCrop((112, 112)), | |
| T.Resize((112, 112)), | |
| ]) | |
| def get_image(self): | |
| idx = np.random.randint(0, len(self.all_files)) | |
| img = Image.open(os.path.join(self.path, self.all_files[idx])) | |
| return img | |
| class Context: | |
| def __init__(self, path, device): | |
| self.path = path | |
| self.config_path = os.path.join(path, "config.yaml") | |
| self.weight_path = os.path.join(path, "merged.pt") | |
| self.config = OmegaConf.load(self.config_path) | |
| self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration) | |
| self.im_load = ImageLoader("echo_images") | |
| unets = [] | |
| for i, (k, v) in enumerate(self.config.unets.items()): | |
| unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore | |
| imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen | |
| del self.config.imagen.elucidated | |
| imagen = imagen_klass( | |
| unets = unets, | |
| **OmegaConf.to_container(self.config.imagen), # type: ignore | |
| ) | |
| self.trainer = ImagenTrainer( | |
| imagen = imagen, | |
| **self.config.trainer | |
| ).to(device) | |
| print("Loading weights from", self.weight_path) | |
| additional_data = self.trainer.load(self.weight_path) | |
| print("Loaded weights from", self.weight_path) | |
| def reshape_image(self, image): | |
| try: | |
| image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy() | |
| return image | |
| except: | |
| return None | |
| def load_random_image(self): | |
| print("Loading random image") | |
| image = self.im_load.get_image() | |
| return image | |
| def generate_video(self, image, lvef, cond_scale): | |
| print("Generating video") | |
| print(f"lvef: {lvef}, cond_scale: {cond_scale}") | |
| image = self.im_load.transform(image).unsqueeze(0) | |
| sample_kwargs = {} | |
| sample_kwargs = { | |
| "text_embeds": torch.tensor([[[lvef/100.0]]]), | |
| "cond_scale": cond_scale, | |
| "cond_images": image, | |
| } | |
| self.trainer.eval() | |
| with torch.no_grad(): | |
| video = self.trainer.sample( | |
| batch_size=1, | |
| video_frames=self.config.dataset.num_frames, | |
| **sample_kwargs, | |
| use_tqdm = True, | |
| ).detach().cpu() # C x F x H x W | |
| if video.shape[-3:] != (64, 112, 112): | |
| video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False) | |
| video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see | |
| uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app | |
| path = f"tmp/{uid}.mp4" | |
| video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy() | |
| out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112)) | |
| for i in video: | |
| out.write(i) | |
| out.release() | |
| return path | |
| context = Context(exp_path, device) | |
| with gr.Blocks(css="style.css") as demo: | |
| with gr.Row(): | |
| gr.Label("Feature-Conditioned Cascaded Video Diffusion Models for Precise Echocardiogram Synthesis") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(scale=3, variant="panel"): | |
| text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** [Code is available here](https://github.com/HReynaud/EchoDiffusion) ") | |
| with gr.Column(scale=1, min_width="226"): | |
| image = gr.Image(interactive=True) | |
| with gr.Column(scale=1, min_width="226"): | |
| video = gr.Video(interactive=False) | |
| slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True) | |
| slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True) | |
| with gr.Row(): | |
| img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)") | |
| run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀") | |
| image.change(context.reshape_image, inputs=[image], outputs=[image]) | |
| img_btn.click(context.load_random_image, inputs=[], outputs=[image]) | |
| run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video]) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |