File size: 6,023 Bytes
5d4d48e 6a9bf88 32d5b17 6a9bf88 1ed5e21 32d5b17 6a9bf88 854b743 6a9bf88 854b743 6a9bf88 854b743 6a9bf88 854b743 6a9bf88 854b743 6a9bf88 32d5b17 6a9bf88 854b743 6a9bf88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import Optional
import io
import numpy as np
from PIL import Image
from pathlib import Path
from measurement_processor import process_measurements
from beta_regressor import predict_betas
from smpl_generator import generate_mesh
from renderer import render_avatar
SMPL_MODEL_PATH = os.getenv("SMPL_MODEL_PATH", "smpl")
smpl_models_dirs = [Path("smpl/smpl/models"), Path("smpl/models"), Path(SMPL_MODEL_PATH) / "models"]
found_models = False
for models_dir in smpl_models_dirs:
if models_dir.exists():
print(f"Found SMPL models in {models_dir}")
model_files = list(models_dir.glob("*.pkl"))
for f in model_files:
print(f" - {f.name}")
found_models = True
break
if not found_models:
print(f"Warning: SMPL models not found in expected locations")
print(f"Looking in {SMPL_MODEL_PATH}...")
app = FastAPI(
title="Avatar Generation Service",
description="Generate 2D avatar images from body measurements using SMPL"
)
class MeasurementRequest(BaseModel):
height: float = Field(..., gt=0, description="Height in cm")
weight: float = Field(..., gt=0, description="Weight in kg")
chest: float = Field(..., gt=0, description="Chest measurement in cm")
waist: float = Field(..., gt=0, description="Waist measurement in cm")
hips: float = Field(..., gt=0, description="Hips measurement in cm")
shoulder_width: Optional[float] = Field(None, gt=0, description="Shoulder width in cm")
arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm")
leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm")
inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm")
gender: Optional[str] = Field("male", description="Gender: 'male' or 'female'")
class Config:
json_schema_extra = {
"example": {
"height": 178,
"weight": 74,
"chest": 96,
"waist": 82,
"hips": 94,
"shoulder_width": 47,
"arm_length": 60,
"leg_length": 98,
"inseam": 81,
"gender": "male"
}
}
@app.get("/")
async def root():
return {
"service": "Avatar Generation Service",
"endpoints": {
"/generate-avatar": "POST - Generate 2D avatar image (PNG) from measurements",
"/generate-avatar-3d": "POST - Generate 3D avatar mesh (OBJ) from measurements",
"/health": "GET - Health check"
}
}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/generate-avatar")
async def generate_avatar(measurements: MeasurementRequest):
try:
measurements_dict = measurements.model_dump(exclude_none=True)
gender = measurements_dict.pop("gender", "male")
if gender not in ["male", "female", "neutral"]:
raise ValueError("Gender must be 'male', 'female', or 'neutral'")
normalized = process_measurements(measurements_dict)
betas = predict_betas(normalized)
vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
img_np = render_avatar(vertices, faces)
if img_np.dtype != np.uint8:
img_np = (img_np * 255).astype(np.uint8) if img_np.max() <= 1.0 else img_np.astype(np.uint8)
img = Image.fromarray(img_np, mode='RGB')
buf = io.BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(
status_code=500,
detail=f"SMPL model not found: {str(e)}. "
f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. "
f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable."
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}")
@app.post("/generate-avatar-3d")
async def generate_avatar_3d(measurements: MeasurementRequest):
try:
measurements_dict = measurements.model_dump(exclude_none=True)
gender = measurements_dict.pop("gender", "male")
if gender not in ["male", "female", "neutral"]:
raise ValueError("Gender must be 'male', 'female', or 'neutral'")
normalized = process_measurements(measurements_dict)
betas = predict_betas(normalized)
vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
import trimesh
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
buf = io.BytesIO()
mesh.export(file_obj=buf, file_type='obj')
buf.seek(0)
return StreamingResponse(
buf,
media_type="model/obj",
headers={"Content-Disposition": "attachment; filename=avatar.obj"}
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(
status_code=500,
detail=f"SMPL model not found: {str(e)}. "
f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. "
f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable."
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating 3D avatar: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|