File size: 5,061 Bytes
f0ab5ad ba0a6a3 f0ab5ad 5177c9a f0ab5ad ba0a6a3 f0ab5ad 5177c9a f0ab5ad ba0a6a3 f0ab5ad ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 614a445 5177c9a 614a445 ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a ba0a6a3 5177c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# app.py
import io
import os
from typing import List
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
import numpy as np
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Model definition ---
class AgeGenderClassifier(nn.Module):
def __init__(self):
super(AgeGenderClassifier, self).__init__()
self.intermediate = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 64),
nn.ReLU(),
)
self.age_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)
self.gender_classifier = nn.Sequential(
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.intermediate(x)
age = self.age_classifier(x)
gender = self.gender_classifier(x)
return age, gender
def build_model(weights_path: str):
"""Rebuild VGG16 backbone + custom avgpool/classifier then load weights."""
backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
# freeze all layers
for p in backbone.parameters():
p.requires_grad = False
# optionally allow last block to be trainable
for p in backbone.features[24:].parameters():
p.requires_grad = True
# replace avgpool
backbone.avgpool = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Flatten()
)
# attach classifier
model = backbone
model.classifier = AgeGenderClassifier()
# load weights
if not os.path.exists(weights_path):
raise FileNotFoundError(f"Model weights not found at {weights_path}")
state = torch.load(weights_path, map_location=device)
try:
model.load_state_dict(state)
except Exception:
if "model_state_dict" in state:
model.load_state_dict(state["model_state_dict"])
else:
raise
model.to(device)
model.eval()
return model
# --- Preprocessing ---
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
INV_AGE_SCALE = 80 # training used age/80 normalization
# --- Prediction function ---
def predict_images_with_text(images: List[Image.Image], model):
"""Return original images and captions for each."""
if not images:
return [], []
tensors = []
for im in images:
if im.mode != "RGB":
im = im.convert("RGB")
tensors.append(transform(im))
batch = torch.stack(tensors).to(device)
with torch.no_grad():
pred_age, pred_gender = model(batch)
pred_age = pred_age.squeeze(-1).cpu().numpy()
pred_gender = pred_gender.squeeze(-1).cpu().numpy()
output_images = []
captions = []
for img, pa, pg in zip(images, pred_age, pred_gender):
age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE)
gender_label = "Female" if pg > 0.5 else "Male"
gender_emoji = "π©" if pg > 0.5 else "π¨"
conf = float(pg if pg > 0.5 else 1 - pg)
output_images.append(np.array(img))
captions.append(f"{gender_emoji} {gender_label} ({conf:.2f}) β’ π Age β {age_val}")
return output_images, captions
# --- Load model ---
MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth")
model = build_model(MODEL_WEIGHTS)
# --- Gradio UI ---
with gr.Blocks(title="FairFace Age & Gender β Multi-image Demo") as demo:
gr.Markdown("""
# π§ FairFace Multi-task Age & Gender Predictor
Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results below the image.
""")
with gr.Row():
img_input = gr.File(file_count="multiple", label="Upload images")
run_btn = gr.Button("Run βΆοΈ")
gallery = gr.Gallery(
label="Uploaded Images",
columns=3,
height="auto"
)
captions = gr.HTML(label="Predictions")
def run_and_predict(files):
if not files:
return [], ""
pil_imgs = []
for f in files:
path = f if isinstance(f, str) else f.name
pil_imgs.append(Image.open(path).convert("RGB"))
imgs, texts = predict_images_with_text(pil_imgs, model)
captions_html = "<br>".join([f"<h2>{t}</h2>" for t in texts])
return imgs, captions_html
run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery, captions])
gr.Markdown("""
---
**Tips & Notes**
- Age is normalized to 0β80 years (approx.).
- For best results, upload clear frontal face images.
- This is a demo β respect privacy when using photos. π
""")
if __name__ == "__main__":
demo.launch()
|