3D / smpl_generator.py
nexusbert's picture
push natural pose
3d2acee
import torch
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
import smplx
import os
class SMPLGenerator:
def __init__(self, model_path: str = "smpl", gender: str = "neutral", device: str = "cpu"):
self.device = torch.device(device)
self.gender = gender
model_path_obj = Path(model_path)
if not model_path_obj.exists():
alt_paths = [Path("smpl"), Path("smpl/smpl")]
for alt_path in alt_paths:
if alt_path.exists():
model_path_obj = alt_path
print(f"Using alternative model path: {model_path_obj}")
break
else:
model_path_obj.mkdir(parents=True, exist_ok=True)
models_source = Path("smpl/smpl/models")
if not models_source.exists():
models_source = model_path_obj / "models"
self.model_path = model_path_obj
model_path_str = str(self.model_path)
if gender == "neutral":
gender = "male"
print("Note: Neutral gender not available, using male model")
if models_source.exists():
model_files = list(models_source.glob("*.pkl"))
print(f"Found {len(model_files)} model files in {models_source}: {[f.name for f in model_files]}")
import shutil
expected_smpl_dir = Path("smpl") / "smpl"
expected_models_dir = expected_smpl_dir / "models"
expected_models_dir.mkdir(parents=True, exist_ok=True)
for model_file in model_files:
file_lower = model_file.name.lower()
target_name = None
if "basicmodel_m" in file_lower or "male" in file_lower:
target_name = "SMPL_MALE.pkl"
elif "basicmodel_f" in file_lower or "female" in file_lower:
target_name = "SMPL_FEMALE.pkl"
elif "neutral" in file_lower:
target_name = "SMPL_NEUTRAL.pkl"
if target_name:
target_in_models = expected_models_dir / target_name
target_in_smpl = expected_smpl_dir / target_name
if not target_in_models.exists():
shutil.copy2(model_file, target_in_models)
print(f"Copied {model_file.name} -> {target_in_models}")
if not target_in_smpl.exists():
shutil.copy2(model_file, target_in_smpl)
print(f"Copied {model_file.name} -> {target_in_smpl}")
else:
target_file = expected_models_dir / model_file.name
if not target_file.exists():
shutil.copy2(model_file, target_file)
print(f"Copied {model_file.name} to {target_file}")
models_dir = model_path_obj / "smpl" / "models"
if not models_dir.exists():
models_dir = model_path_obj / "models"
base_path = Path(".").absolute()
model_paths_to_try = [
str(base_path),
".",
"smpl",
str(model_path_obj),
]
if models_dir.exists():
parent_of_smpl = models_dir.parent.parent
if parent_of_smpl.exists():
model_paths_to_try.append(str(parent_of_smpl))
model_paths_to_try = list(dict.fromkeys(model_paths_to_try))
last_error = None
for try_path in model_paths_to_try:
print(f"Trying model path: {try_path}")
try:
self.smpl_model = smplx.create(
model_path=try_path,
model_type='smpl',
gender=gender,
batch_size=1,
ext='npz'
).to(self.device)
print(f"Successfully loaded model from: {try_path}")
break
except Exception as e:
last_error = e
try:
self.smpl_model = smplx.create(
model_path=try_path,
model_type='smpl',
gender=gender,
batch_size=1,
ext='pkl'
).to(self.device)
print(f"Successfully loaded model from: {try_path}")
break
except Exception as e2:
last_error = e2
try:
self.smpl_model = smplx.create(
model_path=try_path,
model_type='smpl',
gender=gender,
batch_size=1
).to(self.device)
print(f"Successfully loaded model from: {try_path}")
break
except Exception as e3:
last_error = e3
continue
else:
error_msg = str(last_error) if last_error else "Unknown error"
print(f"Error details: {error_msg}")
raise RuntimeError(
f"Failed to load SMPL model after trying paths: {model_paths_to_try}. "
f"Error: {error_msg}. "
f"Models should be in a 'models' subdirectory. "
f"Expected files: basicModel_f_lbs_*.pkl (female) or basicmodel_m_lbs_*.pkl (male)"
)
def generate_mesh(
self,
betas: np.ndarray,
body_pose: Optional[np.ndarray] = None,
global_orient: Optional[np.ndarray] = None,
transl: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray]:
if betas.ndim == 1:
betas = betas.unsqueeze(0) if isinstance(betas, torch.Tensor) else betas[np.newaxis, :]
if isinstance(betas, np.ndarray):
betas = torch.FloatTensor(betas).to(self.device)
batch_size = betas.shape[0]
if global_orient is None:
global_orient = torch.zeros([batch_size, 3], device=self.device)
global_orient[0, 0] = np.radians(2)
elif isinstance(global_orient, np.ndarray):
global_orient = torch.FloatTensor(global_orient).to(self.device)
if body_pose is None:
body_pose = torch.zeros([batch_size, 69], device=self.device)
shoulder_down = np.radians(-12.5)
shoulder_forward = np.radians(7.5)
upper_arm_adduction = np.radians(12.5)
upper_arm_forward = np.radians(7.5)
elbow_bend = np.radians(12.5)
palm_inward = np.radians(15)
hip_forward_tilt = np.radians(2)
hip_outward = np.radians(7.5)
hip_flex = np.radians(3.5)
knee_bend = np.radians(4)
foot_outward = np.radians(11.5)
body_pose[0, 6:9] = torch.tensor([shoulder_down, 0, shoulder_forward], device=self.device)
body_pose[0, 9:12] = torch.tensor([shoulder_down, 0, -shoulder_forward], device=self.device)
body_pose[0, 12:15] = torch.tensor([upper_arm_adduction, upper_arm_forward, 0], device=self.device)
body_pose[0, 15:18] = torch.tensor([-upper_arm_adduction, upper_arm_forward, 0], device=self.device)
body_pose[0, 18:21] = torch.tensor([0, elbow_bend, 0], device=self.device)
body_pose[0, 21:24] = torch.tensor([0, elbow_bend, 0], device=self.device)
body_pose[0, 24:27] = torch.tensor([0, 0, palm_inward], device=self.device)
body_pose[0, 27:30] = torch.tensor([0, 0, -palm_inward], device=self.device)
body_pose[0, 30:33] = torch.tensor([np.radians(5), 0, 0], device=self.device)
body_pose[0, 33:36] = torch.tensor([np.radians(3), 0, 0], device=self.device)
body_pose[0, 36:39] = torch.tensor([0, 0, 0], device=self.device)
body_pose[0, 39:42] = torch.tensor([np.radians(2), 0, 0], device=self.device)
body_pose[0, 42:45] = torch.tensor([0, 0, 0], device=self.device)
body_pose[0, 45:48] = torch.tensor([hip_flex, hip_outward, 0], device=self.device)
body_pose[0, 48:51] = torch.tensor([hip_flex, -hip_outward, 0], device=self.device)
body_pose[0, 51:54] = torch.tensor([0, knee_bend, 0], device=self.device)
body_pose[0, 54:57] = torch.tensor([0, knee_bend, 0], device=self.device)
body_pose[0, 57:60] = torch.tensor([0, foot_outward, 0], device=self.device)
body_pose[0, 60:63] = torch.tensor([0, -foot_outward, 0], device=self.device)
body_pose[0, 63:66] = torch.tensor([0, 0, 0], device=self.device)
body_pose[0, 66:69] = torch.tensor([0, 0, 0], device=self.device)
elif isinstance(body_pose, np.ndarray):
body_pose = torch.FloatTensor(body_pose).to(self.device)
if transl is None:
transl = torch.zeros([batch_size, 3], device=self.device)
elif isinstance(transl, np.ndarray):
transl = torch.FloatTensor(transl).to(self.device)
with torch.no_grad():
output = self.smpl_model(
betas=betas,
body_pose=body_pose,
global_orient=global_orient,
transl=transl
)
vertices = output.vertices[0].detach().cpu().numpy()
faces = self.smpl_model.faces
return vertices, faces
_generator_instance = None
def get_generator(model_path: str = "smpl", gender: str = "neutral", device: str = "cpu") -> SMPLGenerator:
global _generator_instance
if _generator_instance is None:
_generator_instance = SMPLGenerator(model_path=model_path, gender=gender, device=device)
return _generator_instance
def generate_mesh(
betas: np.ndarray,
model_path: str = "smpl",
gender: str = "neutral",
device: str = "cpu"
) -> Tuple[np.ndarray, np.ndarray]:
generator = get_generator(model_path=model_path, gender=gender, device=device)
return generator.generate_mesh(betas)