Spaces:
Build error
Build error
Pie31415
commited on
Commit
·
4a38f47
1
Parent(s):
8160e04
updated app for video inference
Browse files
app.py
CHANGED
|
@@ -138,34 +138,35 @@ def image_inference(
|
|
| 138 |
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
|
| 139 |
return res[..., ::-1]
|
| 140 |
|
| 141 |
-
def extract_frames(
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
image_frames = extract_frames(driver_vid)
|
| 158 |
|
| 159 |
resulted_imgs = defaultdict(list)
|
| 160 |
|
| 161 |
-
video_folder = 'jenya_driver/'
|
| 162 |
-
image_frames = sorted(glob(f"{video_folder}/*", recursive=True), key=lambda x: int(x.split('/')[-1][:-4]))
|
| 163 |
-
|
| 164 |
mask_hard_threshold = 0.5
|
| 165 |
-
N = len(image_frames)
|
| 166 |
-
for i in range(0, N, 4):
|
| 167 |
-
new_out = infer.evaluate(source_img,
|
| 168 |
-
source_information_for_reuse=out.get('source_information'))
|
| 169 |
|
| 170 |
mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
|
| 171 |
mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
|
|
@@ -192,34 +193,41 @@ def video_inference(source_img, driver_vid):
|
|
| 192 |
im.set_data(video[i,:,:,::-1])
|
| 193 |
return im
|
| 194 |
|
| 195 |
-
anim = animation.FuncAnimation(fig, animate, init_func=init,
|
| 196 |
-
|
| 197 |
|
| 198 |
-
return
|
| 199 |
|
| 200 |
with gr.Blocks() as demo:
|
| 201 |
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
|
| 202 |
-
|
| 203 |
gr.Markdown(
|
| 204 |
"""
|
|
|
|
|
|
|
| 205 |
<p style='text-align: center'>
|
| 206 |
Create a personal avatar from just a single image using ROME.
|
| 207 |
<br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a>
|
| 208 |
</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
"""
|
| 210 |
)
|
| 211 |
|
| 212 |
with gr.Tab("Image Inference"):
|
| 213 |
with gr.Row():
|
| 214 |
-
source_img = gr.Image(type="pil", label="
|
| 215 |
-
driver_img = gr.Image(type="pil", label="
|
| 216 |
-
image_output = gr.Image()
|
| 217 |
image_button = gr.Button("Predict")
|
| 218 |
with gr.Tab("Video Inference"):
|
| 219 |
with gr.Row():
|
| 220 |
source_img2 = gr.Image(type="pil", label="source image", show_label=True)
|
| 221 |
driver_vid = gr.Video(label="driver video")
|
| 222 |
-
video_output = gr.Image()
|
| 223 |
video_button = gr.Button("Predict")
|
| 224 |
|
| 225 |
gr.Examples(
|
|
|
|
| 138 |
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
|
| 139 |
return res[..., ::-1]
|
| 140 |
|
| 141 |
+
def extract_frames(
|
| 142 |
+
driver_vid: gr.inputs.Video = None
|
| 143 |
+
):
|
| 144 |
+
image_frames = []
|
| 145 |
+
vid = cv2.VideoCapture(driver_vid) # path to mp4
|
| 146 |
+
|
| 147 |
+
while True:
|
| 148 |
+
success, img = vid.read()
|
| 149 |
+
|
| 150 |
+
if not success: break
|
| 151 |
+
|
| 152 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 153 |
+
pil_img = Image.fromarray(img)
|
| 154 |
+
image_frames.append(pil_img)
|
| 155 |
+
|
| 156 |
+
return image_frames
|
| 157 |
+
|
| 158 |
+
def video_inference(
|
| 159 |
+
source_img: gr.inputs.Image = None,
|
| 160 |
+
driver_vid: gr.inputs.Video = None
|
| 161 |
+
):
|
| 162 |
image_frames = extract_frames(driver_vid)
|
| 163 |
|
| 164 |
resulted_imgs = defaultdict(list)
|
| 165 |
|
|
|
|
|
|
|
|
|
|
| 166 |
mask_hard_threshold = 0.5
|
| 167 |
+
N = len(image_frames)
|
| 168 |
+
for i in range(0, N, 4): # frame limits
|
| 169 |
+
new_out = infer.evaluate(source_img, image_frames[i])
|
|
|
|
| 170 |
|
| 171 |
mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
|
| 172 |
mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
|
|
|
|
| 193 |
im.set_data(video[i,:,:,::-1])
|
| 194 |
return im
|
| 195 |
|
| 196 |
+
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=30)
|
| 197 |
+
anim.save("avatar.gif", dpi=300, writer = animation.PillowWriter(fps=24))
|
| 198 |
|
| 199 |
+
return "avatar.gif"
|
| 200 |
|
| 201 |
with gr.Blocks() as demo:
|
| 202 |
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
|
| 203 |
+
|
| 204 |
gr.Markdown(
|
| 205 |
"""
|
| 206 |
+
<img src='https://github.com/SamsungLabs/rome/blob/main/media/tease.gif'>
|
| 207 |
+
|
| 208 |
<p style='text-align: center'>
|
| 209 |
Create a personal avatar from just a single image using ROME.
|
| 210 |
<br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a>
|
| 211 |
</p>
|
| 212 |
+
|
| 213 |
+
<blockquote>
|
| 214 |
+
[The] system creates realistic mesh-based avatars from a single <strong>source</strong>
|
| 215 |
+
photo. These avatars are rigged, i.e., they can be driven by the animation parameters from a different <strong>driving</strong> frame.
|
| 216 |
+
</blockquote>
|
| 217 |
"""
|
| 218 |
)
|
| 219 |
|
| 220 |
with gr.Tab("Image Inference"):
|
| 221 |
with gr.Row():
|
| 222 |
+
source_img = gr.Image(type="pil", label="Source image", show_label=True)
|
| 223 |
+
driver_img = gr.Image(type="pil", label="Driver image", show_label=True)
|
| 224 |
+
image_output = gr.Image("Rendered avatar")
|
| 225 |
image_button = gr.Button("Predict")
|
| 226 |
with gr.Tab("Video Inference"):
|
| 227 |
with gr.Row():
|
| 228 |
source_img2 = gr.Image(type="pil", label="source image", show_label=True)
|
| 229 |
driver_vid = gr.Video(label="driver video")
|
| 230 |
+
video_output = gr.Image(label="Rendered GIF avatar")
|
| 231 |
video_button = gr.Button("Predict")
|
| 232 |
|
| 233 |
gr.Examples(
|