|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
FaceLift: Single Image 3D Face Reconstruction |
|
|
Generates 3D head models from single images using multi-view diffusion and GS-LRM. |
|
|
|
|
|
Note: To enable the interactive 3D viewer, this Space needs write access to wlyu/FaceLift_demo. |
|
|
Set the HF_TOKEN environment variable in Space settings with a token that has write access. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") == "1": |
|
|
try: |
|
|
import hf_transfer |
|
|
except ImportError: |
|
|
print("⚠️ hf_transfer not available, disabling fast download") |
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import random |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
from easydict import EasyDict as edict |
|
|
from einops import rearrange |
|
|
from PIL import Image |
|
|
from huggingface_hub import snapshot_download, HfApi |
|
|
import spaces |
|
|
|
|
|
|
|
|
import subprocess |
|
|
import sys |
|
|
|
|
|
|
|
|
OUTPUTS_DIR = Path.cwd() / "outputs" |
|
|
OUTPUTS_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
def _log_viewer_file(ply_path: Path): |
|
|
"""Print a concise JSON line about the viewer file so users can debug from Space logs.""" |
|
|
info = { |
|
|
"ply_path": str(Path(ply_path).absolute()), |
|
|
"exists": Path(ply_path).exists(), |
|
|
"size_bytes": (Path(ply_path).stat().st_size if Path(ply_path).exists() else None) |
|
|
} |
|
|
print("[VIEWER-RETURN]", json.dumps(info)) |
|
|
|
|
|
def upload_ply_to_hf(ply_path: Path, repo_id: str = "wlyu/FaceLift_demo") -> str: |
|
|
"""Upload PLY file to HuggingFace and return the public URL.""" |
|
|
try: |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
if not hf_token: |
|
|
print("⚠️ No HF_TOKEN found in environment, skipping upload") |
|
|
return None |
|
|
|
|
|
api = HfApi(token=hf_token) |
|
|
ply_filename = ply_path.name |
|
|
|
|
|
|
|
|
path_in_repo = f"tmp_ply/{ply_filename}" |
|
|
|
|
|
print(f"Uploading {ply_filename} to HuggingFace...") |
|
|
api.upload_file( |
|
|
path_or_fileobj=str(ply_path), |
|
|
path_in_repo=path_in_repo, |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
token=hf_token, |
|
|
) |
|
|
|
|
|
|
|
|
hf_url = f"https://huggingface.co/{repo_id}/resolve/main/{path_in_repo}" |
|
|
print(f"✓ Uploaded to: {hf_url}") |
|
|
return hf_url |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to upload to HuggingFace: {e}") |
|
|
print(" Make sure the Space has write access to the repository") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import diff_gaussian_rasterization |
|
|
except ImportError: |
|
|
print("Installing diff-gaussian-rasterization (compiling for detected CUDA arch)...") |
|
|
env = os.environ.copy() |
|
|
try: |
|
|
import torch as _torch |
|
|
if _torch.cuda.is_available(): |
|
|
maj, minr = _torch.cuda.get_device_capability() |
|
|
arch = f"{maj}.{minr}" |
|
|
env["TORCH_CUDA_ARCH_LIST"] = f"{arch}+PTX" |
|
|
else: |
|
|
|
|
|
env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX" |
|
|
except Exception: |
|
|
env["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;8.9;9.0+PTX" |
|
|
|
|
|
|
|
|
env.setdefault("PYTORCH_NO_CUDA_MEMORY_CACHING", "1") |
|
|
|
|
|
subprocess.check_call( |
|
|
[sys.executable, "-m", "pip", "install", |
|
|
"git+https://github.com/graphdeco-inria/diff-gaussian-rasterization"], |
|
|
env=env, |
|
|
) |
|
|
import diff_gaussian_rasterization |
|
|
|
|
|
|
|
|
from gslrm.model.gaussians_renderer import render_turntable, imageseq2video |
|
|
from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline |
|
|
from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping |
|
|
|
|
|
|
|
|
HF_REPO_ID = "wlyu/OpenFaceLift" |
|
|
|
|
|
def download_weights_from_hf() -> Path: |
|
|
"""Download model weights from HuggingFace if not already present. |
|
|
|
|
|
Returns: |
|
|
Path to the downloaded repository |
|
|
""" |
|
|
workspace_dir = Path(__file__).parent |
|
|
|
|
|
|
|
|
mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts" |
|
|
gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt" |
|
|
|
|
|
if mvdiffusion_path.exists() and gslrm_path.exists(): |
|
|
print("Using local model weights") |
|
|
return workspace_dir |
|
|
|
|
|
print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}") |
|
|
print("This may take a few minutes on first run...") |
|
|
|
|
|
|
|
|
snapshot_download( |
|
|
repo_id=HF_REPO_ID, |
|
|
local_dir=str(workspace_dir / "checkpoints"), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
print("Model weights downloaded successfully!") |
|
|
return workspace_dir |
|
|
|
|
|
class FaceLiftPipeline: |
|
|
"""Pipeline for FaceLift 3D head generation from single images.""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
workspace_dir = download_weights_from_hf() |
|
|
|
|
|
|
|
|
self.output_dir = workspace_dir / "outputs" |
|
|
self.examples_dir = workspace_dir / "examples" |
|
|
self.output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
self.image_size = 512 |
|
|
self.camera_indices = [2, 1, 0, 5, 4, 3] |
|
|
|
|
|
|
|
|
print("Loading models... (gradio", getattr(gr, "__version__", "unknown"), ")") |
|
|
try: |
|
|
self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( |
|
|
str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"), |
|
|
torch_dtype=torch.float16, |
|
|
) |
|
|
|
|
|
self._models_on_gpu = False |
|
|
|
|
|
with open(workspace_dir / "configs/gslrm.yaml", "r") as f: |
|
|
config = edict(yaml.safe_load(f)) |
|
|
|
|
|
module_name, class_name = config.model.class_name.rsplit(".", 1) |
|
|
module = __import__(module_name, fromlist=[class_name]) |
|
|
ModelClass = getattr(module, class_name) |
|
|
|
|
|
self.gs_lrm_model = ModelClass(config) |
|
|
checkpoint = torch.load( |
|
|
workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt", |
|
|
map_location="cpu" |
|
|
) |
|
|
|
|
|
state_dict = {k: v for k, v in checkpoint["model"].items() |
|
|
if not k.startswith("loss_calculator.")} |
|
|
self.gs_lrm_model.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
self.color_prompt_embedding = torch.load( |
|
|
workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt", |
|
|
map_location="cpu" |
|
|
) |
|
|
|
|
|
with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f: |
|
|
self.cameras_data = json.load(f)["frames"] |
|
|
|
|
|
print("Models loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading models: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
def _move_models_to_gpu(self): |
|
|
"""Move models to GPU and enable optimizations. Called within @spaces.GPU context.""" |
|
|
if not self._models_on_gpu and torch.cuda.is_available(): |
|
|
print("Moving models to GPU...") |
|
|
self.device = torch.device("cuda:0") |
|
|
self.mvdiffusion_pipeline.to(self.device) |
|
|
self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention() |
|
|
self.gs_lrm_model.to(self.device) |
|
|
self.gs_lrm_model.eval() |
|
|
self.color_prompt_embedding = self.color_prompt_embedding.to(self.device) |
|
|
self._models_on_gpu = True |
|
|
torch.cuda.empty_cache() |
|
|
print("Models on GPU, xformers enabled!") |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0, |
|
|
random_seed=4, num_steps=50): |
|
|
"""Generate 3D head from single image.""" |
|
|
try: |
|
|
|
|
|
self._move_models_to_gpu() |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
output_dir = self.output_dir / timestamp |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
original_img = np.array(Image.open(image_path)) |
|
|
input_image = preprocess_image(original_img) if auto_crop else \ |
|
|
preprocess_image_without_cropping(original_img) |
|
|
|
|
|
if input_image.size != (self.image_size, self.image_size): |
|
|
input_image = input_image.resize((self.image_size, self.image_size)) |
|
|
|
|
|
input_path = output_dir / "input.png" |
|
|
input_image.save(input_path) |
|
|
|
|
|
|
|
|
generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device) |
|
|
generator.manual_seed(random_seed) |
|
|
|
|
|
result = self.mvdiffusion_pipeline( |
|
|
input_image, None, |
|
|
prompt_embeds=self.color_prompt_embedding, |
|
|
height=self.image_size, |
|
|
width=self.image_size, |
|
|
guidance_scale=guidance_scale, |
|
|
num_images_per_prompt=1, |
|
|
num_inference_steps=num_steps, |
|
|
generator=generator, |
|
|
eta=1.0, |
|
|
) |
|
|
|
|
|
selected_views = result.images[:6] |
|
|
|
|
|
|
|
|
multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size)) |
|
|
for i, view in enumerate(selected_views): |
|
|
multiview_image.paste(view, (self.image_size * i, 0)) |
|
|
|
|
|
multiview_path = output_dir / "multiview.png" |
|
|
multiview_image.save(multiview_path) |
|
|
|
|
|
|
|
|
print("Moving diffusion model to CPU to free memory...") |
|
|
self.mvdiffusion_pipeline.to("cpu") |
|
|
|
|
|
|
|
|
del result, generator |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
view_arrays = [np.array(view) for view in selected_views] |
|
|
lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float() |
|
|
lrm_input = lrm_input[None].to(self.device) / 255.0 |
|
|
lrm_input = rearrange(lrm_input, "b v h w c -> b v c h w") |
|
|
|
|
|
|
|
|
selected_cameras = [self.cameras_data[i] for i in self.camera_indices] |
|
|
fxfycxcy_list = [[c["fx"], c["fy"], c["cx"], c["cy"]] for c in selected_cameras] |
|
|
c2w_list = [np.linalg.inv(np.array(c["w2c"])) for c in selected_cameras] |
|
|
|
|
|
fxfycxcy = torch.from_numpy(np.stack(fxfycxcy_list, axis=0).astype(np.float32)) |
|
|
c2w = torch.from_numpy(np.stack(c2w_list, axis=0).astype(np.float32)) |
|
|
fxfycxcy = fxfycxcy[None].to(self.device) |
|
|
c2w = c2w[None].to(self.device) |
|
|
|
|
|
batch_indices = torch.stack([ |
|
|
torch.zeros(lrm_input.size(1)).long(), |
|
|
torch.arange(lrm_input.size(1)).long(), |
|
|
], dim=-1)[None].to(self.device) |
|
|
|
|
|
batch = edict({ |
|
|
"image": lrm_input, |
|
|
"c2w": c2w, |
|
|
"fxfycxcy": fxfycxcy, |
|
|
"index": batch_indices, |
|
|
}) |
|
|
|
|
|
|
|
|
if next(self.gs_lrm_model.parameters()).device.type == "cpu": |
|
|
print("Moving GS-LRM model to GPU...") |
|
|
self.gs_lrm_model.to(self.device) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
with torch.no_grad(), torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16): |
|
|
result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True) |
|
|
|
|
|
comp_image = result.render[0].unsqueeze(0).detach() |
|
|
gaussians = result.gaussians[0] |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
filtered_gaussians = gaussians.apply_all_filters( |
|
|
cam_origins=None, |
|
|
opacity_thres=0.04, |
|
|
scaling_thres=0.2, |
|
|
floater_thres=0.75, |
|
|
crop_bbx=[-0.91, 0.91, -0.91, 0.91, -1.0, 1.0], |
|
|
nearfar_percent=(0.0001, 1.0), |
|
|
) |
|
|
|
|
|
|
|
|
random_id = random.randint(0, 999) |
|
|
ply_filename = f"gaussians_{random_id:03d}.ply" |
|
|
ply_path = output_dir / ply_filename |
|
|
filtered_gaussians.save_ply(str(ply_path)) |
|
|
|
|
|
|
|
|
comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c") |
|
|
comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8) |
|
|
output_path = output_dir / "output.png" |
|
|
Image.fromarray(comp_image).save(output_path) |
|
|
|
|
|
|
|
|
turntable_resolution = 512 |
|
|
num_turntable_views = 180 |
|
|
turntable_frames = render_turntable(gaussians, rendering_resolution=turntable_resolution, |
|
|
num_views=num_turntable_views) |
|
|
turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=num_turntable_views) |
|
|
turntable_frames = np.ascontiguousarray(turntable_frames) |
|
|
|
|
|
turntable_path = output_dir / "turntable.mp4" |
|
|
imageseq2video(turntable_frames, str(turntable_path), fps=30) |
|
|
|
|
|
|
|
|
_log_viewer_file(ply_path) |
|
|
|
|
|
|
|
|
hf_ply_url = upload_ply_to_hf(ply_path) |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if hf_ply_url: |
|
|
|
|
|
viewer_url = f"https://www.wlyu.me/FaceLift/splat/index.html?url={hf_ply_url}" |
|
|
|
|
|
viewer_html = f""" |
|
|
<div style="width:100%; height:600px; position:relative; border-radius:8px; overflow:hidden; border:1px solid #333; background:#000;"> |
|
|
<iframe |
|
|
src="{viewer_url}" |
|
|
style="width:100%; height:100%; border:none;" |
|
|
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" |
|
|
allowfullscreen> |
|
|
</iframe> |
|
|
</div> |
|
|
<div style="text-align:center; margin-top:10px; padding:10px;"> |
|
|
<a href="{viewer_url}" |
|
|
target="_blank" |
|
|
style="display:inline-block; color:#fff; background:#4CAF50; padding:10px 20px; text-decoration:none; font-size:14px; border-radius:6px; font-weight:500;"> |
|
|
🎮 Open Interactive Viewer in New Tab |
|
|
</a> |
|
|
<p style="color:#666; font-size:12px; margin-top:8px;"> |
|
|
Drag to rotate • Scroll to zoom • Right-click to pan |
|
|
</p> |
|
|
</div> |
|
|
""" |
|
|
else: |
|
|
|
|
|
viewer_base_url = "https://www.wlyu.me/FaceLift/splat/index.html" |
|
|
|
|
|
viewer_html = f""" |
|
|
<div style="padding:40px; text-align:center; background:#f5f5f5; border-radius:8px; border:1px solid #ddd;"> |
|
|
<div style="font-size:48px; margin-bottom:20px;">🎮</div> |
|
|
<h3 style="margin:0 0 15px 0; color:#333;">Interactive 3D Viewer</h3> |
|
|
<p style="color:#666; margin-bottom:25px; line-height:1.6;"> |
|
|
Download the PLY file below, then drag and drop it into the viewer<br> |
|
|
or use the viewer with a public URL |
|
|
</p> |
|
|
<a href="{viewer_base_url}" |
|
|
target="_blank" |
|
|
style="display:inline-block; color:#fff; background:#4CAF50; padding:12px 24px; text-decoration:none; font-size:15px; border-radius:6px; font-weight:500; margin-bottom:15px;"> |
|
|
🔗 Open Interactive Viewer |
|
|
</a> |
|
|
<p style="color:#888; font-size:13px; margin-top:15px;"> |
|
|
<strong>Controls:</strong> Drag to rotate • Scroll to zoom • Right-click to pan |
|
|
</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
return ( |
|
|
viewer_html, |
|
|
str(output_path), |
|
|
str(turntable_path), |
|
|
str(ply_path), |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_details = traceback.format_exc() |
|
|
print(f"Error details:\n{error_details}") |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
def main(): |
|
|
"""Run the FaceLift application.""" |
|
|
pipeline = FaceLiftPipeline() |
|
|
|
|
|
|
|
|
examples = [] |
|
|
if pipeline.examples_dir.exists(): |
|
|
examples = [[str(f), True, 3.0, 4, 50] for f in sorted(pipeline.examples_dir.iterdir()) |
|
|
if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}] |
|
|
|
|
|
with gr.Blocks(title="FaceLift: Single Image 3D Face Reconstruction") as demo: |
|
|
|
|
|
gr.Markdown("## [ICCV 2025] FaceLift: Learning Generalizable Single Image 3D Face Reconstruction from Synthetic Heads") |
|
|
|
|
|
gr.Markdown(""" |
|
|
### 💡 Tips for Best Results |
|
|
- Works best with near-frontal portrait images. |
|
|
- The provided checkpoints were not trained with accessories (glasses, hats, etc.). Portraits containing accessories may produce suboptimal results. |
|
|
- If face detection fails, try disabling auto-cropping and manually crop to square. |
|
|
- Inference complete when the turntable video is generated, the interactive 3D gaussian might take several seconds to load. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
in_image = gr.Image(type="filepath", label="Input Portrait Image") |
|
|
auto_crop = gr.Checkbox(value=True, label="Auto Cropping") |
|
|
guidance = gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale") |
|
|
seed = gr.Number(value=4, label="Random Seed") |
|
|
steps = gr.Slider(10, 100, 50, step=5, label="Generation Steps") |
|
|
run_btn = gr.Button("Generate 3D Head", variant="primary") |
|
|
|
|
|
|
|
|
if examples: |
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[in_image, auto_crop, guidance, seed, steps], |
|
|
examples_per_page=10, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
out_viewer = gr.HTML(label="🎮 Interactive 3D Viewer") |
|
|
out_recon = gr.Image(label="3D Reconstruction Views") |
|
|
out_video = gr.PlayableVideo(label="Turntable Animation (360° View)", height=600) |
|
|
out_ply = gr.File(label="Download 3D Model (.ply)") |
|
|
|
|
|
|
|
|
def _generate_and_filter_outputs(image_path, auto_crop, guidance_scale, random_seed, num_steps): |
|
|
return pipeline.generate_3d_head(image_path, auto_crop, guidance_scale, random_seed, num_steps) |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=_generate_and_filter_outputs, |
|
|
inputs=[in_image, auto_crop, guidance, seed, steps], |
|
|
outputs=[out_viewer, out_recon, out_video, out_ply], |
|
|
) |
|
|
|
|
|
demo.queue(max_size=10) |
|
|
demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|