File size: 6,023 Bytes
5d4d48e
 
 
6a9bf88
 
 
 
 
 
 
32d5b17
6a9bf88
 
 
 
 
 
1ed5e21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32d5b17
 
6a9bf88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854b743
6a9bf88
 
 
 
 
 
 
 
 
 
 
 
854b743
 
6a9bf88
 
 
 
 
 
 
 
 
854b743
 
6a9bf88
 
 
 
 
 
 
 
 
 
 
 
 
 
854b743
 
 
 
 
6a9bf88
 
854b743
6a9bf88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32d5b17
 
 
 
 
 
6a9bf88
 
 
 
854b743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a9bf88
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
os.environ['PYOPENGL_PLATFORM'] = 'osmesa'

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import Optional
import io
import numpy as np
from PIL import Image
from pathlib import Path

from measurement_processor import process_measurements
from beta_regressor import predict_betas
from smpl_generator import generate_mesh
from renderer import render_avatar

SMPL_MODEL_PATH = os.getenv("SMPL_MODEL_PATH", "smpl")

smpl_models_dirs = [Path("smpl/smpl/models"), Path("smpl/models"), Path(SMPL_MODEL_PATH) / "models"]
found_models = False
for models_dir in smpl_models_dirs:
    if models_dir.exists():
        print(f"Found SMPL models in {models_dir}")
        model_files = list(models_dir.glob("*.pkl"))
        for f in model_files:
            print(f"  - {f.name}")
        found_models = True
        break

if not found_models:
    print(f"Warning: SMPL models not found in expected locations")
    print(f"Looking in {SMPL_MODEL_PATH}...")

app = FastAPI(
    title="Avatar Generation Service",
    description="Generate 2D avatar images from body measurements using SMPL"
)


class MeasurementRequest(BaseModel):
    height: float = Field(..., gt=0, description="Height in cm")
    weight: float = Field(..., gt=0, description="Weight in kg")
    chest: float = Field(..., gt=0, description="Chest measurement in cm")
    waist: float = Field(..., gt=0, description="Waist measurement in cm")
    hips: float = Field(..., gt=0, description="Hips measurement in cm")
    shoulder_width: Optional[float] = Field(None, gt=0, description="Shoulder width in cm")
    arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm")
    leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm")
    inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm")
    gender: Optional[str] = Field("male", description="Gender: 'male' or 'female'")
    
    class Config:
        json_schema_extra = {
            "example": {
                "height": 178,
                "weight": 74,
                "chest": 96,
                "waist": 82,
                "hips": 94,
                "shoulder_width": 47,
                "arm_length": 60,
                "leg_length": 98,
                "inseam": 81,
                "gender": "male"
            }
        }


@app.get("/")
async def root():
    return {
        "service": "Avatar Generation Service",
        "endpoints": {
            "/generate-avatar": "POST - Generate 2D avatar image (PNG) from measurements",
            "/generate-avatar-3d": "POST - Generate 3D avatar mesh (OBJ) from measurements",
            "/health": "GET - Health check"
        }
    }


@app.get("/health")
async def health():
    return {"status": "healthy"}


@app.post("/generate-avatar")
async def generate_avatar(measurements: MeasurementRequest):
    try:
        measurements_dict = measurements.model_dump(exclude_none=True)
        gender = measurements_dict.pop("gender", "male")
        
        if gender not in ["male", "female", "neutral"]:
            raise ValueError("Gender must be 'male', 'female', or 'neutral'")
        
        normalized = process_measurements(measurements_dict)
        betas = predict_betas(normalized)
        vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
        img_np = render_avatar(vertices, faces)
        
        if img_np.dtype != np.uint8:
            img_np = (img_np * 255).astype(np.uint8) if img_np.max() <= 1.0 else img_np.astype(np.uint8)
        
        img = Image.fromarray(img_np, mode='RGB')
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        buf.seek(0)
        
        return StreamingResponse(buf, media_type="image/png")
    
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except FileNotFoundError as e:
        raise HTTPException(
            status_code=500, 
            detail=f"SMPL model not found: {str(e)}. "
                   f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. "
                   f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable."
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}")


@app.post("/generate-avatar-3d")
async def generate_avatar_3d(measurements: MeasurementRequest):
    try:
        measurements_dict = measurements.model_dump(exclude_none=True)
        gender = measurements_dict.pop("gender", "male")
        
        if gender not in ["male", "female", "neutral"]:
            raise ValueError("Gender must be 'male', 'female', or 'neutral'")
        
        normalized = process_measurements(measurements_dict)
        betas = predict_betas(normalized)
        vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
        
        import trimesh
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
        
        buf = io.BytesIO()
        mesh.export(file_obj=buf, file_type='obj')
        buf.seek(0)
        
        return StreamingResponse(
            buf,
            media_type="model/obj",
            headers={"Content-Disposition": "attachment; filename=avatar.obj"}
        )
    
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except FileNotFoundError as e:
        raise HTTPException(
            status_code=500, 
            detail=f"SMPL model not found: {str(e)}. "
                   f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. "
                   f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable."
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error generating 3D avatar: {str(e)}")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)