liamsch commited on
Commit
887af40
·
1 Parent(s): 46a927f

Initial commit: SHeaP Gradio demo

Browse files
Files changed (47) hide show
  1. .gitattributes +3 -0
  2. FLAME2020/.gitattributes +3 -0
  3. FLAME2020/eyelids.pt +3 -0
  4. FLAME2020/flame_landmark_idxs_barys.pt +3 -0
  5. FLAME2020/generic_model.pkl +3 -0
  6. FLAME2020/generic_model.pt +3 -0
  7. LICENSE.txt +1 -0
  8. README.md +124 -14
  9. app.py +7 -6
  10. convert_flame.py +68 -0
  11. demo.py +99 -0
  12. example_images/00000200.jpg +3 -0
  13. example_images/00000201.jpg +3 -0
  14. example_images/00000202.jpg +3 -0
  15. example_images/00000203.jpg +3 -0
  16. example_images/00000204.jpg +3 -0
  17. example_images/00000205.jpg +3 -0
  18. example_images/00000206.jpg +3 -0
  19. example_images/00000207.jpg +3 -0
  20. example_images/00000208.jpg +3 -0
  21. example_images/00000209.jpg +3 -0
  22. example_videos/dafoe.mp4 +3 -0
  23. gradio_demo.py +284 -0
  24. models/model_expressive.pt +3 -0
  25. pyproject.toml +52 -0
  26. requirements.txt +14 -0
  27. requirements_hf.txt +15 -0
  28. sheap/__init__.py +21 -0
  29. sheap/__pycache__/__init__.cpython-311.pyc +0 -0
  30. sheap/__pycache__/eval_utils.cpython-311.pyc +0 -0
  31. sheap/__pycache__/fa_landmark_utils.cpython-311.pyc +0 -0
  32. sheap/__pycache__/landmark_utils.cpython-311.pyc +0 -0
  33. sheap/__pycache__/load_flame.cpython-311.pyc +0 -0
  34. sheap/__pycache__/load_flame_pkl.cpython-311.pyc +0 -0
  35. sheap/__pycache__/load_model.cpython-311.pyc +0 -0
  36. sheap/__pycache__/render.cpython-311.pyc +0 -0
  37. sheap/__pycache__/tiny_flame.cpython-311.pyc +0 -0
  38. sheap/eval_utils.py +270 -0
  39. sheap/fa_landmark_utils.py +96 -0
  40. sheap/landmark_utils.py +143 -0
  41. sheap/load_flame_pkl.py +35 -0
  42. sheap/load_model.py +85 -0
  43. sheap/py.typed +0 -0
  44. sheap/render.py +83 -0
  45. sheap/tiny_flame.py +168 -0
  46. teaser.jpg +3 -0
  47. video_demo.py +460 -0
.gitattributes CHANGED
@@ -33,6 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  *.mp4 filter=lfs diff=lfs merge=lfs -text
37
  *.jpg filter=lfs diff=lfs merge=lfs -text
38
  *.jpeg filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ **/*.pt filter=lfs diff=lfs merge=lfs -text
37
+ **/*.pth filter=lfs diff=lfs merge=lfs -text
38
+ **/*.pkl filter=lfs diff=lfs merge=lfs -text
39
  *.mp4 filter=lfs diff=lfs merge=lfs -text
40
  *.jpg filter=lfs diff=lfs merge=lfs -text
41
  *.jpeg filter=lfs diff=lfs merge=lfs -text
FLAME2020/.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.pkl filter=lfs diff=lfs merge=lfs -text
FLAME2020/eyelids.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5d5a2abbc71384b203451085337b0f9a581619bc839838a92b32a80d76ad9fa
3
+ size 121692
FLAME2020/flame_landmark_idxs_barys.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeda0176e330a0d69e8f2be29baf4bed62ecbed1ea04a0f5eb1ca0460023e398
3
+ size 2948
FLAME2020/generic_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efcd14cc4a69f3a3d9af8ded80146b5b6b50df3bd74cf69108213b144eba725b
3
+ size 53023716
FLAME2020/generic_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af8c7483c6135d26ccbee9a8a0ac64b39575d6afdf121f79b866ff4c7fbdcf19
3
+ size 26784481
LICENSE.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ SHeaP: Self-Supervised Head Geometry Predictor Learned via 2D Gaussians © 2025 by Liam Schoneveld is licensed under Creative Commons Attribution-NonCommercial 4.0 International. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc/4.0/
README.md CHANGED
@@ -1,14 +1,124 @@
1
- ---
2
- title: Sheap
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 6.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-nc-4.0
11
- short_description: 'SHeaP: Self-Supervised Head Geometry Predictor'
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>🐑 SHeaP 🐑</h1>
3
+ <h2>Self-Supervised Head Geometry Predictor Learned via 2D Gaussians</h2>
4
+
5
+ <a href="https://nlml.github.io/sheap" target="_blank" rel="noopener noreferrer">
6
+ <img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page">
7
+ </a>
8
+ <a href="https://arxiv.org/abs/2504.12292"><img src="https://img.shields.io/badge/arXiv-2504.12292-b31b1b" alt="arXiv"></a>
9
+ <a href="https://www.youtube.com/watch?v=vhXsZJWCBMA"><img src="https://img.shields.io/badge/YouTube-Video-red" alt="YouTube"></a>
10
+
11
+ **Liam Schoneveld, Zhe Chen, Davide Davoli, Jiapeng Tang, Saimon Terazawa, Ko Nishino, Matthias Nießner**
12
+
13
+ <img src="teaser.jpg" alt="SHeaP Teaser" width="100%">
14
+
15
+ </div>
16
+
17
+ ## Overview
18
+
19
+ SHeaP learns to predict head geometry (FLAME parameters) from a single image, by predicting and rendering 2D Gaussians.
20
+
21
+ This repository contains code and models for the **FLAME parameter inference only**.
22
+
23
+ ## Example usage
24
+
25
+ **After setting up**, for a simple example, run `python demo.py`.
26
+
27
+ To run on a video you can use:
28
+
29
+ ```bash
30
+ python video_demo.py example_videos/dafoe.mp4
31
+ ```
32
+
33
+ The above command will produce the result in [example_videos/dafoe_rendered.mp4](https://github.com/nlml/SHeaP/blob/main/example_videos/dafoe_rendered.mp4).
34
+
35
+ Or, here is a minimal example script:
36
+
37
+ ```python
38
+ import torch, torchvision.io as io
39
+ from sheap import load_sheap_model
40
+ # Available model variants:
41
+ # sheap_model = load_sheap_model(model_type="paper")
42
+ sheap_model = load_sheap_model(model_type="expressive")
43
+ impath = "example_images/00000200.jpg"
44
+ # Input should be a head crop similar to those in example_images/
45
+ # shape (N,3,224,224) / pixel values from 0 to 1.
46
+ image_tensor = io.decode_image(impath).float() / 255
47
+ # flame_params_dict contains predicted FLAME parameters
48
+ flame_params_dict = sheap_model(image_tensor[None])
49
+ ```
50
+
51
+ **Note: `model_type`** can be one of 2 values:
52
+
53
+ - **`"paper"`**: used for paper results; gets best performance on NoW.
54
+ - **`"expressive"`**: perhaps better for real-world use; it was trained for longer with less regularisation and tends to be more expressive.
55
+
56
+ ## Setup
57
+
58
+ ### Step 1: Install dependencies
59
+
60
+ We just require `torch>=2.0.0` and a few other dependencies.
61
+
62
+ Just install the latest `torch` in a new venv, then `pip install .`
63
+
64
+ Or, if you use [`uv`](https://docs.astral.sh/uv/), you can just run `uv sync`.
65
+
66
+ ### Step 2: Download and convert FLAME
67
+
68
+ Only needed if you want to predict FLAME vertices or render a mesh.
69
+
70
+ Download [FLAME2020](https://flame.is.tue.mpg.de/).
71
+
72
+ Put it in the `FLAME2020/` dir. We only need generic_model.pkl. Your `FLAME2020/` directory should look like this:
73
+
74
+ ```bash
75
+ FLAME2020/
76
+ ├── eyelids.pt
77
+ ├── flame_landmark_idxs_barys.pt
78
+ └── generic_model.pkl
79
+ ```
80
+
81
+ Now convert FLAME to our format:
82
+
83
+ ```bash
84
+ python convert_flame.py
85
+ ```
86
+
87
+ ## Reproduce paper results on NoW dataset
88
+
89
+ To reproduce the validation results from the paper (median=0.93mm):
90
+
91
+ First, update submodules:
92
+
93
+ ```bash
94
+ git submodule update --init --recursive
95
+ ```
96
+
97
+ Then build the NoW Evaluation docker image:
98
+
99
+ ```bash
100
+ docker build -t noweval now/now_evaluation
101
+ ```
102
+
103
+ Then predict FLAME meshes for all images in NoW using SHeaP:
104
+
105
+ ```
106
+ cd now/
107
+ python now.py --now-dataset-root /path/to/NoW_Evaluation/dataset
108
+ ```
109
+
110
+ Upon finishing, the above command will print a command like the following:
111
+
112
+ ```
113
+ chmod 777 -R /home/user/sheap/now/now_eval_outputs/now_preds && docker run --ipc host --gpus all -it --rm -v /data/NoW_Evaluation/dataset:/dataset -v /home/user/sheap/now/now_eval_outputs/now_preds:/preds noweval
114
+ ```
115
+
116
+ Run that command. This will run NoW evaluation on the FLAME meshes we just predicted.
117
+
118
+ Finally, the results will be placed in `/home/user/sheap/now/now_eval_outputs/now_preds` (or equivalent). The mean and median are already calculated:
119
+
120
+ ```bash
121
+ ➜ cat /home/user/sheap/now/now_eval_outputs/now_preds/results/RECON_computed_distances.npy.meanmedian
122
+ 0.9327719333872148 # result in the paper
123
+ 1.1568168246248534
124
+ ```
app.py CHANGED
@@ -1,8 +1,9 @@
1
- import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
8
 
 
 
 
1
+ """
2
+ Hugging Face Space entry point for SHeaP demo.
3
+ This file imports and runs the gradio demo.
4
+ """
5
 
6
+ from gradio_demo import demo
 
 
 
 
7
 
8
+ if __name__ == "__main__":
9
+ demo.launch()
convert_flame.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts FLAME pickle files to PyTorch .pt files.
3
+ """
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from sheap.load_flame_pkl import load_pkl_format_flame_model
12
+
13
+
14
+ def convert_flame(flame_base_dir: Union[str, Path], overwrite: bool) -> None:
15
+ """Convert FLAME pickle files to PyTorch .pt files.
16
+
17
+ Searches for all .pkl files in the FLAME base directory and converts them to
18
+ PyTorch .pt format, skipping certain mask files.
19
+
20
+ Args:
21
+ flame_base_dir: Path to the FLAME model directory containing pickle files.
22
+ overwrite: Whether to overwrite existing .pt files if they already exist.
23
+
24
+ Raises:
25
+ AssertionError: If flame_base_dir does not exist.
26
+ """
27
+ flame_base_dir = Path(flame_base_dir)
28
+ assert flame_base_dir is not None # for mypy
29
+ assert flame_base_dir.exists(), (
30
+ f"FLAME_BASE_DIR not found at {flame_base_dir}. "
31
+ "Please set arg flame_base_dir to the FLAME model directory, "
32
+ " or set the FLAME_BASE_DIR environment variable."
33
+ )
34
+ pickle_files = list(flame_base_dir.glob("**/**/*.pkl"))
35
+ skip_files = ["FLAME_masks.pkl"]
36
+ for model_path in pickle_files:
37
+ if model_path.name in skip_files:
38
+ continue
39
+ print(f"Converting {model_path}...")
40
+ data = load_pkl_format_flame_model(model_path)
41
+ new_path = model_path.with_suffix(".pt")
42
+ if new_path.exists() and not overwrite:
43
+ print(f"Skipping {new_path} because it already exists.")
44
+ continue
45
+ torch.save(data, new_path)
46
+ print(f"Saved {new_path}")
47
+
48
+
49
+ def main() -> None:
50
+ """Parse command-line arguments and convert FLAME pickle files to PyTorch format."""
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument(
53
+ "--flame_base_dir",
54
+ type=str,
55
+ help="Path to the FLAME model directory. "
56
+ "Defaults to the FLAME_BASE_DIR environment variable.",
57
+ default="FLAME2020/",
58
+ )
59
+ parser.add_argument(
60
+ "--overwrite",
61
+ action="store_true",
62
+ help="Overwrite existing files if they already exist.",
63
+ )
64
+ convert_flame(**vars(parser.parse_args()))
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
demo.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from sheap import inference_images_list, load_sheap_model, render_mesh
9
+ from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats
10
+
11
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
12
+
13
+
14
+ def create_rendering_image(
15
+ original_image: Image.Image,
16
+ verts: torch.Tensor,
17
+ faces: torch.Tensor,
18
+ c2w: torch.Tensor,
19
+ output_size: int = 512,
20
+ ) -> Image.Image:
21
+ """
22
+ Create a combined image with original, mesh, and blended views.
23
+
24
+ Args:
25
+ original_image: PIL Image of the original frame
26
+ verts: Vertices tensor for a single frame, shape (num_verts, 3)
27
+ faces: Faces tensor, shape (num_faces, 3)
28
+ c2w: Camera-to-world transformation matrix, shape (4, 4)
29
+ output_size: Size of each sub-image in the combined output
30
+
31
+ Returns:
32
+ PIL Image with three views side-by-side (original, mesh, blended)
33
+ """
34
+ # Render the mesh
35
+ color, depth = render_mesh(verts=verts, faces=faces, c2w=c2w)
36
+
37
+ # Resize original to match output size
38
+ original_resized = original_image.convert("RGB").resize((output_size, output_size))
39
+
40
+ # Create blended image (mesh overlaid on original)
41
+ mask = (depth > 0).astype(np.float32)[..., None]
42
+ blended = (np.array(color) * mask + np.array(original_resized) * (1 - mask)).astype(np.uint8)
43
+
44
+ # Combine all three images horizontally
45
+ combined = Image.new("RGB", (output_size * 3, output_size))
46
+ combined.paste(original_resized, (0, 0))
47
+ combined.paste(Image.fromarray(color), (output_size, 0))
48
+ combined.paste(Image.fromarray(blended), (output_size * 2, 0))
49
+
50
+ return combined
51
+
52
+
53
+ if __name__ == "__main__":
54
+ # Load SHeaP model
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ sheap_model = load_sheap_model(model_type="expressive").to(device)
57
+
58
+ # Inference on example images
59
+ folder_containing_images = Path("example_images/")
60
+ image_paths = list(sorted(folder_containing_images.glob("*.jpg")))
61
+ with torch.no_grad():
62
+ predictions = inference_images_list(
63
+ model=sheap_model,
64
+ device=device,
65
+ image_paths=image_paths,
66
+ )
67
+
68
+ # Load and infer FLAME with our predicted parameters
69
+ flame_dir = Path("FLAME2020/")
70
+ flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt")
71
+ verts = flame(
72
+ shape=predictions["shape_from_facenet"],
73
+ expression=predictions["expr"],
74
+ pose=pose_components_to_rotmats(predictions),
75
+ eyelids=predictions["eyelids"],
76
+ translation=predictions["cam_trans"],
77
+ )
78
+
79
+ # Render the FLAME mesh for each input image
80
+ c2w = torch.tensor(
81
+ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32
82
+ )
83
+ for i_frame in range(verts.shape[0]):
84
+ outpath = image_paths[i_frame].with_name(f"{image_paths[i_frame].name}_rendered.png")
85
+ if outpath.exists():
86
+ outpath.unlink()
87
+
88
+ # Load original image
89
+ original = Image.open(image_paths[i_frame])
90
+
91
+ # Create combined rendering
92
+ combined = create_rendering_image(
93
+ original_image=original,
94
+ verts=verts[i_frame],
95
+ faces=flame.faces,
96
+ c2w=c2w,
97
+ output_size=512,
98
+ )
99
+ combined.save(outpath)
example_images/00000200.jpg ADDED

Git LFS Details

  • SHA256: 916d2843bb24bb71f7cfc586eda9b5834021d0e44bc6e4c4bc01f8ff91d8ac55
  • Pointer size: 130 Bytes
  • Size of remote file: 12.6 kB
example_images/00000201.jpg ADDED

Git LFS Details

  • SHA256: dde19f2c095bff0a262c35f8d807b840b1d2b11cd91ea15db84ad9f94b45eed1
  • Pointer size: 130 Bytes
  • Size of remote file: 12.5 kB
example_images/00000202.jpg ADDED

Git LFS Details

  • SHA256: 90bc8d8464675f0d83eeb67677297f86fec72590a7d6dffbda279e245ec6bbbb
  • Pointer size: 130 Bytes
  • Size of remote file: 12.6 kB
example_images/00000203.jpg ADDED

Git LFS Details

  • SHA256: a8880aad3eb9d541c8939b41812e1a680aa60b090030dbae92dec9d4b9046701
  • Pointer size: 130 Bytes
  • Size of remote file: 12.6 kB
example_images/00000204.jpg ADDED

Git LFS Details

  • SHA256: b16b78ff62da579f4216eb354f37957a6b86290f1c13e94b48100c45cd9e8747
  • Pointer size: 130 Bytes
  • Size of remote file: 12.5 kB
example_images/00000205.jpg ADDED

Git LFS Details

  • SHA256: b67f32bada48399a45769f80a677706d40f650af2e8de1f7faa0c504b58798e6
  • Pointer size: 130 Bytes
  • Size of remote file: 12.6 kB
example_images/00000206.jpg ADDED

Git LFS Details

  • SHA256: 317a2eb7578bcb77b703dcca5f650bcc0f1d7236b83000f92f8c7f9644e99290
  • Pointer size: 130 Bytes
  • Size of remote file: 12.6 kB
example_images/00000207.jpg ADDED

Git LFS Details

  • SHA256: 7abc69cb48790a643d5e1ff629bd0afc93be9b5b83217e4395dd42a0feeeb214
  • Pointer size: 130 Bytes
  • Size of remote file: 12.5 kB
example_images/00000208.jpg ADDED

Git LFS Details

  • SHA256: 332654d0d26fc6e3551bc5c4dc3b919e8673cee04a09d0d73de9d59a61a098f7
  • Pointer size: 130 Bytes
  • Size of remote file: 12.7 kB
example_images/00000209.jpg ADDED

Git LFS Details

  • SHA256: b70bdbbeeec905bef8b1a6053b2433aad8eb8a3ff203a0b62a2e6e72de097392
  • Pointer size: 130 Bytes
  • Size of remote file: 12.6 kB
example_videos/dafoe.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faab39e11cde3a3607039dc27b202472f31b39a757cb91a2fceedf67679b9e24
3
+ size 441906
gradio_demo.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for SHeaP (Self-Supervised Head Geometry Predictor).
3
+ Accepts video or image input and renders the SHEAP output overlayed.
4
+ """
5
+
6
+ import os
7
+ import shutil
8
+ import subprocess
9
+ import tempfile
10
+ from pathlib import Path
11
+ from queue import Queue
12
+ from typing import Optional
13
+
14
+ import gradio as gr
15
+ import numpy as np
16
+ import torch
17
+ import torchvision.transforms.functional as TF
18
+ from PIL import Image
19
+ from torch.utils.data import DataLoader
20
+
21
+ from demo import create_rendering_image
22
+ from sheap import load_sheap_model
23
+ from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats
24
+
25
+ try:
26
+ import face_alignment
27
+ except ImportError:
28
+ raise ImportError(
29
+ "The 'face_alignment' package is required. Please install it via 'pip install face-alignment'."
30
+ )
31
+ from sheap.fa_landmark_utils import detect_face_and_crop
32
+
33
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
34
+
35
+ # Global variables for models (load once)
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ sheap_model = None
38
+ flame = None
39
+ fa_model = None
40
+ c2w = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32)
41
+
42
+
43
+ def initialize_models():
44
+ """Initialize all models (called once at startup)."""
45
+ global sheap_model, flame, fa_model
46
+
47
+ print("Loading SHeaP model...")
48
+ sheap_model = load_sheap_model(model_type="expressive").to(device)
49
+ sheap_model.eval()
50
+
51
+ print("Loading FLAME model...")
52
+ flame_dir = Path("FLAME2020/")
53
+ flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt").to(
54
+ device
55
+ )
56
+
57
+ print("Loading face alignment model...")
58
+ fa_model = face_alignment.FaceAlignment(
59
+ face_alignment.LandmarksType.TWO_D, device=str(device), flip_input=False
60
+ )
61
+
62
+ print("Models loaded successfully!")
63
+
64
+
65
+ def process_image(image: np.ndarray) -> Image.Image:
66
+ """
67
+ Process a single image and return the rendered output.
68
+
69
+ Args:
70
+ image: Input image as numpy array (RGB)
71
+
72
+ Returns:
73
+ PIL Image with three views side-by-side (original, mesh, blended)
74
+ """
75
+ # Convert to torch tensor for face detection (C, H, W) format with values in [0, 1]
76
+ image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
77
+
78
+ # Detect face and get crop coordinates
79
+ x0, y0, x1, y1 = detect_face_and_crop(image_tensor, fa_model, margin=0.9, shift_up=0.5)
80
+
81
+ # Crop the image
82
+ cropped_tensor = image_tensor[:, y0:y1, x0:x1]
83
+
84
+ # Resize to 224x224 for SHEAP model
85
+ cropped_resized = TF.resize(cropped_tensor, [224, 224], antialias=True)
86
+
87
+ # Prepare input tensor for model
88
+ img_tensor = cropped_resized.unsqueeze(0).to(device)
89
+
90
+ # Also create a 512x512 version for rendering
91
+ cropped_for_render = TF.resize(cropped_tensor, [512, 512], antialias=True)
92
+
93
+ # Run inference
94
+ with torch.no_grad():
95
+ predictions = sheap_model(img_tensor)
96
+
97
+ # Get FLAME vertices (predictions are already on device from model)
98
+ verts = flame(
99
+ shape=predictions["shape_from_facenet"],
100
+ expression=predictions["expr"],
101
+ pose=pose_components_to_rotmats(predictions),
102
+ eyelids=predictions["eyelids"],
103
+ translation=predictions["cam_trans"],
104
+ )
105
+
106
+ # Move vertices to CPU for rendering
107
+ verts = verts.cpu()
108
+
109
+ # Convert cropped_for_render back to PIL Image for rendering
110
+ cropped_pil = TF.to_pil_image(cropped_for_render)
111
+
112
+ # Create rendering
113
+ combined = create_rendering_image(
114
+ original_image=cropped_pil,
115
+ verts=verts[0],
116
+ faces=flame.faces,
117
+ c2w=c2w,
118
+ output_size=512,
119
+ )
120
+
121
+ return combined
122
+
123
+
124
+ # --- Import video utilities from video_demo.py ---
125
+ from video_demo import RenderingThread, VideoFrameDataset, _tensor_to_numpy_image
126
+
127
+
128
+ def process_video(video_path: str, progress=gr.Progress()) -> str:
129
+ """
130
+ Process a video and return path to the rendered output video using background threads.
131
+ """
132
+ temp_dir = Path(tempfile.mkdtemp())
133
+ render_size = 512
134
+ try:
135
+ # Prepare dataset and dataloader
136
+ dataset = VideoFrameDataset(video_path, fa_model)
137
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=0)
138
+ fps = dataset.fps
139
+ num_frames = len(dataset)
140
+ # Prepare rendering thread and queue
141
+ render_queue = Queue(maxsize=32)
142
+ num_render_workers = 1
143
+ rendering_threads = []
144
+ for _ in range(num_render_workers):
145
+ thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size)
146
+ thread.start()
147
+ rendering_threads.append(thread)
148
+ progress(0, desc="Processing video frames...")
149
+ frame_idx = 0
150
+ with torch.no_grad():
151
+ for batch in dataloader:
152
+ images = batch["image"].to(device)
153
+ cropped_frames = batch["cropped_frame"]
154
+ # Run inference
155
+ predictions = sheap_model(images)
156
+ verts = flame(
157
+ shape=predictions["shape_from_facenet"],
158
+ expression=predictions["expr"],
159
+ pose=pose_components_to_rotmats(predictions),
160
+ eyelids=predictions["eyelids"],
161
+ translation=predictions["cam_trans"],
162
+ )
163
+ verts = verts.cpu()
164
+ for i in range(images.shape[0]):
165
+ cropped_frame = _tensor_to_numpy_image(cropped_frames[i])
166
+ render_queue.put((frame_idx, cropped_frame, verts[i]))
167
+ frame_idx += 1
168
+ progress(
169
+ frame_idx / num_frames, desc=f"Processing frame {frame_idx}/{num_frames}"
170
+ )
171
+ # Stop rendering threads
172
+ for _ in range(num_render_workers):
173
+ render_queue.put(None)
174
+ for thread in rendering_threads:
175
+ thread.join()
176
+ if frame_idx == 0:
177
+ raise ValueError("No frames were successfully processed!")
178
+ # Create output video using ffmpeg
179
+ progress(0.95, desc="Encoding video...")
180
+ output_path = temp_dir / "output.mp4"
181
+ ffmpeg_cmd = [
182
+ "ffmpeg",
183
+ "-y",
184
+ "-framerate",
185
+ str(fps),
186
+ "-i",
187
+ str(temp_dir / "frame_%06d.png"),
188
+ "-c:v",
189
+ "libx264",
190
+ "-pix_fmt",
191
+ "yuv420p",
192
+ "-crf",
193
+ "18",
194
+ str(output_path),
195
+ ]
196
+ subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
197
+ progress(1.0, desc="Done!")
198
+ return str(output_path)
199
+ except Exception as e:
200
+ shutil.rmtree(temp_dir, ignore_errors=True)
201
+ raise e
202
+
203
+
204
+ def process_input(image: Optional[np.ndarray], video: Optional[str]):
205
+ """
206
+ Process either image or video input.
207
+
208
+ Args:
209
+ image: Input image (if provided)
210
+ video: Input video path (if provided)
211
+
212
+ Returns:
213
+ Either an image or video path depending on input
214
+ """
215
+ if image is not None:
216
+ return process_image(image), None
217
+ elif video is not None:
218
+ return None, process_video(video)
219
+ else:
220
+ raise ValueError("Please provide either an image or video!")
221
+
222
+
223
+ # Initialize models on startup
224
+ initialize_models()
225
+
226
+ # Create Gradio interface
227
+ with gr.Blocks(title="SHeaP Demo") as demo:
228
+ gr.Markdown(
229
+ """
230
+ # 🐑 SHeaP: Self-Supervised Head Geometry Predictor 🐑
231
+
232
+ Upload an image or video to predict head geometry and render a 3D mesh overlay!
233
+
234
+ The output shows three views:
235
+ - **Left**: Original cropped face
236
+ - **Center**: Rendered FLAME mesh
237
+ - **Right**: Mesh overlaid on original
238
+
239
+ [Project Page](https://nlml.github.io/sheap) | [Paper](https://arxiv.org/abs/2504.12292) | [GitHub](https://github.com/nlml/sheap)
240
+ """
241
+ )
242
+
243
+ with gr.Row():
244
+ with gr.Column():
245
+ gr.Markdown("### Input")
246
+ image_input = gr.Image(label="Upload Image", type="numpy")
247
+ video_input = gr.Video(label="Upload Video")
248
+ process_btn = gr.Button("Process", variant="primary")
249
+
250
+ with gr.Column():
251
+ gr.Markdown("### Output")
252
+ image_output = gr.Image(label="Rendered Image", type="pil")
253
+ video_output = gr.Video(label="Rendered Video")
254
+
255
+ gr.Markdown(
256
+ """
257
+ ### Tips:
258
+ - For best results, use images/videos with clearly visible faces
259
+ - The model works best with frontal face views
260
+ - Video processing may take a few minutes depending on length
261
+ """
262
+ )
263
+
264
+ # Connect the button
265
+ process_btn.click(
266
+ fn=process_input,
267
+ inputs=[image_input, video_input],
268
+ outputs=[image_output, video_output],
269
+ )
270
+
271
+ # Add examples
272
+ gr.Examples(
273
+ examples=[
274
+ ["example_images/00000206.jpg", None],
275
+ [None, "example_videos/dafoe.mp4"],
276
+ ],
277
+ inputs=[image_input, video_input],
278
+ outputs=[image_output, video_output],
279
+ fn=process_input,
280
+ cache_examples=False,
281
+ )
282
+
283
+ if __name__ == "__main__":
284
+ demo.launch()
models/model_expressive.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d769f493072aa2e98770ed1b71db784bc3ee0a2132a0fd36aab841ee591c5e2
3
+ size 348292433
pyproject.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "sheap"
7
+ version = "0.1.0"
8
+ description = "SHeaP: Self-Supervised Head Geometry Predictor Learned via 2D Gaussians"
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ license = { file = "LICENSE.txt" }
12
+ authors = [
13
+ { name = "Liam Schoneveld" }
14
+ ]
15
+ keywords = ["3d", "face", "flame", "head", "mesh", "reconstruction"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Developers",
19
+ "Intended Audience :: Science/Research",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
24
+ ]
25
+ dependencies = [
26
+ "chumpy @ git+https://github.com/nlml/chumpy.git",
27
+ "numpy>=1.20.0",
28
+ "pillow>=9.0.0",
29
+ "pyrender>=0.1.45",
30
+ "roma>=1.5.4",
31
+ "scipy>=1.16.3",
32
+ "torch>=2.0.0",
33
+ "torchaudio>=2.0.0",
34
+ "torchvision>=0.15.1",
35
+ "tqdm>=4.67.1",
36
+ "trimesh>=4.9.0",
37
+ ]
38
+
39
+ [project.urls]
40
+ Homepage = "https://nlml.github.io/sheap"
41
+ Repository = "https://github.com/nlml/sheap"
42
+
43
+ [dependency-groups]
44
+ dev = [
45
+ "pre-commit>=4.3.0",
46
+ ]
47
+
48
+ [tool.hatch.metadata]
49
+ allow-direct-references = true
50
+
51
+ [tool.hatch.build.targets.wheel]
52
+ packages = ["sheap"]
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/nlml/chumpy.git
2
+ numpy>=1.20.0
3
+ pillow>=9.0.0
4
+ pyrender>=0.1.45
5
+ roma>=1.5.4
6
+ scipy>=1.16.3
7
+ torch>=2.0.0
8
+ torchaudio>=2.0.0
9
+ torchvision>=0.15.1
10
+ tqdm>=4.67.1
11
+ trimesh>=4.9.0
12
+ gradio>=4.0.0
13
+ face-alignment>=1.3.5
14
+ opencv-python>=4.5.0
requirements_hf.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for Hugging Face Spaces deployment
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ numpy>=1.24.0
5
+ pillow>=9.5.0
6
+ opencv-python-headless>=4.8.0
7
+ gradio>=4.0.0
8
+ face-alignment>=1.4.1
9
+ pyrender>=0.1.45
10
+ trimesh>=4.0.0
11
+ scipy>=1.11.0
12
+ scikit-image>=0.21.0
13
+ networkx>=3.1
14
+ # For rendering
15
+ pyopengl>=3.1.0
sheap/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SHeaP: Self-Supervised Head Geometry Predictor Learned via 2D Gaussians."""
2
+
3
+ from .eval_utils import ImsDataset, inference_images_list, save_result
4
+ from .landmark_utils import vertices_to_7_lmks, vertices_to_landmarks
5
+ from .load_flame_pkl import load_pkl_format_flame_model
6
+ from .load_model import load_sheap_model
7
+ from .render import render_mesh
8
+ from .tiny_flame import TinyFlame
9
+
10
+ __version__ = "0.1.0"
11
+ __all__ = [
12
+ "TinyFlame",
13
+ "load_pkl_format_flame_model",
14
+ "vertices_to_landmarks",
15
+ "vertices_to_7_lmks",
16
+ "inference_images_list",
17
+ "save_result",
18
+ "ImsDataset",
19
+ "render_mesh",
20
+ "load_sheap_model",
21
+ ]
sheap/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (858 Bytes). View file
 
sheap/__pycache__/eval_utils.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
sheap/__pycache__/fa_landmark_utils.cpython-311.pyc ADDED
Binary file (4.1 kB). View file
 
sheap/__pycache__/landmark_utils.cpython-311.pyc ADDED
Binary file (5.82 kB). View file
 
sheap/__pycache__/load_flame.cpython-311.pyc ADDED
Binary file (1.93 kB). View file
 
sheap/__pycache__/load_flame_pkl.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
sheap/__pycache__/load_model.cpython-311.pyc ADDED
Binary file (4.14 kB). View file
 
sheap/__pycache__/render.cpython-311.pyc ADDED
Binary file (4.34 kB). View file
 
sheap/__pycache__/tiny_flame.cpython-311.pyc ADDED
Binary file (7.83 kB). View file
 
sheap/eval_utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data as tud
7
+ import trimesh
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+
12
+ def _preproc_im_default(p: Union[str, Path]) -> Image.Image:
13
+ """Default image preprocessing function that loads an image from a path.
14
+
15
+ Args:
16
+ p: Path to the image file.
17
+
18
+ Returns:
19
+ PIL Image object.
20
+ """
21
+ return Image.open(p)
22
+
23
+
24
+ class ImsDataset(tud.Dataset):
25
+ """Dataset for loading and preprocessing images.
26
+
27
+ Args:
28
+ image_paths: List of paths to image files.
29
+ img_wh: Tuple of (width, height) to resize images to.
30
+ load_and_preproc_im: Optional custom function to load and preprocess images.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ image_paths: List[Union[str, Path]],
36
+ img_wh: Tuple[int, int],
37
+ load_and_preproc_im: Optional[
38
+ Callable[[Union[str, Path]], Image.Image]
39
+ ] = _preproc_im_default,
40
+ ) -> None:
41
+ self.image_paths = image_paths
42
+ self.img_wh = img_wh
43
+ self.load_and_preproc_im = load_and_preproc_im
44
+ if self.load_and_preproc_im is None:
45
+ self.load_and_preproc_im = _preproc_im_default
46
+
47
+ def __len__(self) -> int:
48
+ """Return the number of images in the dataset."""
49
+ return len(self.image_paths)
50
+
51
+ def __getitem__(self, idx: int) -> torch.Tensor:
52
+ """Load and preprocess an image at the given index.
53
+
54
+ Args:
55
+ idx: Index of the image to load.
56
+
57
+ Returns:
58
+ Preprocessed image tensor of shape (3, H, W) with values in [0, 1].
59
+ """
60
+ impath = self.image_paths[idx]
61
+ pil_im = self.load_and_preproc_im(impath)
62
+ im = pil_im.convert("RGB").resize(self.img_wh)
63
+ im = np.array(im).astype("float64") / 255.0
64
+ im = torch.from_numpy(im).permute(2, 0, 1).float()
65
+ return im
66
+
67
+
68
+ @torch.no_grad()
69
+ def inference_images_list(
70
+ model: torch.nn.Module,
71
+ device: torch.device,
72
+ image_paths: List[Union[str, Path]],
73
+ custom_pil_im_load_fn: Optional[Callable[[Union[str, Path]], Image.Image]] = None,
74
+ img_wh: Tuple[int, int] = (224, 224),
75
+ batch_size: int = 4,
76
+ num_workers: int = 4,
77
+ verbose: bool = False,
78
+ ) -> Dict[str, torch.Tensor]:
79
+ """Run inference on a list of images using a model.
80
+
81
+ Args:
82
+ model: PyTorch model to use for inference.
83
+ device: Device to run inference on.
84
+ image_paths: List of paths to image files.
85
+ custom_pil_im_load_fn: Optional custom function to load and preprocess images.
86
+ img_wh: Tuple of (width, height) to resize images to. Default is (224, 224).
87
+ batch_size: Batch size for inference. Default is 4.
88
+ num_workers: Number of workers for data loading. Default is 4.
89
+ verbose: Whether to print output shapes. Default is False.
90
+
91
+ Returns:
92
+ Dictionary mapping output keys to concatenated tensors across all batches.
93
+ """
94
+ model = model.to(device)
95
+ ds = ImsDataset(image_paths, img_wh=img_wh, load_and_preproc_im=custom_pil_im_load_fn)
96
+ dl = torch.utils.data.DataLoader(
97
+ ds,
98
+ batch_size=batch_size,
99
+ shuffle=False,
100
+ num_workers=num_workers,
101
+ drop_last=False,
102
+ pin_memory=True,
103
+ )
104
+
105
+ all_outs = {}
106
+ for images in tqdm(dl, desc="Inferencing images through ViT model"):
107
+ images = images.to(device)
108
+ batch_size = images.shape[0]
109
+ model_outs = model(images)
110
+ for k in model_outs:
111
+ if not isinstance(model_outs[k], torch.Tensor):
112
+ continue
113
+ if k not in all_outs:
114
+ all_outs[k] = []
115
+ all_outs[k].append(model_outs[k].detach().cpu())
116
+
117
+ if verbose:
118
+ print("Concatenated output shapes:")
119
+ for k in all_outs:
120
+ all_outs[k] = torch.cat(all_outs[k], dim=0)
121
+ if verbose:
122
+ print(" --", k, all_outs[k].shape)
123
+ return all_outs
124
+
125
+
126
+ def invert_4x4_cam_matrix(inp_cam: torch.Tensor) -> torch.Tensor:
127
+ """Invert a 4x4 camera transformation matrix.
128
+
129
+ Args:
130
+ inp_cam: 4x4 camera transformation matrix.
131
+
132
+ Returns:
133
+ Inverted 4x4 camera transformation matrix.
134
+ """
135
+ rr = inp_cam[:3, :3].T
136
+ tt = rr @ -inp_cam[:3, 3]
137
+ inv_cam = torch.eye(4, device=inp_cam.device, dtype=inp_cam.dtype)
138
+ inv_cam[:3, :3] = rr
139
+ inv_cam[:3, 3] = tt
140
+ return inv_cam
141
+
142
+
143
+ def save_obj(outpath: Union[str, Path], verts: np.ndarray, faces: np.ndarray) -> None:
144
+ """Save vertices and faces as an OBJ file using trimesh.
145
+
146
+ Args:
147
+ outpath: Path where the OBJ file will be saved.
148
+ verts: Vertex array of shape (N, 3).
149
+ faces: Face array of shape (M, 3) containing vertex indices.
150
+ """
151
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
152
+ mesh.export(outpath)
153
+
154
+
155
+ def save_result(
156
+ flame_faces: np.ndarray,
157
+ base_dir: Union[str, Path],
158
+ verts_with_zero_exprn: np.ndarray,
159
+ lmks7_3d: torch.Tensor,
160
+ preds_outdir: Path,
161
+ input_im_path: Union[str, Path],
162
+ verbose: bool = False,
163
+ ) -> None:
164
+ """Save FLAME model prediction results to disk.
165
+
166
+ Saves the predicted mesh as an OBJ file and 3D landmarks as a numpy file.
167
+ Vertices and landmarks are scaled by 1000 to match MICA format.
168
+
169
+ Args:
170
+ flame_faces: FLAME model face indices.
171
+ base_dir: Base directory for computing relative paths.
172
+ verts_with_zero_exprn: Predicted vertices with zero expression.
173
+ lmks7_3d: 3D landmarks tensor.
174
+ preds_outdir: Output directory for predictions.
175
+ input_im_path: Path to the input image.
176
+ verbose: Whether to print save confirmation messages. Default is False.
177
+ """
178
+ # MICA scaled up by 1000, so let's try it too:
179
+ pred_verts = verts_with_zero_exprn * 1000.0
180
+ pred_lmks7_3d = lmks7_3d.numpy() * 1000.0
181
+
182
+ outpath_jpg = preds_outdir / Path(input_im_path).relative_to(base_dir)
183
+ outpath_obj = outpath_jpg.with_suffix(".obj")
184
+
185
+ outpath_obj.parent.mkdir(parents=True, exist_ok=True)
186
+
187
+ save_obj(outpath_obj, verts=pred_verts, faces=flame_faces)
188
+ if verbose:
189
+ print(f"Saved {outpath_obj}")
190
+
191
+ outpath_lmk_npy = outpath_obj.with_suffix(".npy")
192
+ np.save(outpath_lmk_npy, pred_lmks7_3d)
193
+ if verbose:
194
+ print(f"Saved {outpath_lmk_npy}")
195
+
196
+ assert outpath_obj.exists()
197
+ assert outpath_lmk_npy.exists()
198
+
199
+
200
+ def add_pct_to_bbox(
201
+ top: int,
202
+ left: int,
203
+ bottom: int,
204
+ right: int,
205
+ im_np_array: Union[np.ndarray, Image.Image],
206
+ pct: float = 0.2,
207
+ ) -> Tuple[int, int, int, int]:
208
+ """Expand a bounding box by a percentage while staying within image bounds.
209
+
210
+ Args:
211
+ top: Top coordinate of the bounding box.
212
+ left: Left coordinate of the bounding box.
213
+ bottom: Bottom coordinate of the bounding box.
214
+ right: Right coordinate of the bounding box.
215
+ im_np_array: Image as numpy array or PIL Image.
216
+ pct: Percentage to expand the bounding box by. Default is 0.2 (20%).
217
+
218
+ Returns:
219
+ Tuple of (top, left, bottom, right) coordinates of the expanded bounding box.
220
+ """
221
+ if isinstance(im_np_array, Image.Image):
222
+ im_np_array = np.array(im_np_array)
223
+ h, w, _ = im_np_array.shape
224
+
225
+ box_height = bottom - top
226
+ top = max(0, top - int(box_height * pct * 0.5))
227
+ bottom = top + int(box_height * (1 + pct))
228
+ bottom = min(h, bottom)
229
+
230
+ box_width = right - left
231
+ left = max(0, left - int(box_width * pct * 0.5))
232
+ right = left + int(box_width * (1 + pct))
233
+ right = min(w, right)
234
+
235
+ return top, left, bottom, right
236
+
237
+
238
+ def resize_to_max_size(
239
+ im: Union[np.ndarray, Image.Image], max_size: int = 512, pad_smaller: bool = True
240
+ ) -> Union[np.ndarray, Image.Image]:
241
+ """Resize an image to fit within a maximum size, optionally padding to square.
242
+
243
+ Args:
244
+ im: Input image as numpy array or PIL Image.
245
+ max_size: Maximum size for the longest dimension. Default is 512.
246
+ pad_smaller: Whether to pad the smaller dimension to create a square image.
247
+ Default is True.
248
+
249
+ Returns:
250
+ Resized (and optionally padded) image in the same format as input.
251
+ """
252
+ was_np = False
253
+ if isinstance(im, np.ndarray):
254
+ im = Image.fromarray(im)
255
+ was_np = True
256
+ w, h = im.size
257
+ if h > w:
258
+ new_h = max_size
259
+ new_w = int(w * (max_size / h))
260
+ else:
261
+ new_w = max_size
262
+ new_h = int(h * (max_size / w))
263
+ im = im.resize((new_w, new_h))
264
+ if pad_smaller:
265
+ new_im = Image.new("RGB", (max_size, max_size))
266
+ new_im.paste(im, ((max_size - new_w) // 2, (max_size - new_h) // 2))
267
+ im = new_im
268
+ if was_np:
269
+ return np.array(im)
270
+ return im
sheap/fa_landmark_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import face_alignment
4
+ import numpy as np
5
+ import torch
6
+ from numpy.typing import NDArray
7
+
8
+ from sheap.landmark_utils import landmarks_2_face_bounding_box
9
+
10
+
11
+ def get_fa_landmarks(
12
+ np_array_im_255_uint8: NDArray[np.uint8],
13
+ fa: face_alignment.FaceAlignment,
14
+ normalize: bool = True,
15
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
16
+ """
17
+ Extract facial landmarks from an image using face_alignment.
18
+
19
+ Args:
20
+ np_array_im_255_uint8: Image array of shape (H, W, 3) with values in [0, 255]
21
+ fa: FaceAlignment model instance
22
+ normalize: If True, normalize landmarks to [0, 1] range
23
+
24
+ Returns:
25
+ Tuple of (landmarks, success):
26
+ - landmarks: Tensor of shape (68, 2) with normalized coordinates
27
+ - success: Boolean tensor indicating if face was detected
28
+ """
29
+ preds = fa.get_landmarks(np_array_im_255_uint8)
30
+ if preds is not None:
31
+ if normalize:
32
+ h, w = np_array_im_255_uint8.shape[:2]
33
+ lmks = preds[0][:, :2] / np.array([w, h])
34
+ else:
35
+ lmks = preds[0][:, :2]
36
+ success = True
37
+ else:
38
+ lmks = np.zeros((68, 2))
39
+ success = False
40
+
41
+ lmks_tensor = torch.from_numpy(lmks).float()
42
+ success_tensor = torch.tensor(success).bool()
43
+ return lmks_tensor, success_tensor
44
+
45
+
46
+ def detect_face_and_crop(
47
+ image: torch.Tensor,
48
+ fa_model: face_alignment.FaceAlignment,
49
+ margin: float = 0.6,
50
+ shift_up: float = 0.2,
51
+ ) -> Tuple[int, int, int, int]:
52
+ """
53
+ Detect face and compute bounding box coordinates.
54
+
55
+ Args:
56
+ image: torch.Tensor of shape (3, H, W) with values in [0, 1]
57
+ fa_model: FaceAlignment model instance for landmark detection
58
+
59
+ Returns:
60
+ tuple: (x0, x1, y0, y1) bounding box coordinates in pixels
61
+ """
62
+ _, h, w = image.shape
63
+
64
+ # Convert image to numpy format for face_alignment (H, W, 3) with values [0, 255]
65
+ image_np = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
66
+
67
+ # Get facial landmarks
68
+ lmks, success = get_fa_landmarks(image_np, fa_model, normalize=True)
69
+
70
+ if not success:
71
+ # If face detection fails, return center square from image
72
+ if h > w:
73
+ y0 = (h - w) // 2
74
+ y1 = y0 + w
75
+ x0 = 0
76
+ x1 = w
77
+ else:
78
+ x0 = (w - h) // 2
79
+ x1 = x0 + h
80
+ y0 = 0
81
+ y1 = h
82
+ return x0, x1, y0, y1
83
+
84
+ # Add batch dimension for landmarks_2_face_bounding_box
85
+ lmks_batched = lmks.unsqueeze(0) # Shape: (1, 68, 2)
86
+ valid = torch.ones(1, dtype=torch.bool)
87
+
88
+ # Compute bounding box in normalized coordinates
89
+ bbox = landmarks_2_face_bounding_box(
90
+ lmks_batched, valid, margin=margin, clamp=True, shift_up=shift_up, aspect_ratio=w / h
91
+ )
92
+
93
+ x0, y0, x1, y1 = bbox[0].tolist()
94
+ x0, y0, x1, y1 = int(x0 * w), int(y0 * h), int(x1 * w), int(y1 * h)
95
+
96
+ return x0, y0, x1, y1
sheap/landmark_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ def vertices_to_landmarks(
8
+ vertices: Tensor, # shape: (*batch, num_vertices, 3)
9
+ faces: Tensor, # shape: (num_faces, 3), indices of vertices
10
+ face_indices_with_landmarks: Tensor, # shape: (num_landmarks,), indices of faces
11
+ barys: Tensor, # shape: (num_landmarks, 3), barycentric coordinates
12
+ ) -> Tensor:
13
+ """
14
+ Calculate the 3D world coordinates of landmarks from mesh vertices.
15
+
16
+ Args:
17
+ vertices (Tensor): Mesh vertices of shape (*batch, num_vertices, 3).
18
+ faces (Tensor): Mesh faces of shape (num_faces, 3), containing indices into `vertices`.
19
+ face_indices_with_landmarks (Tensor): Indices of faces containing the landmarks, shape (num_landmarks,).
20
+ barys (Tensor): Barycentric coordinates of the landmarks in their respective faces,
21
+ shape (num_landmarks, 3). The last dimension should sum to 1.0.
22
+
23
+ Returns:
24
+ Tensor: Landmark positions of shape (*batch, num_landmarks, 3).
25
+ """
26
+ did_unsqueeze = False
27
+ if vertices.ndim == 2: # Support no batch dimension case
28
+ vertices = vertices.unsqueeze(0)
29
+ did_unsqueeze = True
30
+
31
+ batch_dims = vertices.shape[:-2]
32
+
33
+ # Select the faces that contain the landmarks
34
+ relevant_faces = faces[face_indices_with_landmarks]
35
+
36
+ # Select vertices corresponding to relevant faces
37
+ selected_vertices = torch.index_select(vertices, len(batch_dims), relevant_faces.view(-1)).view(
38
+ *batch_dims, *relevant_faces.shape, 3
39
+ )
40
+
41
+ # Compute landmark positions using barycentric interpolation
42
+ landmark_positions = torch.einsum("b...lvx,lv->b...lx", selected_vertices, barys)
43
+
44
+ if did_unsqueeze:
45
+ landmark_positions = landmark_positions[0]
46
+
47
+ return landmark_positions
48
+
49
+
50
+ def vertices_to_7_lmks(
51
+ vertices: Tensor,
52
+ flame_faces: Tensor,
53
+ face_alignment_lmk_faces_idx: Tensor,
54
+ face_alignment_lmk_bary_coords: Tensor,
55
+ ) -> Tuple[Tensor, Tensor]:
56
+ """
57
+ Extract the 7 specific 3D landmarks (and all landmarks) from mesh vertices.
58
+
59
+ Args:
60
+ vertices (Tensor): Mesh vertices of shape (*batch, num_vertices, 3).
61
+ flame_faces (Tensor): Mesh faces of shape (num_faces, 3).
62
+ face_alignment_lmk_faces_idx (Tensor): Indices of faces that contain facial landmarks.
63
+ face_alignment_lmk_bary_coords (Tensor): Barycentric coordinates of landmarks within faces.
64
+
65
+ Returns:
66
+ Tuple[Tensor, Tensor]:
67
+ - lmks7_3d: Landmark positions for 7 specific points, shape (*batch, 7, 3).
68
+ - lmks_3d: Landmark positions for all landmarks, shape (*batch, num_landmarks, 3).
69
+ """
70
+ lmks_3d = vertices_to_landmarks(
71
+ vertices,
72
+ flame_faces,
73
+ face_alignment_lmk_faces_idx,
74
+ face_alignment_lmk_bary_coords,
75
+ )
76
+
77
+ # Select landmark subset starting from index 17 (e.g., 51 landmarks)
78
+ landmark_51 = lmks_3d[:, 17:]
79
+
80
+ # Extract specific 7 landmark indices
81
+ lmks7_3d = landmark_51[:, [19, 22, 25, 28, 16, 31, 37]]
82
+
83
+ return lmks7_3d, lmks_3d
84
+
85
+
86
+ def landmarks_2_face_bounding_box(
87
+ landmarks: Tensor,
88
+ valid: Tensor,
89
+ margin: float = 0.1,
90
+ clamp: bool = True,
91
+ shift_up: float = 0.0,
92
+ too_small_threshold: float = 0.02,
93
+ aspect_ratio: float = 1.0,
94
+ ) -> Tensor:
95
+ """
96
+ Calculate a square bounding box around face landmarks with a specified margin for batched inputs.
97
+
98
+ Parameters:
99
+ - landmarks: torch.Tensor of shape [B1,...,BN,L,2], normalized face landmarks.
100
+ - valid: torch.Tensor of shape [B1,...,BN], boolean indicating validity of each entry.
101
+ - margin: float, margin factor to expand the bounding box around the face.
102
+ - clamp: bool, whether to clamp the bounding box to [0, 1].
103
+ - shift_up: float, factor to shift the bounding box up.
104
+ - too_small_threshold: float, threshold for the bounding box size.
105
+ - aspect_ratio: float, aspect ratio of the image that the landmarks live on (width / height).
106
+ The box size will be divided by this value, under the assumption that you are going to
107
+ multiply these normalised coordinates by the image width later.
108
+
109
+ Returns:
110
+ - bbox: torch.Tensor of shape [B1,...,BN,4] representing the square bounding box.
111
+ """
112
+ # Calculate min and max coordinates along the last dimension for x and y
113
+ min_coords, _ = landmarks.min(dim=-2)
114
+ max_coords, _ = landmarks.max(dim=-2)
115
+
116
+ # Calculate the center and size of the bounding box
117
+ center_coords = (min_coords + max_coords) / 2
118
+ half_size = ((max_coords - min_coords).max(dim=-1).values) / 2
119
+ not_too_small = half_size > too_small_threshold
120
+ valid = valid & not_too_small
121
+
122
+ # Apply margin
123
+ shift_up = shift_up * half_size
124
+ half_size *= 1 + margin
125
+
126
+ # Calculate the square bounding box coordinates
127
+ x_min = center_coords[..., 0] - half_size / aspect_ratio
128
+ x_max = center_coords[..., 0] + half_size / aspect_ratio
129
+ y_min = center_coords[..., 1] - half_size - shift_up
130
+ y_max = center_coords[..., 1] + half_size - shift_up
131
+
132
+ # Stack to get the final bounding box tensor
133
+ bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
134
+
135
+ # Create a full image bounding box of [0, 0, 1, 1]
136
+ full_image_bbox = torch.tensor([0.0, 0.0, 1.0, 1.0], device=landmarks.device)
137
+
138
+ # Overwrite invalid entries with the full image bounding box
139
+ bbox = torch.where(valid.unsqueeze(-1), bbox, full_image_bbox)
140
+
141
+ if clamp:
142
+ return bbox.clamp(0, 1)
143
+ return bbox
sheap/load_flame_pkl.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from pathlib import Path
4
+ from typing import Dict, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def load_pkl_format_flame_model(path: Union[str, os.PathLike, Path]) -> Dict[str, Tensor]:
12
+ """Load a FLAME model from a pickle file format.
13
+
14
+ Loads FLAME model parameters including faces, kinematic tree, joint regressor,
15
+ shape directions, joints, weights, pose directions, and vertex template.
16
+
17
+ Args:
18
+ path: Path to the FLAME model pickle file.
19
+
20
+ Returns:
21
+ Dictionary containing FLAME model parameters as tensors.
22
+ """
23
+ flame_params: Dict[str, Tensor] = {}
24
+ with open(path, "rb") as f:
25
+ flame_data = pickle.load(f, encoding="latin1")
26
+ flame_params["faces"] = torch.from_numpy(flame_data["f"].astype("int64"))
27
+ kintree = torch.from_numpy(flame_data["kintree_table"].astype("int64"))
28
+ kintree[kintree > 100] = -1
29
+ flame_params["kintree"] = kintree
30
+ flame_params["J_regressor"] = torch.from_numpy(
31
+ flame_data["J_regressor"].toarray().astype("float32")
32
+ )
33
+ for thing in ["shapedirs", "J", "weights", "posedirs", "v_template"]:
34
+ flame_params[thing] = torch.from_numpy(np.array(flame_data[thing]).astype("float32"))
35
+ return flame_params
sheap/load_model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from pathlib import Path
3
+ from typing import Dict, Literal
4
+
5
+ import torch
6
+
7
+ # Map model types to filenames and (optional) download URLs
8
+ MODEL_INFO: Dict[str, Dict[str, str]] = {
9
+ "paper": {
10
+ "filename": "model_paper.pt",
11
+ "url": "https://github.com/nlml/sheap/releases/download/v1.0.0/model_paper.pt",
12
+ },
13
+ "expressive": {
14
+ "filename": "model_expressive.pt",
15
+ "url": "https://github.com/nlml/sheap/releases/download/v1.0.0/model_expressive.pt",
16
+ },
17
+ }
18
+
19
+
20
+ def ensure_model_downloaded(
21
+ model_type: Literal["paper", "expressive"] = "paper", models_dir: Path = Path("./models")
22
+ ) -> None:
23
+ """Ensure the requested model is present locally, downloading if needed.
24
+
25
+ Args:
26
+ model_type: Which model variant to use. Valid options are "paper" or "expressive".
27
+ Default is "paper".
28
+ models_dir: Directory where models are stored. Default is "./models".
29
+
30
+ Raises:
31
+ ValueError: If model_type is not recognized.
32
+ FileNotFoundError: If model file is not found and no download URL is configured.
33
+ """
34
+ if model_type not in MODEL_INFO:
35
+ valid = ", ".join(MODEL_INFO.keys())
36
+ raise ValueError(f"Unknown model_type '{model_type}'. Valid options: {valid}")
37
+
38
+ models_dir = Path(models_dir)
39
+ filename = MODEL_INFO[model_type]["filename"]
40
+ url = MODEL_INFO[model_type]["url"]
41
+ model_path = models_dir / filename
42
+
43
+ if model_path.exists():
44
+ return
45
+
46
+ # If we don't have a URL
47
+ if not url:
48
+ raise FileNotFoundError(
49
+ f"Model file '{model_path}' not found and no download URL is configured for "
50
+ f"model_type='{model_type}'. Place the file manually or update MODEL_INFO with a valid URL."
51
+ )
52
+
53
+ print(f"Downloading '{model_type}' model to {model_path}...")
54
+ model_path.parent.mkdir(parents=True, exist_ok=True)
55
+ urllib.request.urlretrieve(url, model_path)
56
+
57
+
58
+ def load_sheap_model(
59
+ model_type: Literal["paper", "expressive"] = "paper", models_dir: Path = Path("./models")
60
+ ) -> torch.jit.ScriptModule:
61
+ """Load the SHeaP model as a PyTorch JIT trace.
62
+
63
+ The function will download the model if it is not present locally (if a URL is
64
+ configured for the selected model_type).
65
+
66
+ Args:
67
+ model_type: Which model variant to load. Valid options are "paper" or "expressive".
68
+ Default is "paper" for backward compatibility.
69
+ models_dir: Directory where models are stored. Default is "./models".
70
+
71
+ Returns:
72
+ The loaded SHeaP model as a PyTorch JIT ScriptModule.
73
+
74
+ Raises:
75
+ ValueError: If model_type is not recognized.
76
+ """
77
+ if model_type not in MODEL_INFO:
78
+ valid = ", ".join(MODEL_INFO.keys())
79
+ raise ValueError(f"Unknown model_type '{model_type}'. Valid options: {valid}")
80
+
81
+ models_dir = Path(models_dir)
82
+ ensure_model_downloaded(model_type=model_type, models_dir=models_dir)
83
+ filename = MODEL_INFO[model_type]["filename"]
84
+ sheap_model = torch.jit.load(models_dir / filename)
85
+ return sheap_model
sheap/py.typed ADDED
File without changes
sheap/render.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import numpy as np
4
+ import pyrender
5
+ import torch
6
+ import trimesh
7
+
8
+
9
+ def render_mesh(
10
+ verts: Union[np.ndarray, torch.Tensor],
11
+ faces: Union[np.ndarray, torch.Tensor],
12
+ c2w: Union[np.ndarray, torch.Tensor],
13
+ img_width: int = 512,
14
+ img_height: int = 512,
15
+ fov_degrees: Union[float, int] = 14.2539,
16
+ render_normals: bool = True,
17
+ ) -> Tuple[np.ndarray, np.ndarray]:
18
+ """Render a mesh using pyrender with a perspective camera defined by FOV.
19
+
20
+ Args:
21
+ verts: Mesh vertex positions of shape (N, 3).
22
+ faces: Triangle vertex indices of shape (F, 3).
23
+ c2w: Camera-to-world transform matrix (extrinsics) of shape (4, 4).
24
+ img_width: Rendered image width in pixels. Default is 512.
25
+ img_height: Rendered image height in pixels. Default is 512.
26
+ fov_degrees: Vertical field of view in degrees. Default is 14.2539.
27
+ render_normals: If True, render normals as RGB. If False, render with lighting. Default is True.
28
+
29
+ Returns:
30
+ Tuple containing:
31
+ - color: RGB image from the render of shape (H, W, 3) as uint8.
32
+ - depth: Depth map from the render of shape (H, W) as float32.
33
+ """
34
+ if isinstance(c2w, torch.Tensor):
35
+ c2w = c2w.detach().cpu().numpy()
36
+ if isinstance(verts, torch.Tensor):
37
+ verts = verts.detach().cpu().numpy()
38
+ if isinstance(faces, torch.Tensor):
39
+ faces = faces.detach().cpu().numpy()
40
+ if not isinstance(fov_degrees, (float, int)):
41
+ fov_degrees = float(fov_degrees)
42
+
43
+ # Convert degrees to radians
44
+ yfov = np.deg2rad(fov_degrees)
45
+
46
+ # Create trimesh mesh
47
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces)
48
+
49
+ if render_normals:
50
+ # Get vertex normals and map to RGB colors
51
+ # Trimesh automatically computes normals when accessed
52
+ normals = mesh.vertex_normals
53
+ # Transform normals to camera space
54
+ w2c = np.linalg.inv(c2w)
55
+ normals_camera = normals @ w2c[:3, :3].T
56
+ # Map from [-1, 1] to [0, 255] for RGB
57
+ vertex_colors = ((normals_camera + 1.0) * 0.5 * 255).astype(np.uint8)
58
+ mesh.visual.vertex_colors = vertex_colors
59
+
60
+ # Convert to pyrender mesh
61
+ render_mesh = pyrender.Mesh.from_trimesh(mesh)
62
+
63
+ # Create scene
64
+ if render_normals:
65
+ scene = pyrender.Scene(ambient_light=[1.0, 1.0, 1.0])
66
+ else:
67
+ scene = pyrender.Scene(ambient_light=[0.3, 0.3, 0.3])
68
+ # Add directional light
69
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
70
+ scene.add(light, pose=c2w)
71
+ scene.add(render_mesh)
72
+
73
+ # Perspective camera
74
+ camera = pyrender.PerspectiveCamera(yfov=yfov, aspectRatio=img_width / img_height)
75
+
76
+ # pyrender expects camera-to-world
77
+ scene.add(camera, pose=c2w)
78
+
79
+ # Offscreen render
80
+ renderer = pyrender.OffscreenRenderer(viewport_width=img_width, viewport_height=img_height)
81
+ color, depth = renderer.render(scene)
82
+
83
+ return color, depth
sheap/tiny_flame.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from roma import rotvec_to_rotmat
6
+ from torch import nn
7
+
8
+
9
+ class TinyFlame(nn.Module):
10
+ v_template: torch.Tensor
11
+ J_regressor: torch.Tensor
12
+ shapedirs: torch.Tensor
13
+ posedirs: torch.Tensor
14
+ weights: torch.Tensor
15
+ faces: torch.Tensor
16
+ kintree: torch.Tensor
17
+
18
+ def __init__(
19
+ self,
20
+ ckpt: Path | str,
21
+ eyelids_ckpt: Path | str | None = None,
22
+ ) -> None:
23
+ """A tiny version of the FLAME model that is compatible with ONNX."""
24
+ super().__init__()
25
+
26
+ # Load the FLAME model weights
27
+ ckpt = Path(ckpt).expanduser()
28
+ data = torch.load(ckpt)
29
+
30
+ for name, tensor in data.items():
31
+ self.register_buffer(name, tensor)
32
+
33
+ # Load the eyelids blendshapes if provided
34
+ if eyelids_ckpt is not None:
35
+ eyelids_ckpt = Path(eyelids_ckpt).expanduser()
36
+ eyelids_data = torch.load(eyelids_ckpt)
37
+
38
+ self.register_buffer("eyelids_dirs", eyelids_data)
39
+ else:
40
+ self.eyelids_dirs = None
41
+
42
+ # To work around the limitation of TorchDynamo, we need to convert kinematic tree to a list,
43
+ # such that it is treated as a constant.
44
+ self.parents = self.kintree[0].tolist()
45
+
46
+ def forward(
47
+ self,
48
+ shape: torch.Tensor,
49
+ expression: torch.Tensor,
50
+ pose: torch.Tensor,
51
+ translation: torch.Tensor,
52
+ eyelids: torch.Tensor | None = None,
53
+ ) -> torch.Tensor:
54
+ """Convert FLAME parameters to coordinates of FLAME vertices.
55
+
56
+ Args:
57
+ - shape (torch.Tensor): Shape parameters of the FLAME model with shape (N, 300).
58
+ - expression (torch.Tensor): Expression parameters of the FLAME model with shape (N, 100).
59
+ - pose (torch.Tensor): Pose parameters of the FLAME model as 3x3 matrices with shape (N, 5, 3, 3).
60
+ It is the concatenation of torso pose (global rotation), neck pose, jaw pose,
61
+ and left/right eye poses.
62
+ - translation (torch.Tensor): Global translation parameters of the FLAME model with shape (N, 3).
63
+ - eyelids (torch.Tensor): Eyelids blendshape parameters with shape (N, 2).
64
+
65
+ Returns:
66
+ - vertices (torch.Tensor): The vertices of the FLAME model with shape (N, V, 3).
67
+ """
68
+ # Some common variables
69
+ batch_size = shape.shape[0]
70
+ num_joints = len(self.parents)
71
+
72
+ # Step1: compute T per equations (2)-(5) in the paper
73
+ # Compute the shape offsets from the shape and the expression parameters
74
+ shape_expr = torch.cat([shape, expression], -1)
75
+ shape_expr_offsets = (self.shapedirs @ shape_expr.t()).permute(2, 0, 1)
76
+
77
+ # Get the vertex offsets due to pose blendshapes
78
+ pose_features = pose[:, 1:, :, :] - torch.eye(3, device=pose.device)
79
+ pose_features = pose_features.view(batch_size, -1)
80
+ pose_offsets = (self.posedirs @ pose_features.t()).permute(2, 0, 1)
81
+
82
+ # Add offsets to the template mesh to get T
83
+ shaped_vertices = self.v_template.expand_as(shape_expr_offsets) + shape_expr_offsets
84
+ if eyelids is not None and self.eyelids_dirs is not None:
85
+ shaped_vertices = shaped_vertices + (self.eyelids_dirs @ eyelids.t()).permute(2, 0, 1)
86
+ shaped_vertices_with_pose_correction = shaped_vertices + pose_offsets
87
+
88
+ # Step2: compute the joint locations per equation (1) in the paper
89
+ # Get the joint locations with the joint regressor
90
+ joint_locations = self.J_regressor @ shaped_vertices
91
+
92
+ # Step3: compute the final mesh vertices per equation (1) in the paper using standard LBS functions.
93
+ # Find the transformation for: unposed FLAME -> joints' local coordinate systems -> posed FLAME
94
+ relative_joint_locations = (
95
+ joint_locations[:, 1:, :] - joint_locations[:, self.parents[1:], :]
96
+ )
97
+ relative_joint_locations = torch.cat(
98
+ [joint_locations[:, :1, :], relative_joint_locations], dim=1
99
+ )
100
+ relative_joint_locations_homogeneous = F.pad(relative_joint_locations, (0, 1), value=1)
101
+
102
+ # joint -> parent joint transformations
103
+ joint_to_parent_transformations = torch.cat(
104
+ [
105
+ F.pad(pose, (0, 0, 0, 1), value=0),
106
+ relative_joint_locations_homogeneous.unsqueeze(-1),
107
+ ],
108
+ dim=-1,
109
+ )
110
+
111
+ joint_to_posed_transformations_ = [joint_to_parent_transformations[:, 0, :, :]]
112
+
113
+ # joint -> posed FLAME transformations
114
+ for i in range(1, num_joints):
115
+ parent_joint = self.parents[i]
116
+
117
+ current_joint_to_posed_transformation = (
118
+ joint_to_posed_transformations_[parent_joint]
119
+ @ joint_to_parent_transformations[:, i, :, :]
120
+ )
121
+
122
+ joint_to_posed_transformations_.append(current_joint_to_posed_transformation)
123
+
124
+ joint_to_posed_transformations = torch.stack(joint_to_posed_transformations_, dim=1)
125
+
126
+ # Unposed FLAME -> joints' local coordinate systems -> posed FLAME transformations
127
+ unposed_to_posed_transformations = joint_to_posed_transformations - F.pad(
128
+ joint_to_posed_transformations @ F.pad(joint_locations, (0, 1), value=0).unsqueeze(-1),
129
+ (3, 0),
130
+ value=0,
131
+ )
132
+
133
+ # Scale rotations and translations by the blend weights
134
+ final_transformations = (self.weights @ unposed_to_posed_transformations.flatten(2)).view(
135
+ batch_size, -1, 4, 4
136
+ )
137
+
138
+ # Apply the transformations to the posed vertices T
139
+ shaped_vertices_with_pose_correction_homogeneous = F.pad(
140
+ shaped_vertices_with_pose_correction, (0, 1), value=1
141
+ )
142
+ posed_vertices = (
143
+ final_transformations @ shaped_vertices_with_pose_correction_homogeneous.unsqueeze(-1)
144
+ )[..., :3, 0] + translation.unsqueeze(1)
145
+
146
+ return posed_vertices
147
+
148
+
149
+ def pose_components_to_rotmats(predictions):
150
+ """
151
+ predictions should contain these 5 keys:
152
+ 'torso_pose', 'neck_pose', 'jaw_pose', 'eye_l_pose', 'eye_r_pose'
153
+ Each of these is expected to be of shape (N, 3) representing rotation vectors.
154
+ This function converts them to rotation matrices and stacks them into a tensor of shape (N, 5, 3, 3).
155
+ """
156
+ pose = torch.stack(
157
+ [
158
+ predictions["torso_pose"],
159
+ predictions["neck_pose"],
160
+ predictions["jaw_pose"],
161
+ predictions["eye_l_pose"],
162
+ predictions["eye_r_pose"],
163
+ ],
164
+ dim=1,
165
+ )
166
+ pose = pose.view(-1, 3)
167
+ pose = rotvec_to_rotmat(pose)
168
+ return pose.view(-1, 5, 3, 3)
teaser.jpg ADDED

Git LFS Details

  • SHA256: 3538cd3c11fd0da5f6422fdae283474a74929100424e4ad9a3b1da844edd4696
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
video_demo.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+ import threading
6
+ from pathlib import Path
7
+ from queue import Empty, Queue
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ import torchvision.transforms.functional as TF
14
+ from PIL import Image
15
+ from torch.utils.data import DataLoader, IterableDataset
16
+ from tqdm import tqdm
17
+
18
+ from demo import create_rendering_image
19
+ from sheap import load_sheap_model
20
+ from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats
21
+
22
+ try:
23
+ import face_alignment
24
+ except ImportError:
25
+ raise ImportError(
26
+ "The 'face_alignment' package is required. Please install it via 'pip install face-alignment'."
27
+ )
28
+ from sheap.fa_landmark_utils import detect_face_and_crop
29
+
30
+
31
+ class RenderingThread(threading.Thread):
32
+ """Background thread for rendering frames to images."""
33
+
34
+ def __init__(
35
+ self,
36
+ render_queue: Queue,
37
+ temp_dir: Path,
38
+ faces: torch.Tensor,
39
+ c2w: torch.Tensor,
40
+ render_size: int,
41
+ ):
42
+ """
43
+ Initialize rendering thread.
44
+
45
+ Args:
46
+ render_queue: Queue containing (frame_idx, cropped_frame, verts) tuples
47
+ temp_dir: Directory to save rendered images
48
+ faces: Face indices tensor from FLAME model
49
+ c2w: Camera-to-world transformation matrix
50
+ render_size: Size of each sub-image in the rendered output
51
+ """
52
+ super().__init__(daemon=True)
53
+ self.render_queue = render_queue
54
+ self.temp_dir = temp_dir
55
+ self.faces = faces
56
+ self.c2w = c2w
57
+ self.render_size = render_size
58
+ self.stop_event = threading.Event()
59
+ self.frames_rendered = 0
60
+
61
+ def run(self):
62
+ """Process rendering queue until stop signal is received."""
63
+ # Set PyOpenGL platform for this thread
64
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
65
+
66
+ while not self.stop_event.is_set():
67
+ try:
68
+ # Get item from queue with timeout to allow checking stop_event
69
+ try:
70
+ item = self.render_queue.get(timeout=0.1)
71
+ except Empty: # Haven't finished, but nothing to render yet
72
+ continue
73
+ if item is None: # Sentinel value to stop
74
+ break
75
+
76
+ frame_idx, cropped_frame, verts = item
77
+ frame_idx, cropped_frame, verts = item
78
+
79
+ # Render the frame
80
+ cropped_pil = Image.fromarray(cropped_frame)
81
+ combined = create_rendering_image(
82
+ original_image=cropped_pil,
83
+ verts=verts,
84
+ faces=self.faces,
85
+ c2w=self.c2w,
86
+ output_size=self.render_size,
87
+ )
88
+
89
+ # Save to temp directory with zero-padded frame number
90
+ output_path = self.temp_dir / f"frame_{frame_idx:06d}.png"
91
+ combined.save(output_path)
92
+
93
+ self.frames_rendered += 1
94
+ self.render_queue.task_done()
95
+
96
+ except Exception as e:
97
+ if not self.stop_event.is_set():
98
+ print(f"Error rendering frame: {e}")
99
+ import traceback
100
+
101
+ traceback.print_exc()
102
+
103
+ def stop(self):
104
+ """Signal the thread to stop."""
105
+ self.stop_event.set()
106
+
107
+
108
+ class VideoFrameDataset(IterableDataset):
109
+ """Iterable dataset for streaming video frames with face detection and cropping."""
110
+
111
+ def __init__(
112
+ self,
113
+ video_path: str,
114
+ fa_model: face_alignment.FaceAlignment,
115
+ smoothing_alpha: float = 0.3,
116
+ ):
117
+ """
118
+ Initialize video frame dataset.
119
+
120
+ Args:
121
+ video_path: Path to video file
122
+ fa_model: FaceAlignment model instance for face detection
123
+ smoothing_alpha: Smoothing factor for bounding box (0=no smoothing, 1=no change).
124
+ Lower values = more smoothing
125
+ """
126
+ super().__init__()
127
+ self.video_path = video_path
128
+ self.fa_model = fa_model
129
+ self.smoothing_alpha = smoothing_alpha
130
+ self.prev_bbox: Optional[Tuple[int, int, int, int]] = None
131
+
132
+ # Get video metadata (don't keep capture open)
133
+ cap = cv2.VideoCapture(video_path)
134
+ if not cap.isOpened():
135
+ raise ValueError(f"Could not open video file: {video_path}")
136
+
137
+ self.fps = cap.get(cv2.CAP_PROP_FPS)
138
+ self.num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
139
+ self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
140
+ self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
141
+ cap.release()
142
+
143
+ print(
144
+ f"Video info: {self.num_frames} frames, {self.fps:.2f} fps, {self.width}x{self.height}"
145
+ )
146
+
147
+ def __iter__(self):
148
+ """
149
+ Iterate through video frames sequentially.
150
+
151
+ Yields:
152
+ Dictionary containing frame_idx, processed image, and bounding box
153
+ """
154
+ # Reset smoothing state for new iteration
155
+ self.prev_bbox = None
156
+
157
+ # Open video capture for this iteration
158
+ cap = cv2.VideoCapture(self.video_path)
159
+ if not cap.isOpened():
160
+ raise RuntimeError(f"Could not open video file: {self.video_path}")
161
+
162
+ frame_idx = 0
163
+ while True:
164
+ # Read frame
165
+ ret, frame_bgr = cap.read()
166
+ if not ret:
167
+ break
168
+
169
+ # Convert BGR to RGB
170
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
171
+
172
+ # Convert to torch tensor (C, H, W) with values in [0, 1]
173
+ image = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
174
+
175
+ # Detect face and crop
176
+ bbox = detect_face_and_crop(image, self.fa_model, margin=0.9, shift_up=0.5)
177
+
178
+ # Apply smoothing using exponential moving average
179
+ bbox = self._smooth_bbox(bbox)
180
+ x0, y0, x1, y1 = bbox
181
+
182
+ cropped = image[:, y0:y1, x0:x1]
183
+
184
+ # Resize to 224x224 for SHEAP model
185
+ cropped_resized = TF.resize(cropped, [224, 224], antialias=True)
186
+ cropped_for_render = TF.resize(cropped, [512, 512], antialias=True)
187
+
188
+ yield {
189
+ "frame_idx": frame_idx,
190
+ "image": cropped_resized,
191
+ "bbox": bbox,
192
+ "original_frame": frame_rgb, # Keep original for reference (as numpy array)
193
+ "cropped_frame": cropped_for_render, # Cropped region resized to 512x512
194
+ }
195
+
196
+ frame_idx += 1
197
+
198
+ cap.release()
199
+
200
+ def _smooth_bbox(self, bbox: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]:
201
+ """Apply exponential moving average smoothing to bounding box."""
202
+ if self.prev_bbox is None:
203
+ self.prev_bbox = bbox
204
+ return bbox
205
+
206
+ x0, y0, x1, y1 = bbox
207
+ prev_x0, prev_y0, prev_x1, prev_y1 = self.prev_bbox
208
+
209
+ # Smooth: new_bbox = alpha * detected_bbox + (1 - alpha) * prev_bbox
210
+ smoothed = (
211
+ int(self.smoothing_alpha * x0 + (1 - self.smoothing_alpha) * prev_x0),
212
+ int(self.smoothing_alpha * y0 + (1 - self.smoothing_alpha) * prev_y0),
213
+ int(self.smoothing_alpha * x1 + (1 - self.smoothing_alpha) * prev_x1),
214
+ int(self.smoothing_alpha * y1 + (1 - self.smoothing_alpha) * prev_y1),
215
+ )
216
+
217
+ self.prev_bbox = smoothed
218
+ return smoothed
219
+
220
+ def __len__(self) -> int:
221
+ return self.num_frames
222
+
223
+
224
+ def process_video(
225
+ video_path: str,
226
+ model_type: str = "expressive",
227
+ batch_size: int = 8,
228
+ num_workers: int = 0,
229
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
230
+ output_video_path: Optional[str] = None,
231
+ render_size: int = 512,
232
+ num_render_workers: int = 16,
233
+ max_queue_size: int = 128,
234
+ ) -> List[Dict[str, Any]]:
235
+ """
236
+ Process video frames through SHEAP model and optionally render output video.
237
+
238
+ Uses an IterableDataset for efficient sequential video processing without seeking overhead.
239
+ Rendering is done in a background thread, and ffmpeg is used to create the final video.
240
+
241
+ Args:
242
+ video_path: Path to video file
243
+ model_type: SHEAP model variant ("paper", "expressive", or "lightweight")
244
+ batch_size: Batch size for processing
245
+ num_workers: Number of workers (0 or 1 only). Will be clamped to max 1.
246
+ device: Device to run model on ("cpu" or "cuda")
247
+ output_video_path: If provided, render and save output video to this path
248
+ render_size: Size of each sub-image in the rendered output
249
+ num_render_workers: Number of background threads for rendering
250
+ max_queue_size: Maximum size of the rendering queue
251
+
252
+ Returns:
253
+ List of dictionaries containing frame index, bounding box, and FLAME parameters
254
+ """
255
+ # Enforce num_workers constraint for IterableDataset
256
+ num_workers = min(num_workers, 1)
257
+ if num_workers > 1:
258
+ print(f"Warning: num_workers > 1 not supported with IterableDataset. Using num_workers=1.")
259
+
260
+ # Load SHEAP model
261
+ print(f"Loading SHEAP model (type: {model_type})...")
262
+ sheap_model = load_sheap_model(model_type=model_type)
263
+ sheap_model.eval()
264
+ sheap_model = sheap_model.to(device)
265
+
266
+ # Load face alignment model
267
+ # Force CPU for FA when using num_workers=1 (subprocess issues with GPU)
268
+ fa_device = "cpu" if num_workers >= 1 else device
269
+ print(f"Loading face alignment model on {fa_device}...")
270
+ fa_model = face_alignment.FaceAlignment(
271
+ face_alignment.LandmarksType.THREE_D, flip_input=False, device=fa_device
272
+ )
273
+
274
+ # Create dataset and dataloader
275
+ dataset = VideoFrameDataset(video_path, fa_model)
276
+ dataloader = DataLoader(
277
+ dataset,
278
+ batch_size=batch_size,
279
+ num_workers=num_workers,
280
+ pin_memory=torch.cuda.is_available(),
281
+ )
282
+
283
+ print(f"Processing {len(dataset)} frames from {video_path}")
284
+
285
+ # Initialize FLAME model and rendering thread if rendering
286
+ flame = None
287
+ rendering_threads = []
288
+ render_queue = None
289
+ temp_dir = None
290
+ c2w = None
291
+
292
+ if output_video_path:
293
+ print("Loading FLAME model for rendering...")
294
+ flame_dir = Path("FLAME2020/")
295
+ flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt")
296
+ flame = flame.to(device) # Move FLAME to GPU
297
+ c2w = torch.tensor(
298
+ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32
299
+ )
300
+
301
+ # Create temporary directory for rendered frames
302
+ temp_dir = Path("./temp_sheap_render/")
303
+ temp_dir.mkdir(parents=True, exist_ok=True)
304
+ print(f"Using temporary directory: {temp_dir}")
305
+
306
+ # Start multiple background rendering threads
307
+ render_queue = Queue(maxsize=max_queue_size)
308
+ for _ in range(num_render_workers):
309
+ thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size)
310
+ thread.start()
311
+ rendering_threads.append(thread)
312
+ print(f"Started {num_render_workers} background rendering threads")
313
+
314
+ results = []
315
+ frame_count = 0
316
+
317
+ with torch.no_grad():
318
+ progbar = tqdm(total=len(dataset), desc="Processing frames")
319
+ for batch in dataloader:
320
+ frame_indices = batch["frame_idx"]
321
+ images = batch["image"].to(device)
322
+ bboxes = batch["bbox"]
323
+
324
+ # Process through SHEAP model
325
+ flame_params_dict = sheap_model(images)
326
+
327
+ # Generate vertices for this batch if rendering
328
+ if output_video_path and flame is not None:
329
+ verts = flame(
330
+ shape=flame_params_dict["shape_from_facenet"],
331
+ expression=flame_params_dict["expr"],
332
+ pose=pose_components_to_rotmats(flame_params_dict),
333
+ eyelids=flame_params_dict["eyelids"],
334
+ translation=flame_params_dict["cam_trans"],
335
+ )
336
+
337
+ # Store results and queue for rendering
338
+ for i in range(len(frame_indices)):
339
+ frame_idx = _extract_scalar(frame_indices[i])
340
+ bbox = tuple(_extract_scalar(b[i]) for b in bboxes)
341
+
342
+ result = {
343
+ "frame_idx": frame_idx,
344
+ "bbox": bbox,
345
+ "flame_params": {k: v[i].cpu() for k, v in flame_params_dict.items()},
346
+ }
347
+ results.append(result)
348
+
349
+ # Queue frame for rendering
350
+ if output_video_path:
351
+ cropped_frame = _tensor_to_numpy_image(batch["cropped_frame"][i])
352
+ render_queue.put((frame_idx, cropped_frame, verts[i].cpu()))
353
+ frame_count += 1
354
+
355
+ progbar.update(len(frame_indices))
356
+ progbar.close()
357
+
358
+ # Finalize rendering and create output video
359
+ if output_video_path and render_queue is not None:
360
+ _finalize_rendering(
361
+ rendering_threads,
362
+ render_queue,
363
+ num_render_workers,
364
+ temp_dir,
365
+ dataset.fps,
366
+ output_video_path,
367
+ )
368
+
369
+ return results
370
+
371
+
372
+ def _extract_scalar(value: Any) -> int:
373
+ """Extract scalar integer from tensor or return as-is."""
374
+ return value.item() if isinstance(value, torch.Tensor) else value
375
+
376
+
377
+ def _tensor_to_numpy_image(tensor: torch.Tensor) -> np.ndarray:
378
+ """Convert (C, H, W) tensor [0, 1] to numpy (H, W, C) uint8 [0, 255]."""
379
+ if not isinstance(tensor, torch.Tensor):
380
+ return tensor
381
+ return (tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
382
+
383
+
384
+ def _finalize_rendering(
385
+ rendering_threads: List[RenderingThread],
386
+ render_queue: Queue,
387
+ num_render_workers: int,
388
+ temp_dir: Path,
389
+ fps: float,
390
+ output_video_path: str,
391
+ ) -> None:
392
+ """Finish rendering threads and create final video with ffmpeg."""
393
+ print("\nWaiting for rendering threads to complete...")
394
+
395
+ # Add sentinel values to stop workers
396
+ for _ in range(num_render_workers):
397
+ render_queue.put(None)
398
+
399
+ # Wait for all threads to finish
400
+ for thread in rendering_threads:
401
+ thread.join()
402
+
403
+ total_rendered = sum(thread.frames_rendered for thread in rendering_threads)
404
+ print(f"Rendered {total_rendered} frames")
405
+
406
+ # Create video with ffmpeg
407
+ print("Creating video with ffmpeg...")
408
+ output_path = Path(output_video_path)
409
+ output_path.parent.mkdir(parents=True, exist_ok=True)
410
+
411
+ ffmpeg_cmd = [
412
+ "ffmpeg",
413
+ "-y", # Overwrite output file if it exists
414
+ "-framerate",
415
+ str(fps),
416
+ "-i",
417
+ str(temp_dir / "frame_%06d.png"),
418
+ "-c:v",
419
+ "libx264",
420
+ "-pix_fmt",
421
+ "yuv420p",
422
+ "-preset",
423
+ "medium",
424
+ "-crf",
425
+ "23",
426
+ str(output_path),
427
+ ]
428
+
429
+ subprocess.run(ffmpeg_cmd, check=True, capture_output=True)
430
+ print(f"Video saved to: {output_video_path}")
431
+
432
+ # Clean up temporary directory
433
+ if temp_dir.exists():
434
+ print(f"Removing temporary directory: {temp_dir}")
435
+ shutil.rmtree(temp_dir)
436
+ print("Cleanup complete")
437
+
438
+
439
+ if __name__ == "__main__":
440
+ # video_path = "skarsgard.mp4"
441
+ # output_video_path = "skarsgard_rendered.mp4"
442
+ parser = argparse.ArgumentParser(description="Process and render video with SHEAP model.")
443
+ parser.add_argument("in_path", type=str, help="Path to input video file.")
444
+ parser.add_argument(
445
+ "--out_path", type=str, help="Path to save rendered output video.", default=None
446
+ )
447
+ args = parser.parse_args()
448
+
449
+ if args.out_path is None:
450
+ args.out_path = str(Path(args.in_path).with_name(f"{Path(args.in_path).stem}_rendered.mp4"))
451
+
452
+ device = "cuda" if torch.cuda.is_available() else "cpu"
453
+ print(f"Using device: {device}")
454
+
455
+ results = process_video(
456
+ video_path=args.in_path,
457
+ model_type="expressive",
458
+ device=device,
459
+ output_video_path=args.out_path,
460
+ )