# app.py
import io
import os
from typing import List
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
import numpy as np
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Model definition ---
class AgeGenderClassifier(nn.Module):
def __init__(self):
super(AgeGenderClassifier, self).__init__()
self.intermediate = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 64),
nn.ReLU(),
)
self.age_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)
self.gender_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.intermediate(x)
age = self.age_classifier(x)
gender = self.gender_classifier(x)
return age, gender
def build_model(weights_path: str):
"""Rebuild VGG16 backbone + custom avgpool/classifier then load weights."""
backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
# freeze all layers
for p in backbone.parameters():
p.requires_grad = False
# optionally allow last block to be trainable
for p in backbone.features[24:].parameters():
p.requires_grad = True
# replace avgpool
backbone.avgpool = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Flatten()
)
# attach classifier
model = backbone
model.classifier = AgeGenderClassifier()
# load weights
if not os.path.exists(weights_path):
raise FileNotFoundError(f"Model weights not found at {weights_path}")
state = torch.load(weights_path, map_location=device)
try:
model.load_state_dict(state)
except Exception:
if "model_state_dict" in state:
model.load_state_dict(state["model_state_dict"])
else:
raise
model.to(device)
model.eval()
return model
# --- Preprocessing ---
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
INV_AGE_SCALE = 80 # training used age/80 normalization
# --- Prediction function ---
def predict_images_with_text(images: List[Image.Image], model):
"""Return original images and captions for each."""
if not images:
return [], []
tensors = []
for im in images:
if im.mode != "RGB":
im = im.convert("RGB")
tensors.append(transform(im))
batch = torch.stack(tensors).to(device)
with torch.no_grad():
pred_age, pred_gender = model(batch)
pred_age = pred_age.squeeze(-1).cpu().numpy()
pred_gender = pred_gender.squeeze(-1).cpu().numpy()
output_images = []
captions = []
for img, pa, pg in zip(images, pred_age, pred_gender):
age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE)
gender_label = "Female" if pg > 0.5 else "Male"
gender_emoji = "👩" if pg > 0.5 else "👨"
conf = float(pg if pg > 0.5 else 1 - pg)
output_images.append(np.array(img))
captions.append(f"{gender_emoji} {gender_label} ({conf:.2f}) • 🎂 Age ≈ {age_val}")
return output_images, captions
# --- Load model ---
MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth")
model = build_model(MODEL_WEIGHTS)
# --- Gradio UI ---
with gr.Blocks(title="FairFace Age & Gender — Multi-image Demo") as demo:
gr.Markdown("""
# 🧠 FairFace Multi-task Age & Gender Predictor
Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results below the image.
""")
with gr.Row():
img_input = gr.File(file_count="multiple", label="Upload images")
run_btn = gr.Button("Run ▶️")
gallery = gr.Gallery(
label="Uploaded Images",
columns=3,
height="auto"
)
captions = gr.HTML(label="Predictions")
def run_and_predict(files):
if not files:
return [], ""
pil_imgs = []
for f in files:
path = f if isinstance(f, str) else f.name
pil_imgs.append(Image.open(path).convert("RGB"))
imgs, texts = predict_images_with_text(pil_imgs, model)
captions_html = "
".join([f"