|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- classification |
|
|
- image-classification |
|
|
- face-classification |
|
|
- age-prediction |
|
|
- computer-vision |
|
|
- swin-transformer |
|
|
- race-prediction |
|
|
- gender-prediction |
|
|
--- |
|
|
|
|
|
|
|
|
## AgeRaceGenderNet |
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/663a886250daed366b657df4/7-yshHOeyZ4DzEDsjVuMA.png" alt="AgeRaceGenderNet" width="800"/> |
|
|
|
|
|
<!--  |
|
|
--> |
|
|
**AgeRaceGenderNet** is a lightweight, multi-task face classification model capable of predicting **age**, **gender**, and **race** from facial images. It is built with a **Swin Transformer V2 (Tiny)** backbone and designed for **fast inference** (\~5.94 GFLOPs). |
|
|
|
|
|
--- |
|
|
|
|
|
### Tasks & Outputs |
|
|
|
|
|
This model simultaneously predicts: |
|
|
|
|
|
* **\[age]**: Integer from `0` to `116`, representing estimated age. |
|
|
* **\[gender]**: `0` for **male**, `1` for **female**. |
|
|
* **\[race]**: Integer from `0` to `4` representing: |
|
|
|
|
|
* `0`: White |
|
|
* `1`: Black |
|
|
* `2`: Asian |
|
|
* `3`: Indian |
|
|
* `4`: Others (e.g., Hispanic, Latino, Middle Eastern) |
|
|
|
|
|
--- |
|
|
|
|
|
### Model Architecture |
|
|
|
|
|
* **Backbone**: [Swin V2 Tiny](https://arxiv.org/abs/2111.09883) (pretrained and fine-tuned) |
|
|
* **Head**: Multi-task architecture with dedicated classification heads for each demographic task. |
|
|
* **Criterion**: Custom `MultiTaskLoss` function |
|
|
* **Total Parameters**: **28.4M** |
|
|
* **Trainable**: 25.7M |
|
|
* **Non-trainable**: 2.7M |
|
|
* **Model Size**: \~113.7 MB |
|
|
* **Inference Cost**: \~5.94 GFLOPs |
|
|
|
|
|
--- |
|
|
|
|
|
### Training Dataset |
|
|
|
|
|
The model is trained on **[UTKFace Dataset](https://susanqq.github.io/UTKFace/)** |
|
|
|
|
|
--- |
|
|
|
|
|
|
|
|
|
|
|
### Usage example |
|
|
##### NOTE: The input image is assumed to be cropped and aligned to contain only the face |
|
|
```python |
|
|
import onnxruntime as ort |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
import numpy as np |
|
|
|
|
|
# Load ONNX model |
|
|
session = ort.InferenceSession("AgeRaceGenderNet_v1.onnx") |
|
|
|
|
|
# Get input and output names |
|
|
input_name = session.get_inputs()[0].name |
|
|
output_names = [out.name for out in session.get_outputs()] |
|
|
|
|
|
# Preprocessing: Load and transform image |
|
|
# NOTE: The input image is assumed to be cropped and aligned to contain only the face |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((256, 256)), # Resize to model input size |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
img = Image.open("path_to_image.jpg").convert("RGB") |
|
|
img_tensor = transform(img).unsqueeze(0).numpy() # Shape: (1, 3, 256, 256) |
|
|
|
|
|
# Run inference |
|
|
age_logits, gender_logits, race_logits = session.run(output_names, {input_name: img_tensor}) |
|
|
|
|
|
# Postprocessing: Get predictions |
|
|
age_pred = int(np.argmax(age_logits, axis=1)[0]) |
|
|
gender_pred = int(np.argmax(gender_logits, axis=1)[0]) |
|
|
race_pred = int(np.argmax(race_logits, axis=1)[0]) |
|
|
|
|
|
# Convert predictions to labels |
|
|
def get_gender_text(gender_idx): |
|
|
return 'Male' if gender_idx == 0 else 'Female' |
|
|
|
|
|
def get_race_text(race_idx): |
|
|
race_map = { |
|
|
0: 'White', |
|
|
1: 'Black', |
|
|
2: 'Asian', |
|
|
3: 'Indian', |
|
|
4: 'Other' |
|
|
} |
|
|
return race_map.get(race_idx, 'Unknown') |
|
|
|
|
|
# Display results |
|
|
print(f"Predicted Age: {age_pred}") |
|
|
print(f"Predicted Gender: {get_gender_text(gender_pred)}") |
|
|
print(f"Predicted Race: {get_race_text(race_pred)}") |
|
|
|
|
|
``` |