3D / app.py
nexusbert's picture
push
854b743
import os
os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import Optional
import io
import numpy as np
from PIL import Image
from pathlib import Path
from measurement_processor import process_measurements
from beta_regressor import predict_betas
from smpl_generator import generate_mesh
from renderer import render_avatar
SMPL_MODEL_PATH = os.getenv("SMPL_MODEL_PATH", "smpl")
smpl_models_dirs = [Path("smpl/smpl/models"), Path("smpl/models"), Path(SMPL_MODEL_PATH) / "models"]
found_models = False
for models_dir in smpl_models_dirs:
if models_dir.exists():
print(f"Found SMPL models in {models_dir}")
model_files = list(models_dir.glob("*.pkl"))
for f in model_files:
print(f" - {f.name}")
found_models = True
break
if not found_models:
print(f"Warning: SMPL models not found in expected locations")
print(f"Looking in {SMPL_MODEL_PATH}...")
app = FastAPI(
title="Avatar Generation Service",
description="Generate 2D avatar images from body measurements using SMPL"
)
class MeasurementRequest(BaseModel):
height: float = Field(..., gt=0, description="Height in cm")
weight: float = Field(..., gt=0, description="Weight in kg")
chest: float = Field(..., gt=0, description="Chest measurement in cm")
waist: float = Field(..., gt=0, description="Waist measurement in cm")
hips: float = Field(..., gt=0, description="Hips measurement in cm")
shoulder_width: Optional[float] = Field(None, gt=0, description="Shoulder width in cm")
arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm")
leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm")
inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm")
gender: Optional[str] = Field("male", description="Gender: 'male' or 'female'")
class Config:
json_schema_extra = {
"example": {
"height": 178,
"weight": 74,
"chest": 96,
"waist": 82,
"hips": 94,
"shoulder_width": 47,
"arm_length": 60,
"leg_length": 98,
"inseam": 81,
"gender": "male"
}
}
@app.get("/")
async def root():
return {
"service": "Avatar Generation Service",
"endpoints": {
"/generate-avatar": "POST - Generate 2D avatar image (PNG) from measurements",
"/generate-avatar-3d": "POST - Generate 3D avatar mesh (OBJ) from measurements",
"/health": "GET - Health check"
}
}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/generate-avatar")
async def generate_avatar(measurements: MeasurementRequest):
try:
measurements_dict = measurements.model_dump(exclude_none=True)
gender = measurements_dict.pop("gender", "male")
if gender not in ["male", "female", "neutral"]:
raise ValueError("Gender must be 'male', 'female', or 'neutral'")
normalized = process_measurements(measurements_dict)
betas = predict_betas(normalized)
vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
img_np = render_avatar(vertices, faces)
if img_np.dtype != np.uint8:
img_np = (img_np * 255).astype(np.uint8) if img_np.max() <= 1.0 else img_np.astype(np.uint8)
img = Image.fromarray(img_np, mode='RGB')
buf = io.BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(
status_code=500,
detail=f"SMPL model not found: {str(e)}. "
f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. "
f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable."
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}")
@app.post("/generate-avatar-3d")
async def generate_avatar_3d(measurements: MeasurementRequest):
try:
measurements_dict = measurements.model_dump(exclude_none=True)
gender = measurements_dict.pop("gender", "male")
if gender not in ["male", "female", "neutral"]:
raise ValueError("Gender must be 'male', 'female', or 'neutral'")
normalized = process_measurements(measurements_dict)
betas = predict_betas(normalized)
vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
import trimesh
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
buf = io.BytesIO()
mesh.export(file_obj=buf, file_type='obj')
buf.seek(0)
return StreamingResponse(
buf,
media_type="model/obj",
headers={"Content-Disposition": "attachment; filename=avatar.obj"}
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(
status_code=500,
detail=f"SMPL model not found: {str(e)}. "
f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. "
f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable."
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating 3D avatar: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)