|
|
import os |
|
|
os.environ['PYOPENGL_PLATFORM'] = 'osmesa' |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional |
|
|
import io |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
|
|
|
from measurement_processor import process_measurements |
|
|
from beta_regressor import predict_betas |
|
|
from smpl_generator import generate_mesh |
|
|
from renderer import render_avatar |
|
|
|
|
|
SMPL_MODEL_PATH = os.getenv("SMPL_MODEL_PATH", "smpl") |
|
|
|
|
|
smpl_models_dirs = [Path("smpl/smpl/models"), Path("smpl/models"), Path(SMPL_MODEL_PATH) / "models"] |
|
|
found_models = False |
|
|
for models_dir in smpl_models_dirs: |
|
|
if models_dir.exists(): |
|
|
print(f"Found SMPL models in {models_dir}") |
|
|
model_files = list(models_dir.glob("*.pkl")) |
|
|
for f in model_files: |
|
|
print(f" - {f.name}") |
|
|
found_models = True |
|
|
break |
|
|
|
|
|
if not found_models: |
|
|
print(f"Warning: SMPL models not found in expected locations") |
|
|
print(f"Looking in {SMPL_MODEL_PATH}...") |
|
|
|
|
|
app = FastAPI( |
|
|
title="Avatar Generation Service", |
|
|
description="Generate 2D avatar images from body measurements using SMPL" |
|
|
) |
|
|
|
|
|
|
|
|
class MeasurementRequest(BaseModel): |
|
|
height: float = Field(..., gt=0, description="Height in cm") |
|
|
weight: float = Field(..., gt=0, description="Weight in kg") |
|
|
chest: float = Field(..., gt=0, description="Chest measurement in cm") |
|
|
waist: float = Field(..., gt=0, description="Waist measurement in cm") |
|
|
hips: float = Field(..., gt=0, description="Hips measurement in cm") |
|
|
shoulder_width: Optional[float] = Field(None, gt=0, description="Shoulder width in cm") |
|
|
arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm") |
|
|
leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm") |
|
|
inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm") |
|
|
gender: Optional[str] = Field("male", description="Gender: 'male' or 'female'") |
|
|
|
|
|
class Config: |
|
|
json_schema_extra = { |
|
|
"example": { |
|
|
"height": 178, |
|
|
"weight": 74, |
|
|
"chest": 96, |
|
|
"waist": 82, |
|
|
"hips": 94, |
|
|
"shoulder_width": 47, |
|
|
"arm_length": 60, |
|
|
"leg_length": 98, |
|
|
"inseam": 81, |
|
|
"gender": "male" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"service": "Avatar Generation Service", |
|
|
"endpoints": { |
|
|
"/generate-avatar": "POST - Generate 2D avatar image (PNG) from measurements", |
|
|
"/generate-avatar-3d": "POST - Generate 3D avatar mesh (OBJ) from measurements", |
|
|
"/health": "GET - Health check" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
|
|
|
@app.post("/generate-avatar") |
|
|
async def generate_avatar(measurements: MeasurementRequest): |
|
|
try: |
|
|
measurements_dict = measurements.model_dump(exclude_none=True) |
|
|
gender = measurements_dict.pop("gender", "male") |
|
|
|
|
|
if gender not in ["male", "female", "neutral"]: |
|
|
raise ValueError("Gender must be 'male', 'female', or 'neutral'") |
|
|
|
|
|
normalized = process_measurements(measurements_dict) |
|
|
betas = predict_betas(normalized) |
|
|
vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender) |
|
|
img_np = render_avatar(vertices, faces) |
|
|
|
|
|
if img_np.dtype != np.uint8: |
|
|
img_np = (img_np * 255).astype(np.uint8) if img_np.max() <= 1.0 else img_np.astype(np.uint8) |
|
|
|
|
|
img = Image.fromarray(img_np, mode='RGB') |
|
|
buf = io.BytesIO() |
|
|
img.save(buf, format="PNG") |
|
|
buf.seek(0) |
|
|
|
|
|
return StreamingResponse(buf, media_type="image/png") |
|
|
|
|
|
except ValueError as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
except FileNotFoundError as e: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"SMPL model not found: {str(e)}. " |
|
|
f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. " |
|
|
f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable." |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/generate-avatar-3d") |
|
|
async def generate_avatar_3d(measurements: MeasurementRequest): |
|
|
try: |
|
|
measurements_dict = measurements.model_dump(exclude_none=True) |
|
|
gender = measurements_dict.pop("gender", "male") |
|
|
|
|
|
if gender not in ["male", "female", "neutral"]: |
|
|
raise ValueError("Gender must be 'male', 'female', or 'neutral'") |
|
|
|
|
|
normalized = process_measurements(measurements_dict) |
|
|
betas = predict_betas(normalized) |
|
|
vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender) |
|
|
|
|
|
import trimesh |
|
|
mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
mesh.export(file_obj=buf, file_type='obj') |
|
|
buf.seek(0) |
|
|
|
|
|
return StreamingResponse( |
|
|
buf, |
|
|
media_type="model/obj", |
|
|
headers={"Content-Disposition": "attachment; filename=avatar.obj"} |
|
|
) |
|
|
|
|
|
except ValueError as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
except FileNotFoundError as e: |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"SMPL model not found: {str(e)}. " |
|
|
f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. " |
|
|
f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable." |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating 3D avatar: {str(e)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
|