import os os.environ['PYOPENGL_PLATFORM'] = 'osmesa' from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from typing import Optional import io import numpy as np from PIL import Image from pathlib import Path from measurement_processor import process_measurements from beta_regressor import predict_betas from smpl_generator import generate_mesh from renderer import render_avatar SMPL_MODEL_PATH = os.getenv("SMPL_MODEL_PATH", "smpl") smpl_models_dirs = [Path("smpl/smpl/models"), Path("smpl/models"), Path(SMPL_MODEL_PATH) / "models"] found_models = False for models_dir in smpl_models_dirs: if models_dir.exists(): print(f"Found SMPL models in {models_dir}") model_files = list(models_dir.glob("*.pkl")) for f in model_files: print(f" - {f.name}") found_models = True break if not found_models: print(f"Warning: SMPL models not found in expected locations") print(f"Looking in {SMPL_MODEL_PATH}...") app = FastAPI( title="Avatar Generation Service", description="Generate 2D avatar images from body measurements using SMPL" ) class MeasurementRequest(BaseModel): height: float = Field(..., gt=0, description="Height in cm") weight: float = Field(..., gt=0, description="Weight in kg") chest: float = Field(..., gt=0, description="Chest measurement in cm") waist: float = Field(..., gt=0, description="Waist measurement in cm") hips: float = Field(..., gt=0, description="Hips measurement in cm") shoulder_width: Optional[float] = Field(None, gt=0, description="Shoulder width in cm") arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm") leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm") inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm") gender: Optional[str] = Field("male", description="Gender: 'male' or 'female'") class Config: json_schema_extra = { "example": { "height": 178, "weight": 74, "chest": 96, "waist": 82, "hips": 94, "shoulder_width": 47, "arm_length": 60, "leg_length": 98, "inseam": 81, "gender": "male" } } @app.get("/") async def root(): return { "service": "Avatar Generation Service", "endpoints": { "/generate-avatar": "POST - Generate 2D avatar image (PNG) from measurements", "/generate-avatar-3d": "POST - Generate 3D avatar mesh (OBJ) from measurements", "/health": "GET - Health check" } } @app.get("/health") async def health(): return {"status": "healthy"} @app.post("/generate-avatar") async def generate_avatar(measurements: MeasurementRequest): try: measurements_dict = measurements.model_dump(exclude_none=True) gender = measurements_dict.pop("gender", "male") if gender not in ["male", "female", "neutral"]: raise ValueError("Gender must be 'male', 'female', or 'neutral'") normalized = process_measurements(measurements_dict) betas = predict_betas(normalized) vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender) img_np = render_avatar(vertices, faces) if img_np.dtype != np.uint8: img_np = (img_np * 255).astype(np.uint8) if img_np.max() <= 1.0 else img_np.astype(np.uint8) img = Image.fromarray(img_np, mode='RGB') buf = io.BytesIO() img.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except FileNotFoundError as e: raise HTTPException( status_code=500, detail=f"SMPL model not found: {str(e)}. " f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. " f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable." ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}") @app.post("/generate-avatar-3d") async def generate_avatar_3d(measurements: MeasurementRequest): try: measurements_dict = measurements.model_dump(exclude_none=True) gender = measurements_dict.pop("gender", "male") if gender not in ["male", "female", "neutral"]: raise ValueError("Gender must be 'male', 'female', or 'neutral'") normalized = process_measurements(measurements_dict) betas = predict_betas(normalized) vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender) import trimesh mesh = trimesh.Trimesh(vertices=vertices, faces=faces) buf = io.BytesIO() mesh.export(file_obj=buf, file_type='obj') buf.seek(0) return StreamingResponse( buf, media_type="model/obj", headers={"Content-Disposition": "attachment; filename=avatar.obj"} ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except FileNotFoundError as e: raise HTTPException( status_code=500, detail=f"SMPL model not found: {str(e)}. " f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. " f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable." ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating 3D avatar: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)