|
|
|
|
|
import io |
|
|
import os |
|
|
from typing import List |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.models as models |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
class AgeGenderClassifier(nn.Module): |
|
|
def __init__(self): |
|
|
super(AgeGenderClassifier, self).__init__() |
|
|
self.intermediate = nn.Sequential( |
|
|
nn.Linear(2048, 512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(512, 128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(128, 64), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.age_classifier = nn.Sequential( |
|
|
nn.Linear(64, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
self.gender_classifier = nn.Sequential( |
|
|
nn.Linear(64, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.intermediate(x) |
|
|
age = self.age_classifier(x) |
|
|
gender = self.gender_classifier(x) |
|
|
return age, gender |
|
|
|
|
|
|
|
|
def build_model(weights_path: str): |
|
|
"""Rebuild VGG16 backbone + custom avgpool/classifier then load weights.""" |
|
|
backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) |
|
|
|
|
|
|
|
|
for p in backbone.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
for p in backbone.features[24:].parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
backbone.avgpool = nn.Sequential( |
|
|
nn.Conv2d(512, 512, kernel_size=3), |
|
|
nn.MaxPool2d(2), |
|
|
nn.ReLU(), |
|
|
nn.Flatten() |
|
|
) |
|
|
|
|
|
|
|
|
model = backbone |
|
|
model.classifier = AgeGenderClassifier() |
|
|
|
|
|
|
|
|
if not os.path.exists(weights_path): |
|
|
raise FileNotFoundError(f"Model weights not found at {weights_path}") |
|
|
|
|
|
state = torch.load(weights_path, map_location=device) |
|
|
try: |
|
|
model.load_state_dict(state) |
|
|
except Exception: |
|
|
if "model_state_dict" in state: |
|
|
model.load_state_dict(state["model_state_dict"]) |
|
|
else: |
|
|
raise |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((224, 224)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
INV_AGE_SCALE = 80 |
|
|
|
|
|
|
|
|
|
|
|
def predict_images_with_text(images: List[Image.Image], model): |
|
|
"""Return original images and captions for each.""" |
|
|
if not images: |
|
|
return [], [] |
|
|
|
|
|
tensors = [] |
|
|
for im in images: |
|
|
if im.mode != "RGB": |
|
|
im = im.convert("RGB") |
|
|
tensors.append(transform(im)) |
|
|
|
|
|
batch = torch.stack(tensors).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_age, pred_gender = model(batch) |
|
|
pred_age = pred_age.squeeze(-1).cpu().numpy() |
|
|
pred_gender = pred_gender.squeeze(-1).cpu().numpy() |
|
|
|
|
|
output_images = [] |
|
|
captions = [] |
|
|
|
|
|
for img, pa, pg in zip(images, pred_age, pred_gender): |
|
|
age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE) |
|
|
gender_label = "Female" if pg > 0.5 else "Male" |
|
|
gender_emoji = "π©" if pg > 0.5 else "π¨" |
|
|
conf = float(pg if pg > 0.5 else 1 - pg) |
|
|
|
|
|
output_images.append(np.array(img)) |
|
|
captions.append(f"{gender_emoji} {gender_label} ({conf:.2f}) β’ π Age β {age_val}") |
|
|
|
|
|
return output_images, captions |
|
|
|
|
|
|
|
|
|
|
|
MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth") |
|
|
model = build_model(MODEL_WEIGHTS) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="FairFace Age & Gender β Multi-image Demo") as demo: |
|
|
gr.Markdown(""" |
|
|
# π§ FairFace Multi-task Age & Gender Predictor |
|
|
Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results below the image. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
img_input = gr.File(file_count="multiple", label="Upload images") |
|
|
run_btn = gr.Button("Run βΆοΈ") |
|
|
|
|
|
gallery = gr.Gallery( |
|
|
label="Uploaded Images", |
|
|
columns=3, |
|
|
height="auto" |
|
|
) |
|
|
|
|
|
captions = gr.HTML(label="Predictions") |
|
|
|
|
|
def run_and_predict(files): |
|
|
if not files: |
|
|
return [], "" |
|
|
pil_imgs = [] |
|
|
for f in files: |
|
|
path = f if isinstance(f, str) else f.name |
|
|
pil_imgs.append(Image.open(path).convert("RGB")) |
|
|
|
|
|
imgs, texts = predict_images_with_text(pil_imgs, model) |
|
|
captions_html = "<br>".join([f"<h2>{t}</h2>" for t in texts]) |
|
|
return imgs, captions_html |
|
|
|
|
|
run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery, captions]) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
**Tips & Notes** |
|
|
- Age is normalized to 0β80 years (approx.). |
|
|
- For best results, upload clear frontal face images. |
|
|
- This is a demo β respect privacy when using photos. π |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|