--- license: mit tags: - classification - image-classification - face-classification - age-prediction - computer-vision - swin-transformer - race-prediction - gender-prediction --- ## AgeRaceGenderNet AgeRaceGenderNet **AgeRaceGenderNet** is a lightweight, multi-task face classification model capable of predicting **age**, **gender**, and **race** from facial images. It is built with a **Swin Transformer V2 (Tiny)** backbone and designed for **fast inference** (\~5.94 GFLOPs). --- ### Tasks & Outputs This model simultaneously predicts: * **\[age]**: Integer from `0` to `116`, representing estimated age. * **\[gender]**: `0` for **male**, `1` for **female**. * **\[race]**: Integer from `0` to `4` representing: * `0`: White * `1`: Black * `2`: Asian * `3`: Indian * `4`: Others (e.g., Hispanic, Latino, Middle Eastern) --- ### Model Architecture * **Backbone**: [Swin V2 Tiny](https://arxiv.org/abs/2111.09883) (pretrained and fine-tuned) * **Head**: Multi-task architecture with dedicated classification heads for each demographic task. * **Criterion**: Custom `MultiTaskLoss` function * **Total Parameters**: **28.4M** * **Trainable**: 25.7M * **Non-trainable**: 2.7M * **Model Size**: \~113.7 MB * **Inference Cost**: \~5.94 GFLOPs --- ### Training Dataset The model is trained on **[UTKFace Dataset](https://susanqq.github.io/UTKFace/)** --- ### Usage example ##### NOTE: The input image is assumed to be cropped and aligned to contain only the face ```python import onnxruntime as ort from PIL import Image from torchvision import transforms import numpy as np # Load ONNX model session = ort.InferenceSession("AgeRaceGenderNet_v1.onnx") # Get input and output names input_name = session.get_inputs()[0].name output_names = [out.name for out in session.get_outputs()] # Preprocessing: Load and transform image # NOTE: The input image is assumed to be cropped and aligned to contain only the face transform = transforms.Compose([ transforms.Resize((256, 256)), # Resize to model input size transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = Image.open("path_to_image.jpg").convert("RGB") img_tensor = transform(img).unsqueeze(0).numpy() # Shape: (1, 3, 256, 256) # Run inference age_logits, gender_logits, race_logits = session.run(output_names, {input_name: img_tensor}) # Postprocessing: Get predictions age_pred = int(np.argmax(age_logits, axis=1)[0]) gender_pred = int(np.argmax(gender_logits, axis=1)[0]) race_pred = int(np.argmax(race_logits, axis=1)[0]) # Convert predictions to labels def get_gender_text(gender_idx): return 'Male' if gender_idx == 0 else 'Female' def get_race_text(race_idx): race_map = { 0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Other' } return race_map.get(race_idx, 'Unknown') # Display results print(f"Predicted Age: {age_pred}") print(f"Predicted Gender: {get_gender_text(gender_pred)}") print(f"Predicted Race: {get_race_text(race_pred)}") ```