Spaces:
Runtime error
Runtime error
Commit
·
316f1d5
0
Parent(s):
Duplicate from anon-SGXT/echocardiogram-video-diffusion
Browse filesCo-authored-by: Anonymous <anon-SGXT@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +34 -0
- .gitignore +3 -0
- README.md +14 -0
- app.py +145 -0
- echo_images/0X10094BA0A028EAC3.png +0 -0
- echo_images/0X1013E8A4864781B.png +0 -0
- echo_images/0X12B890B1E2E14CC4.png +0 -0
- echo_images/0X13E043A35E3EB490.png +0 -0
- echo_images/0X159BDA520C61736A.png +0 -0
- echo_images/0X15DA8D60960ABB2B.png +0 -0
- echo_images/0X16AF26F9A372EEDE.png +0 -0
- echo_images/0X17BC4EF4BF83368B.png +0 -0
- echo_images/0X1B379931357428C0.png +0 -0
- echo_images/0X1CDD9C054D8FB60D.png +0 -0
- echo_images/0X1DF7163A74801695.png +0 -0
- echo_images/0X20C397F012441121.png +0 -0
- echo_images/0X22A1A8A656653343.png +0 -0
- echo_images/0X22D7FDCF2827269E.png +0 -0
- echo_images/0X230F00FD0DF5D71C.png +0 -0
- echo_images/0X244CAB3550320216.png +0 -0
- echo_images/0X24FEF7D294B35A5B.png +0 -0
- echo_images/0X25D970C75A57B3F2.png +0 -0
- echo_images/0X277FC348812C0E79.png +0 -0
- echo_images/0X27836E538BD008A.png +0 -0
- echo_images/0X2840438B29E95F1F.png +0 -0
- echo_images/0X29A336DCE20541A0.png +0 -0
- echo_images/0X29C81728B50A2E6C.png +0 -0
- echo_images/0X2A830BC4A3A36A93.png +0 -0
- echo_images/0X2AD994F98C491FA6.png +0 -0
- echo_images/0X2BB766EF1A13DECC.png +0 -0
- echo_images/0X2DA99F9FC1DAD8A9.png +0 -0
- echo_images/0X3545F8A008B34ED0.png +0 -0
- echo_images/0X36E4468C9E659B89.png +0 -0
- echo_images/0X39CA8CC96A5D5E8B.png +0 -0
- echo_images/0X3B01B7487E3D81EA.png +0 -0
- echo_images/0X3B0D2D527C387A0E.png +0 -0
- echo_images/0X3B54A5459841DCE8.png +0 -0
- echo_images/0X3B9FBD87EE113D62.png +0 -0
- echo_images/0X3BA9F7C9DB0CF55B.png +0 -0
- echo_images/0X3DA2B290B58A6540.png +0 -0
- echo_images/0X3E2F182038897EA5.png +0 -0
- echo_images/0X3F076329C702F768.png +0 -0
- echo_images/0X4130EB4CD7ED958B.png +0 -0
- echo_images/0X42E8226CA93B7BAC.png +0 -0
- echo_images/0X45418C574D97027A.png +0 -0
- echo_images/0X45CE057EC2EB577F.png +0 -0
- echo_images/0X463A7B7D46C6CA4.png +0 -0
- echo_images/0X463C296E8E65DA97.png +0 -0
- echo_images/0X46682D67FA3FE237.png +0 -0
- echo_images/0X487B52623BC14C25.png +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.mp4
|
| 2 |
+
*.ipynb
|
| 3 |
+
*__pycache__*
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: EchoNet Video Diffusion
|
| 3 |
+
emoji: 🖤
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.17.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
duplicated_from: anon-SGXT/echocardiogram-video-diffusion
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from omegaconf import OmegaConf
|
| 4 |
+
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer, ElucidatedImagenConfig, NullUnet, Imagen
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
exp_path = "model"
|
| 14 |
+
|
| 15 |
+
class BetterCenterCrop(T.CenterCrop):
|
| 16 |
+
def __call__(self, img):
|
| 17 |
+
h = img.shape[-2]
|
| 18 |
+
w = img.shape[-1]
|
| 19 |
+
dim = min(h, w)
|
| 20 |
+
|
| 21 |
+
return T.functional.center_crop(img, dim)
|
| 22 |
+
|
| 23 |
+
class ImageLoader:
|
| 24 |
+
def __init__(self, path) -> None:
|
| 25 |
+
self.path = path
|
| 26 |
+
self.all_files = os.listdir(path)
|
| 27 |
+
self.transform = T.Compose([
|
| 28 |
+
T.ToTensor(),
|
| 29 |
+
BetterCenterCrop((112, 112)),
|
| 30 |
+
T.Resize((112, 112)),
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
def get_image(self):
|
| 34 |
+
idx = np.random.randint(0, len(self.all_files))
|
| 35 |
+
img = Image.open(os.path.join(self.path, self.all_files[idx]))
|
| 36 |
+
return img
|
| 37 |
+
|
| 38 |
+
class Context:
|
| 39 |
+
def __init__(self, path, device):
|
| 40 |
+
self.path = path
|
| 41 |
+
self.config_path = os.path.join(path, "config.yaml")
|
| 42 |
+
self.weight_path = os.path.join(path, "merged.pt")
|
| 43 |
+
|
| 44 |
+
self.config = OmegaConf.load(self.config_path)
|
| 45 |
+
|
| 46 |
+
self.config.dataset.num_frames = int(self.config.dataset.fps * self.config.dataset.duration)
|
| 47 |
+
|
| 48 |
+
self.im_load = ImageLoader("echo_images")
|
| 49 |
+
|
| 50 |
+
unets = []
|
| 51 |
+
for i, (k, v) in enumerate(self.config.unets.items()):
|
| 52 |
+
unets.append(Unet3D(**v, lowres_cond=(i>0))) # type: ignore
|
| 53 |
+
|
| 54 |
+
imagen_klass = ElucidatedImagen if self.config.imagen.elucidated == True else Imagen
|
| 55 |
+
del self.config.imagen.elucidated
|
| 56 |
+
imagen = imagen_klass(
|
| 57 |
+
unets = unets,
|
| 58 |
+
**OmegaConf.to_container(self.config.imagen), # type: ignore
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.trainer = ImagenTrainer(
|
| 62 |
+
imagen = imagen,
|
| 63 |
+
**self.config.trainer
|
| 64 |
+
).to(device)
|
| 65 |
+
|
| 66 |
+
print("Loading weights from", self.weight_path)
|
| 67 |
+
additional_data = self.trainer.load(self.weight_path)
|
| 68 |
+
print("Loaded weights from", self.weight_path)
|
| 69 |
+
|
| 70 |
+
def reshape_image(self, image):
|
| 71 |
+
try:
|
| 72 |
+
image = self.im_load.transform(image).multiply(255).byte().permute(1,2,0).numpy()
|
| 73 |
+
return image
|
| 74 |
+
except:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
def load_random_image(self):
|
| 78 |
+
print("Loading random image")
|
| 79 |
+
image = self.im_load.get_image()
|
| 80 |
+
return image
|
| 81 |
+
|
| 82 |
+
def generate_video(self, image, lvef, cond_scale):
|
| 83 |
+
print("Generating video")
|
| 84 |
+
print(f"lvef: {lvef}, cond_scale: {cond_scale}")
|
| 85 |
+
|
| 86 |
+
image = self.im_load.transform(image).unsqueeze(0)
|
| 87 |
+
|
| 88 |
+
sample_kwargs = {}
|
| 89 |
+
sample_kwargs = {
|
| 90 |
+
"text_embeds": torch.tensor([[[lvef/100.0]]]),
|
| 91 |
+
"cond_scale": cond_scale,
|
| 92 |
+
"cond_images": image,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
self.trainer.eval()
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
video = self.trainer.sample(
|
| 98 |
+
batch_size=1,
|
| 99 |
+
video_frames=self.config.dataset.num_frames,
|
| 100 |
+
**sample_kwargs,
|
| 101 |
+
use_tqdm = True,
|
| 102 |
+
).detach().cpu() # C x F x H x W
|
| 103 |
+
if video.shape[-3:] != (64, 112, 112):
|
| 104 |
+
video = torch.nn.functional.interpolate(video, size=(64, 112, 112), mode='trilinear', align_corners=False)
|
| 105 |
+
video = video.repeat((1,1,5,1,1)) # make the video loop 5 times - easier to see
|
| 106 |
+
uid = np.random.randint(0, 10) # prevent overwriting if multiple users are using the app
|
| 107 |
+
path = f"tmp/{uid}.mp4"
|
| 108 |
+
video = video.multiply(255).byte().squeeze(0).permute(1, 2, 3, 0).numpy()
|
| 109 |
+
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), 32, (112, 112))
|
| 110 |
+
for i in video:
|
| 111 |
+
out.write(i)
|
| 112 |
+
out.release()
|
| 113 |
+
return path
|
| 114 |
+
|
| 115 |
+
context = Context(exp_path, device)
|
| 116 |
+
|
| 117 |
+
with gr.Blocks(css="style.css") as demo:
|
| 118 |
+
|
| 119 |
+
with gr.Row():
|
| 120 |
+
gr.Label("Cardiac Ultrasound Video Generation Demo (paper: 905)")
|
| 121 |
+
|
| 122 |
+
with gr.Row():
|
| 123 |
+
with gr.Column():
|
| 124 |
+
with gr.Row():
|
| 125 |
+
with gr.Column(scale=3, variant="panel"):
|
| 126 |
+
text = gr.Markdown(value="This is a live demo of our work on cardiac ultrasound video generation. The model is trained on 4-chamber cardiac ultrasound videos and can generate realistic 4-chamber videos given a target Left Ventricle Ejection Fraction. Please, start by sampling a random frame from the pool of 100 images taken from the EchoNet-Dynamic dataset, which will act as the conditional image, representing the anatomy of the video. Then, set the target LVEF, and click the button to generate a video. The process takes 30s to 60s. The model running here corresponds to the 1SCM from the paper. **Click on the video to play it.** ")
|
| 127 |
+
with gr.Column(scale=1, min_width="226"):
|
| 128 |
+
image = gr.Image(interactive=True)
|
| 129 |
+
with gr.Column(scale=1, min_width="226"):
|
| 130 |
+
video = gr.Video(interactive=False)
|
| 131 |
+
|
| 132 |
+
slider_ef = gr.Slider(minimum=10, maximum=90, step=1, label="Target LVEF", value=60, interactive=True)
|
| 133 |
+
slider_cond = gr.Slider(minimum=0, maximum=20, step=1, label="Conditional scale (if set to more than 1, generation time is 60s)", value=1, interactive=True)
|
| 134 |
+
|
| 135 |
+
with gr.Row():
|
| 136 |
+
img_btn = gr.Button(value="❶ Get a random cardiac ultrasound image (4Ch)")
|
| 137 |
+
run_btn = gr.Button(value="❷ Generate a video (~30s) 🚀")
|
| 138 |
+
|
| 139 |
+
image.change(context.reshape_image, inputs=[image], outputs=[image])
|
| 140 |
+
img_btn.click(context.load_random_image, inputs=[], outputs=[image])
|
| 141 |
+
run_btn.click(context.generate_video, inputs=[image, slider_ef, slider_cond], outputs=[video])
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
demo.queue()
|
| 145 |
+
demo.launch()
|
echo_images/0X10094BA0A028EAC3.png
ADDED
|
echo_images/0X1013E8A4864781B.png
ADDED
|
echo_images/0X12B890B1E2E14CC4.png
ADDED
|
echo_images/0X13E043A35E3EB490.png
ADDED
|
echo_images/0X159BDA520C61736A.png
ADDED
|
echo_images/0X15DA8D60960ABB2B.png
ADDED
|
echo_images/0X16AF26F9A372EEDE.png
ADDED
|
echo_images/0X17BC4EF4BF83368B.png
ADDED
|
echo_images/0X1B379931357428C0.png
ADDED
|
echo_images/0X1CDD9C054D8FB60D.png
ADDED
|
echo_images/0X1DF7163A74801695.png
ADDED
|
echo_images/0X20C397F012441121.png
ADDED
|
echo_images/0X22A1A8A656653343.png
ADDED
|
echo_images/0X22D7FDCF2827269E.png
ADDED
|
echo_images/0X230F00FD0DF5D71C.png
ADDED
|
echo_images/0X244CAB3550320216.png
ADDED
|
echo_images/0X24FEF7D294B35A5B.png
ADDED
|
echo_images/0X25D970C75A57B3F2.png
ADDED
|
echo_images/0X277FC348812C0E79.png
ADDED
|
echo_images/0X27836E538BD008A.png
ADDED
|
echo_images/0X2840438B29E95F1F.png
ADDED
|
echo_images/0X29A336DCE20541A0.png
ADDED
|
echo_images/0X29C81728B50A2E6C.png
ADDED
|
echo_images/0X2A830BC4A3A36A93.png
ADDED
|
echo_images/0X2AD994F98C491FA6.png
ADDED
|
echo_images/0X2BB766EF1A13DECC.png
ADDED
|
echo_images/0X2DA99F9FC1DAD8A9.png
ADDED
|
echo_images/0X3545F8A008B34ED0.png
ADDED
|
echo_images/0X36E4468C9E659B89.png
ADDED
|
echo_images/0X39CA8CC96A5D5E8B.png
ADDED
|
echo_images/0X3B01B7487E3D81EA.png
ADDED
|
echo_images/0X3B0D2D527C387A0E.png
ADDED
|
echo_images/0X3B54A5459841DCE8.png
ADDED
|
echo_images/0X3B9FBD87EE113D62.png
ADDED
|
echo_images/0X3BA9F7C9DB0CF55B.png
ADDED
|
echo_images/0X3DA2B290B58A6540.png
ADDED
|
echo_images/0X3E2F182038897EA5.png
ADDED
|
echo_images/0X3F076329C702F768.png
ADDED
|
echo_images/0X4130EB4CD7ED958B.png
ADDED
|
echo_images/0X42E8226CA93B7BAC.png
ADDED
|
echo_images/0X45418C574D97027A.png
ADDED
|
echo_images/0X45CE057EC2EB577F.png
ADDED
|
echo_images/0X463A7B7D46C6CA4.png
ADDED
|
echo_images/0X463C296E8E65DA97.png
ADDED
|
echo_images/0X46682D67FA3FE237.png
ADDED
|
echo_images/0X487B52623BC14C25.png
ADDED
|