nexusbert commited on
Commit
854b743
·
1 Parent(s): 40ade6b
Files changed (1) hide show
  1. app.py +50 -3
app.py CHANGED
@@ -48,6 +48,7 @@ class MeasurementRequest(BaseModel):
48
  arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm")
49
  leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm")
50
  inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm")
 
51
 
52
  class Config:
53
  json_schema_extra = {
@@ -60,7 +61,8 @@ class MeasurementRequest(BaseModel):
60
  "shoulder_width": 47,
61
  "arm_length": 60,
62
  "leg_length": 98,
63
- "inseam": 81
 
64
  }
65
  }
66
 
@@ -70,7 +72,8 @@ async def root():
70
  return {
71
  "service": "Avatar Generation Service",
72
  "endpoints": {
73
- "/generate-avatar": "POST - Generate avatar from measurements",
 
74
  "/health": "GET - Health check"
75
  }
76
  }
@@ -85,9 +88,14 @@ async def health():
85
  async def generate_avatar(measurements: MeasurementRequest):
86
  try:
87
  measurements_dict = measurements.model_dump(exclude_none=True)
 
 
 
 
 
88
  normalized = process_measurements(measurements_dict)
89
  betas = predict_betas(normalized)
90
- vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH)
91
  img_np = render_avatar(vertices, faces)
92
 
93
  if img_np.dtype != np.uint8:
@@ -113,6 +121,45 @@ async def generate_avatar(measurements: MeasurementRequest):
113
  raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}")
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if __name__ == "__main__":
117
  import uvicorn
118
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
48
  arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm")
49
  leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm")
50
  inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm")
51
+ gender: Optional[str] = Field("male", description="Gender: 'male' or 'female'")
52
 
53
  class Config:
54
  json_schema_extra = {
 
61
  "shoulder_width": 47,
62
  "arm_length": 60,
63
  "leg_length": 98,
64
+ "inseam": 81,
65
+ "gender": "male"
66
  }
67
  }
68
 
 
72
  return {
73
  "service": "Avatar Generation Service",
74
  "endpoints": {
75
+ "/generate-avatar": "POST - Generate 2D avatar image (PNG) from measurements",
76
+ "/generate-avatar-3d": "POST - Generate 3D avatar mesh (OBJ) from measurements",
77
  "/health": "GET - Health check"
78
  }
79
  }
 
88
  async def generate_avatar(measurements: MeasurementRequest):
89
  try:
90
  measurements_dict = measurements.model_dump(exclude_none=True)
91
+ gender = measurements_dict.pop("gender", "male")
92
+
93
+ if gender not in ["male", "female", "neutral"]:
94
+ raise ValueError("Gender must be 'male', 'female', or 'neutral'")
95
+
96
  normalized = process_measurements(measurements_dict)
97
  betas = predict_betas(normalized)
98
+ vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
99
  img_np = render_avatar(vertices, faces)
100
 
101
  if img_np.dtype != np.uint8:
 
121
  raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}")
122
 
123
 
124
+ @app.post("/generate-avatar-3d")
125
+ async def generate_avatar_3d(measurements: MeasurementRequest):
126
+ try:
127
+ measurements_dict = measurements.model_dump(exclude_none=True)
128
+ gender = measurements_dict.pop("gender", "male")
129
+
130
+ if gender not in ["male", "female", "neutral"]:
131
+ raise ValueError("Gender must be 'male', 'female', or 'neutral'")
132
+
133
+ normalized = process_measurements(measurements_dict)
134
+ betas = predict_betas(normalized)
135
+ vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender)
136
+
137
+ import trimesh
138
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
139
+
140
+ buf = io.BytesIO()
141
+ mesh.export(file_obj=buf, file_type='obj')
142
+ buf.seek(0)
143
+
144
+ return StreamingResponse(
145
+ buf,
146
+ media_type="model/obj",
147
+ headers={"Content-Disposition": "attachment; filename=avatar.obj"}
148
+ )
149
+
150
+ except ValueError as e:
151
+ raise HTTPException(status_code=400, detail=str(e))
152
+ except FileNotFoundError as e:
153
+ raise HTTPException(
154
+ status_code=500,
155
+ detail=f"SMPL model not found: {str(e)}. "
156
+ f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. "
157
+ f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable."
158
+ )
159
+ except Exception as e:
160
+ raise HTTPException(status_code=500, detail=f"Error generating 3D avatar: {str(e)}")
161
+
162
+
163
  if __name__ == "__main__":
164
  import uvicorn
165
  uvicorn.run(app, host="0.0.0.0", port=7860)