File size: 5,061 Bytes
f0ab5ad
 
 
 
ba0a6a3
f0ab5ad
 
 
 
 
5177c9a
f0ab5ad
 
ba0a6a3
f0ab5ad
 
5177c9a
f0ab5ad
ba0a6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0ab5ad
 
 
ba0a6a3
 
5177c9a
 
ba0a6a3
 
5177c9a
ba0a6a3
 
 
5177c9a
ba0a6a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5177c9a
 
 
 
 
ba0a6a3
 
 
 
 
5177c9a
ba0a6a3
 
 
 
 
 
 
 
5177c9a
 
 
ba0a6a3
 
 
 
 
 
5177c9a
 
ba0a6a3
5177c9a
ba0a6a3
 
5177c9a
ba0a6a3
 
 
5177c9a
ba0a6a3
 
 
 
5177c9a
ba0a6a3
 
 
 
 
 
614a445
5177c9a
 
614a445
 
ba0a6a3
5177c9a
 
ba0a6a3
 
5177c9a
ba0a6a3
 
5177c9a
 
ba0a6a3
5177c9a
 
 
ba0a6a3
5177c9a
ba0a6a3
 
 
 
5177c9a
 
 
ba0a6a3
 
 
5177c9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# app.py
import io
import os
from typing import List

import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
import numpy as np

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Model definition ---
class AgeGenderClassifier(nn.Module):
    def __init__(self):
        super(AgeGenderClassifier, self).__init__()
        self.intermediate = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 64),
            nn.ReLU(),
        )
        self.age_classifier = nn.Sequential(
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        self.gender_classifier = nn.Sequential(
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.intermediate(x)
        age = self.age_classifier(x)
        gender = self.gender_classifier(x)
        return age, gender


def build_model(weights_path: str):
    """Rebuild VGG16 backbone + custom avgpool/classifier then load weights."""
    backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

    # freeze all layers
    for p in backbone.parameters():
        p.requires_grad = False
    # optionally allow last block to be trainable
    for p in backbone.features[24:].parameters():
        p.requires_grad = True

    # replace avgpool
    backbone.avgpool = nn.Sequential(
        nn.Conv2d(512, 512, kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Flatten()
    )

    # attach classifier
    model = backbone
    model.classifier = AgeGenderClassifier()

    # load weights
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"Model weights not found at {weights_path}")

    state = torch.load(weights_path, map_location=device)
    try:
        model.load_state_dict(state)
    except Exception:
        if "model_state_dict" in state:
            model.load_state_dict(state["model_state_dict"])
        else:
            raise

    model.to(device)
    model.eval()
    return model


# --- Preprocessing ---
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

INV_AGE_SCALE = 80  # training used age/80 normalization


# --- Prediction function ---
def predict_images_with_text(images: List[Image.Image], model):
    """Return original images and captions for each."""
    if not images:
        return [], []

    tensors = []
    for im in images:
        if im.mode != "RGB":
            im = im.convert("RGB")
        tensors.append(transform(im))

    batch = torch.stack(tensors).to(device)

    with torch.no_grad():
        pred_age, pred_gender = model(batch)
        pred_age = pred_age.squeeze(-1).cpu().numpy()
        pred_gender = pred_gender.squeeze(-1).cpu().numpy()

    output_images = []
    captions = []

    for img, pa, pg in zip(images, pred_age, pred_gender):
        age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE)
        gender_label = "Female" if pg > 0.5 else "Male"
        gender_emoji = "πŸ‘©" if pg > 0.5 else "πŸ‘¨"
        conf = float(pg if pg > 0.5 else 1 - pg)

        output_images.append(np.array(img))
        captions.append(f"{gender_emoji} {gender_label} ({conf:.2f})  β€’  πŸŽ‚ Age β‰ˆ {age_val}")

    return output_images, captions


# --- Load model ---
MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth")
model = build_model(MODEL_WEIGHTS)


# --- Gradio UI ---
with gr.Blocks(title="FairFace Age & Gender β€” Multi-image Demo") as demo:
    gr.Markdown("""
    # 🧠 FairFace Multi-task Age & Gender Predictor
    Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results below the image.
    """)

    with gr.Row():
        img_input = gr.File(file_count="multiple", label="Upload images")
        run_btn = gr.Button("Run  ▢️")

    gallery = gr.Gallery(
        label="Uploaded Images",
        columns=3,
        height="auto"
    )

    captions = gr.HTML(label="Predictions")

    def run_and_predict(files):
        if not files:
            return [], ""
        pil_imgs = []
        for f in files:
            path = f if isinstance(f, str) else f.name
            pil_imgs.append(Image.open(path).convert("RGB"))

        imgs, texts = predict_images_with_text(pil_imgs, model)
        captions_html = "<br>".join([f"<h2>{t}</h2>" for t in texts])
        return imgs, captions_html

    run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery, captions])

    gr.Markdown("""
    ---
    **Tips & Notes**
    - Age is normalized to 0–80 years (approx.).
    - For best results, upload clear frontal face images.
    - This is a demo β€” respect privacy when using photos. πŸ™
    """)

if __name__ == "__main__":
    demo.launch()