nexusbert commited on
Commit
6a9bf88
·
1 Parent(s): ce3ed29
Files changed (8) hide show
  1. .gitignore +63 -0
  2. Dockerfile +50 -0
  3. app.py +93 -0
  4. beta_regressor.py +80 -0
  5. measurement_processor.py +88 -0
  6. renderer.py +89 -0
  7. requirements.txt +11 -0
  8. smpl_generator.py +105 -0
.gitignore ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .venv
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # SMPL models (large files, download separately)
37
+ models/smpl/
38
+ *.pkl
39
+ *.npz
40
+
41
+ # Trained models
42
+ *.pth
43
+ *.pt
44
+ *.ckpt
45
+
46
+ # Generated avatars
47
+ *.png
48
+ *.jpg
49
+ *.jpeg
50
+ avatars/
51
+
52
+ # Logs
53
+ *.log
54
+ logs/
55
+
56
+ # Environment variables
57
+ .env
58
+ .env.local
59
+
60
+ # OS
61
+ .DS_Store
62
+ Thumbs.db
63
+
Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base Image
2
+ FROM python:3.10-slim
3
+
4
+ # Build argument for Hugging Face token
5
+ ARG HF_TOKEN
6
+
7
+ ENV DEBIAN_FRONTEND=noninteractive \
8
+ PYTHONUNBUFFERED=1 \
9
+ PYTHONDONTWRITEBYTECODE=1 \
10
+ HF_TOKEN=${HF_TOKEN}
11
+
12
+ WORKDIR /code
13
+
14
+ # System Dependencies
15
+ RUN apt-get update && apt-get install -y --no-install-recommends \
16
+ build-essential \
17
+ git \
18
+ curl \
19
+ libopenblas-dev \
20
+ libomp-dev \
21
+ libosmesa6-dev \
22
+ libgl1-mesa-glx \
23
+ libglib2.0-0 \
24
+ && rm -rf /var/lib/apt/lists/*
25
+
26
+ # Copy requirements and install Python dependencies
27
+ COPY requirements.txt .
28
+ RUN pip install --no-cache-dir -r requirements.txt
29
+
30
+ # Hugging Face + model tools
31
+ RUN pip install --no-cache-dir huggingface-hub sentencepiece accelerate
32
+
33
+ # Hugging Face cache environment
34
+ ENV HF_HOME=/models/huggingface \
35
+ HUGGINGFACE_HUB_CACHE=/models/huggingface \
36
+ HF_HUB_CACHE=/models/huggingface
37
+
38
+ # Created cache dir and set permissions
39
+ RUN mkdir -p /models/huggingface && chmod -R 777 /models/huggingface
40
+
41
+ # Create SMPL models directory
42
+ RUN mkdir -p /code/models/smpl && chmod -R 777 /code/models
43
+
44
+ # Copy project files
45
+ COPY . .
46
+
47
+ EXPOSE 7860
48
+
49
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
50
+
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel, Field
4
+ from typing import Optional
5
+ import io
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from measurement_processor import process_measurements
10
+ from beta_regressor import predict_betas
11
+ from smpl_generator import generate_mesh
12
+ from renderer import render_avatar
13
+
14
+ app = FastAPI(
15
+ title="Avatar Generation Service",
16
+ description="Generate 2D avatar images from body measurements using SMPL"
17
+ )
18
+
19
+
20
+ class MeasurementRequest(BaseModel):
21
+ height: float = Field(..., gt=0, description="Height in cm")
22
+ weight: float = Field(..., gt=0, description="Weight in kg")
23
+ chest: float = Field(..., gt=0, description="Chest measurement in cm")
24
+ waist: float = Field(..., gt=0, description="Waist measurement in cm")
25
+ hips: float = Field(..., gt=0, description="Hips measurement in cm")
26
+ shoulder_width: Optional[float] = Field(None, gt=0, description="Shoulder width in cm")
27
+ arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm")
28
+ leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm")
29
+ inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm")
30
+
31
+ class Config:
32
+ json_schema_extra = {
33
+ "example": {
34
+ "height": 178,
35
+ "weight": 74,
36
+ "chest": 96,
37
+ "waist": 82,
38
+ "hips": 94,
39
+ "shoulder_width": 47,
40
+ "arm_length": 60,
41
+ "leg_length": 98,
42
+ "inseam": 81
43
+ }
44
+ }
45
+
46
+
47
+ @app.get("/")
48
+ async def root():
49
+ return {
50
+ "service": "Avatar Generation Service",
51
+ "endpoints": {
52
+ "/generate-avatar": "POST - Generate avatar from measurements",
53
+ "/health": "GET - Health check"
54
+ }
55
+ }
56
+
57
+
58
+ @app.get("/health")
59
+ async def health():
60
+ return {"status": "healthy"}
61
+
62
+
63
+ @app.post("/generate-avatar")
64
+ async def generate_avatar(measurements: MeasurementRequest):
65
+ try:
66
+ measurements_dict = measurements.model_dump(exclude_none=True)
67
+ normalized = process_measurements(measurements_dict)
68
+ betas = predict_betas(normalized)
69
+ vertices, faces = generate_mesh(betas)
70
+ img_np = render_avatar(vertices, faces)
71
+
72
+ if img_np.dtype != np.uint8:
73
+ img_np = (img_np * 255).astype(np.uint8) if img_np.max() <= 1.0 else img_np.astype(np.uint8)
74
+
75
+ img = Image.fromarray(img_np, mode='RGB')
76
+ buf = io.BytesIO()
77
+ img.save(buf, format="PNG")
78
+ buf.seek(0)
79
+
80
+ return StreamingResponse(buf, media_type="image/png")
81
+
82
+ except ValueError as e:
83
+ raise HTTPException(status_code=400, detail=str(e))
84
+ except FileNotFoundError as e:
85
+ raise HTTPException(status_code=500, detail=f"SMPL model not found: {str(e)}")
86
+ except Exception as e:
87
+ raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ import uvicorn
92
+ uvicorn.run(app, host="0.0.0.0", port=7860)
93
+
beta_regressor.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+
8
+ class BetaRegressor(nn.Module):
9
+ def __init__(self, input_dim: int = 9, output_dim: int = 10, hidden_dims: list = [64, 32]):
10
+ super(BetaRegressor, self).__init__()
11
+
12
+ layers = []
13
+ prev_dim = input_dim
14
+
15
+ for hidden_dim in hidden_dims:
16
+ layers.append(nn.Linear(prev_dim, hidden_dim))
17
+ layers.append(nn.ReLU())
18
+ layers.append(nn.Dropout(0.1))
19
+ prev_dim = hidden_dim
20
+
21
+ layers.append(nn.Linear(prev_dim, output_dim))
22
+ layers.append(nn.Tanh())
23
+
24
+ self.network = nn.Sequential(*layers)
25
+ self._initialize_weights()
26
+
27
+ def _initialize_weights(self):
28
+ for m in self.modules():
29
+ if isinstance(m, nn.Linear):
30
+ nn.init.xavier_uniform_(m.weight, gain=0.1)
31
+ nn.init.zeros_(m.bias)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ return self.network(x)
35
+
36
+
37
+ class MeasurementToBetaPredictor:
38
+ def __init__(self, model_path: Optional[str] = None, device: str = "cpu"):
39
+ self.device = torch.device(device)
40
+ self.model = BetaRegressor().to(self.device)
41
+ self.model.eval()
42
+
43
+ if model_path and Path(model_path).exists():
44
+ self.load_model(model_path)
45
+ else:
46
+ print("Warning: Using untrained model. Results may not be optimal.")
47
+ print("Consider training the model or loading pretrained weights.")
48
+
49
+ def load_model(self, model_path: str):
50
+ checkpoint = torch.load(model_path, map_location=self.device)
51
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
52
+ self.model.load_state_dict(checkpoint['model_state_dict'])
53
+ else:
54
+ self.model.load_state_dict(checkpoint)
55
+ print(f"Loaded model from {model_path}")
56
+
57
+ def predict(self, normalized_measurements: np.ndarray) -> np.ndarray:
58
+ with torch.no_grad():
59
+ measurements_tensor = torch.FloatTensor(normalized_measurements).unsqueeze(0).to(self.device)
60
+ betas_tensor = self.model(measurements_tensor)
61
+ betas = betas_tensor.squeeze(0).cpu().numpy()
62
+ betas = betas * 2.0
63
+
64
+ return betas
65
+
66
+
67
+ _predictor_instance = None
68
+
69
+
70
+ def get_predictor(model_path: Optional[str] = None, device: str = "cpu") -> MeasurementToBetaPredictor:
71
+ global _predictor_instance
72
+ if _predictor_instance is None:
73
+ _predictor_instance = MeasurementToBetaPredictor(model_path=model_path, device=device)
74
+ return _predictor_instance
75
+
76
+
77
+ def predict_betas(normalized_measurements: np.ndarray, model_path: Optional[str] = None) -> np.ndarray:
78
+ predictor = get_predictor(model_path=model_path)
79
+ return predictor.predict(normalized_measurements)
80
+
measurement_processor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import numpy as np
3
+
4
+
5
+ class MeasurementProcessor:
6
+ NORMALIZATION_FACTORS = {
7
+ "height": 200.0,
8
+ "weight": 100.0,
9
+ "chest": 150.0,
10
+ "waist": 150.0,
11
+ "hips": 150.0,
12
+ "shoulder_width": 60.0,
13
+ "arm_length": 80.0,
14
+ "leg_length": 120.0,
15
+ "inseam": 100.0,
16
+ }
17
+
18
+ REQUIRED_MEASUREMENTS = [
19
+ "height",
20
+ "weight",
21
+ "chest",
22
+ "waist",
23
+ "hips",
24
+ ]
25
+
26
+ OPTIONAL_MEASUREMENTS = {
27
+ "shoulder_width": 40.0,
28
+ "arm_length": 60.0,
29
+ "leg_length": 90.0,
30
+ "inseam": 75.0,
31
+ }
32
+
33
+ @classmethod
34
+ def validate_measurements(cls, measurements: Dict) -> Dict[str, str]:
35
+ errors = []
36
+
37
+ for field in cls.REQUIRED_MEASUREMENTS:
38
+ if field not in measurements:
39
+ errors.append(f"Missing required measurement: {field}")
40
+ elif not isinstance(measurements[field], (int, float)):
41
+ errors.append(f"Invalid type for {field}: must be number")
42
+ elif measurements[field] <= 0:
43
+ errors.append(f"Invalid value for {field}: must be positive")
44
+
45
+ for field, default in cls.OPTIONAL_MEASUREMENTS.items():
46
+ if field in measurements:
47
+ if not isinstance(measurements[field], (int, float)):
48
+ errors.append(f"Invalid type for {field}: must be number")
49
+ elif measurements[field] <= 0:
50
+ errors.append(f"Invalid value for {field}: must be positive")
51
+
52
+ return {
53
+ "valid": len(errors) == 0,
54
+ "errors": errors
55
+ }
56
+
57
+ @classmethod
58
+ def normalize_measurements(cls, measurements: Dict) -> np.ndarray:
59
+ processed = measurements.copy()
60
+ for field, default in cls.OPTIONAL_MEASUREMENTS.items():
61
+ if field not in processed:
62
+ processed[field] = default
63
+
64
+ normalized = []
65
+ measurement_order = [
66
+ "height", "weight", "chest", "waist", "hips",
67
+ "shoulder_width", "arm_length", "leg_length", "inseam"
68
+ ]
69
+
70
+ for field in measurement_order:
71
+ value = processed[field]
72
+ factor = cls.NORMALIZATION_FACTORS[field]
73
+ normalized.append(value / factor)
74
+
75
+ return np.array(normalized, dtype=np.float32)
76
+
77
+ @classmethod
78
+ def process(cls, measurements: Dict) -> np.ndarray:
79
+ validation = cls.validate_measurements(measurements)
80
+ if not validation["valid"]:
81
+ raise ValueError(f"Invalid measurements: {', '.join(validation['errors'])}")
82
+
83
+ return cls.normalize_measurements(measurements)
84
+
85
+
86
+ def process_measurements(measurements: Dict) -> np.ndarray:
87
+ return MeasurementProcessor.process(measurements)
88
+
renderer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pyrender
3
+ import trimesh
4
+ from typing import Tuple, Optional
5
+
6
+
7
+ class AvatarRenderer:
8
+ def __init__(
9
+ self,
10
+ image_size: int = 512,
11
+ camera_type: str = "orthographic",
12
+ light_intensity: float = 2.0
13
+ ):
14
+ self.image_size = image_size
15
+ self.camera_type = camera_type
16
+ self.light_intensity = light_intensity
17
+ self.renderer = pyrender.OffscreenRenderer(image_size, image_size)
18
+
19
+ def render(
20
+ self,
21
+ vertices: np.ndarray,
22
+ faces: np.ndarray,
23
+ camera_pose: Optional[np.ndarray] = None,
24
+ light_pose: Optional[np.ndarray] = None
25
+ ) -> np.ndarray:
26
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
27
+ mesh.vertices -= mesh.vertices.mean(axis=0)
28
+ scale = 1.0 / (mesh.vertices.max(axis=0) - mesh.vertices.min(axis=0)).max()
29
+ mesh.vertices *= scale * 0.8
30
+
31
+ pyrender_mesh = pyrender.Mesh.from_trimesh(mesh)
32
+ scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0])
33
+ scene.add(pyrender_mesh)
34
+
35
+ if self.camera_type == "orthographic":
36
+ camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0)
37
+ else:
38
+ camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0)
39
+
40
+ if camera_pose is None:
41
+ camera_pose = np.eye(4)
42
+ camera_pose[:3, 3] = [0, 0, 2.5]
43
+ scene.add(camera, pose=camera_pose)
44
+
45
+ if light_pose is None:
46
+ light_pose = np.eye(4)
47
+ light_pose[:3, 3] = [0, 0, 2.5]
48
+
49
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=self.light_intensity)
50
+ scene.add(light, pose=light_pose)
51
+
52
+ ambient_light = pyrender.SpotLight(
53
+ color=[1.0, 1.0, 1.0],
54
+ intensity=0.5,
55
+ innerConeAngle=np.pi / 16.0,
56
+ outerConeAngle=np.pi / 6.0
57
+ )
58
+ scene.add(ambient_light, pose=light_pose)
59
+
60
+ color, depth = self.renderer.render(scene)
61
+
62
+ if color.shape[2] == 4:
63
+ color = color[:, :, :3]
64
+
65
+ return color
66
+
67
+ def __del__(self):
68
+ if hasattr(self, 'renderer'):
69
+ self.renderer.delete()
70
+
71
+
72
+ _renderer_instance = None
73
+
74
+
75
+ def get_renderer(image_size: int = 512, camera_type: str = "orthographic") -> AvatarRenderer:
76
+ global _renderer_instance
77
+ if _renderer_instance is None:
78
+ _renderer_instance = AvatarRenderer(image_size=image_size, camera_type=camera_type)
79
+ return _renderer_instance
80
+
81
+
82
+ def render_avatar(
83
+ vertices: np.ndarray,
84
+ faces: np.ndarray,
85
+ image_size: int = 512
86
+ ) -> np.ndarray:
87
+ renderer = get_renderer(image_size=image_size)
88
+ return renderer.render(vertices, faces)
89
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pydantic
4
+ numpy
5
+ torch
6
+ smplx
7
+ pyrender
8
+ trimesh
9
+ Pillow
10
+ scikit-learn
11
+
smpl_generator.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from typing import Tuple, Optional
5
+ import smplx
6
+
7
+
8
+ class SMPLGenerator:
9
+ def __init__(self, model_path: str = "models/smpl", gender: str = "neutral", device: str = "cpu"):
10
+ self.device = torch.device(device)
11
+ self.gender = gender
12
+ self.model_path = Path(model_path)
13
+
14
+ if not self.model_path.exists():
15
+ raise FileNotFoundError(
16
+ f"SMPL model not found at {model_path}. "
17
+ f"Please download SMPL models from https://smpl.is.tue.mpg.de/"
18
+ )
19
+
20
+ try:
21
+ self.smpl_model = smplx.create(
22
+ str(self.model_path),
23
+ model_type='smpl',
24
+ gender=gender,
25
+ batch_size=1,
26
+ ext='npz'
27
+ ).to(self.device)
28
+ except Exception as e:
29
+ try:
30
+ self.smpl_model = smplx.create(
31
+ str(self.model_path),
32
+ model_type='smpl',
33
+ gender=gender,
34
+ batch_size=1,
35
+ ext='pkl'
36
+ ).to(self.device)
37
+ except Exception as e2:
38
+ raise RuntimeError(
39
+ f"Failed to load SMPL model: {e}. "
40
+ f"Please ensure SMPL model files are in {model_path}"
41
+ )
42
+
43
+ def generate_mesh(
44
+ self,
45
+ betas: np.ndarray,
46
+ body_pose: Optional[np.ndarray] = None,
47
+ global_orient: Optional[np.ndarray] = None,
48
+ transl: Optional[np.ndarray] = None
49
+ ) -> Tuple[np.ndarray, np.ndarray]:
50
+ if betas.ndim == 1:
51
+ betas = betas.unsqueeze(0) if isinstance(betas, torch.Tensor) else betas[np.newaxis, :]
52
+
53
+ if isinstance(betas, np.ndarray):
54
+ betas = torch.FloatTensor(betas).to(self.device)
55
+
56
+ batch_size = betas.shape[0]
57
+
58
+ if body_pose is None:
59
+ body_pose = torch.zeros([batch_size, 69], device=self.device)
60
+ elif isinstance(body_pose, np.ndarray):
61
+ body_pose = torch.FloatTensor(body_pose).to(self.device)
62
+
63
+ if global_orient is None:
64
+ global_orient = torch.zeros([batch_size, 3], device=self.device)
65
+ elif isinstance(global_orient, np.ndarray):
66
+ global_orient = torch.FloatTensor(global_orient).to(self.device)
67
+
68
+ if transl is None:
69
+ transl = torch.zeros([batch_size, 3], device=self.device)
70
+ elif isinstance(transl, np.ndarray):
71
+ transl = torch.FloatTensor(transl).to(self.device)
72
+
73
+ with torch.no_grad():
74
+ output = self.smpl_model(
75
+ betas=betas,
76
+ body_pose=body_pose,
77
+ global_orient=global_orient,
78
+ transl=transl
79
+ )
80
+
81
+ vertices = output.vertices[0].detach().cpu().numpy()
82
+ faces = self.smpl_model.faces
83
+
84
+ return vertices, faces
85
+
86
+
87
+ _generator_instance = None
88
+
89
+
90
+ def get_generator(model_path: str = "models/smpl", gender: str = "neutral", device: str = "cpu") -> SMPLGenerator:
91
+ global _generator_instance
92
+ if _generator_instance is None:
93
+ _generator_instance = SMPLGenerator(model_path=model_path, gender=gender, device=device)
94
+ return _generator_instance
95
+
96
+
97
+ def generate_mesh(
98
+ betas: np.ndarray,
99
+ model_path: str = "models/smpl",
100
+ gender: str = "neutral",
101
+ device: str = "cpu"
102
+ ) -> Tuple[np.ndarray, np.ndarray]:
103
+ generator = get_generator(model_path=model_path, gender=gender, device=device)
104
+ return generator.generate_mesh(betas)
105
+