faranbutt789 commited on
Commit
ba0a6a3
Β·
verified Β·
1 Parent(s): f0ab5ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -50
app.py CHANGED
@@ -2,6 +2,7 @@
2
  import io
3
  import os
4
  from typing import List
 
5
  import gradio as gr
6
  import torch
7
  import torch.nn as nn
@@ -10,59 +11,208 @@ import torchvision.transforms as T
10
  from PIL import Image, ImageDraw, ImageFont
11
  import numpy as np
12
 
13
-
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
-
17
  class AgeGenderClassifier(nn.Module):
18
- def __init__(self):
19
- super(AgeGenderClassifier, self).__init__()
20
-
21
- self.intermediate = nn.Sequential(
22
- nn.Linear(2048, 512),
23
- nn.ReLU(),
24
- nn.Dropout(0.4),
25
- nn.Linear(512, 128),
26
- nn.ReLU(),
27
- nn.Dropout(0.4),
28
- nn.Linear(128, 64),
29
- nn.ReLU(),
30
- )
31
- self.age_classifier = nn.Sequential(
32
- nn.Linear(64, 1),
33
- nn.Sigmoid()
34
- )
35
- self.gender_classifier = nn.Sequential(
36
- nn.Linear(64, 1),
37
- nn.Sigmoid()
38
- )
39
-
40
-
41
- def forward(self, x):
42
- x = self.intermediate(x)
43
- age = self.age_classifier(x)
44
- gender = self.gender_classifier(x)
45
- return age, gender
46
-
47
-
48
 
49
 
50
  def build_model(weights_path: str):
51
- """Rebuild VGG16 backbone + custom avgpool/classifier then load weights."""
52
- backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
53
-
54
- for p in backbone.parameters():
55
- p.requires_grad = False
56
-
57
-
58
-
59
- for p in backbone.features[24:].parameters():
60
- p.requires_grad = True
61
-
62
-
63
- backbone.avgpool = nn.Sequential(
64
- nn.Conv2d(512, 512, kernel_size=3),
65
- nn.MaxPool2d(2),
66
- nn.ReLU(),
67
- nn.Flatten()
68
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import io
3
  import os
4
  from typing import List
5
+
6
  import gradio as gr
7
  import torch
8
  import torch.nn as nn
 
11
  from PIL import Image, ImageDraw, ImageFont
12
  import numpy as np
13
 
14
+ # Device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # --- Model definition (must match your saved state) ---
18
  class AgeGenderClassifier(nn.Module):
19
+ def __init__(self):
20
+ super(AgeGenderClassifier, self).__init__()
21
+ # classifier expected input dim 2048 (as in your training run)
22
+ self.intermediate = nn.Sequential(
23
+ nn.Linear(2048, 512),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.4),
26
+ nn.Linear(512, 128),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.4),
29
+ nn.Linear(128, 64),
30
+ nn.ReLU(),
31
+ )
32
+ self.age_classifier = nn.Sequential(
33
+ nn.Linear(64, 1),
34
+ nn.Sigmoid()
35
+ )
36
+ self.gender_classifier = nn.Sequential(
37
+ nn.Linear(64, 1),
38
+ nn.Sigmoid()
39
+ )
40
+
41
+ def forward(self, x):
42
+ x = self.intermediate(x)
43
+ age = self.age_classifier(x)
44
+ gender = self.gender_classifier(x)
45
+ return age, gender
 
 
 
46
 
47
 
48
  def build_model(weights_path: str):
49
+ """Rebuild VGG16 backbone + custom avgpool/classifier then load weights."""
50
+ backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
51
+ # freeze all then fine-tune later if needed (same as training script)
52
+ for p in backbone.parameters():
53
+ p.requires_grad = False
54
+
55
+ # allow last block to be trainable if desired (kept same as your training code)
56
+ for p in backbone.features[24:].parameters():
57
+ p.requires_grad = True
58
+
59
+ # replace avgpool with the same block used during training (conv->maxpool->relu->flatten)
60
+ backbone.avgpool = nn.Sequential(
61
+ nn.Conv2d(512, 512, kernel_size=3),
62
+ nn.MaxPool2d(2),
63
+ nn.ReLU(),
64
+ nn.Flatten()
65
+ )
66
+
67
+ # attach classifier
68
+ model = backbone
69
+ model.classifier = AgeGenderClassifier()
70
+
71
+ # load weights
72
+ if not os.path.exists(weights_path):
73
+ raise FileNotFoundError(f"Model weights not found at {weights_path}")
74
+
75
+ state = torch.load(weights_path, map_location=device)
76
+ # If saved state was model.state_dict(), load directly
77
+ try:
78
+ model.load_state_dict(state)
79
+ except Exception:
80
+ # if state is a dict with other keys, try common wrappers
81
+ if "model_state_dict" in state:
82
+ model.load_state_dict(state["model_state_dict"])
83
+ else:
84
+ raise
85
+
86
+ model.to(device)
87
+ model.eval()
88
+ return model
89
+
90
+
91
+ # --- Preprocessing ---
92
+ transform = T.Compose([
93
+ T.Resize((224, 224)),
94
+ T.ToTensor(),
95
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
96
+ ])
97
+
98
+ INV_AGE_SCALE = 80 # training used age/80 normalization
99
+
100
+
101
+ def draw_caption_on_image(pil_img: Image.Image, caption: str):
102
+ """Draw caption at the top of the image with a semi-transparent background."""
103
+ img = pil_img.convert("RGBA")
104
+ overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
105
+ draw = ImageDraw.Draw(overlay)
106
+
107
+ # choose a font size relative to image
108
+ fontsize = max(14, img.width // 20)
109
+ try:
110
+ font = ImageFont.truetype("DejaVuSans-Bold.ttf", fontsize)
111
+ except Exception:
112
+ font = ImageFont.load_default()
113
+
114
+ text_w, text_h = draw.textsize(caption, font=font)
115
+ padding = 8
116
+ rect_h = text_h + padding * 2
117
+
118
+ # draw translucent rectangle
119
+ draw.rectangle([(0, 0), (img.width, rect_h)], fill=(0, 0, 0, 160))
120
+ # write text
121
+ draw.text((padding, padding), caption, font=font, fill=(255, 255, 255, 255))
122
+
123
+ out = Image.alpha_composite(img, overlay).convert("RGB")
124
+ return out
125
+
126
+
127
+ # --- Prediction function for multiple images ---
128
+
129
+ def predict_images(images: List[Image.Image], model) -> List[Image.Image]:
130
+ """Takes a list of PIL images and returns list of PIL images annotated with predictions."""
131
+ if images is None or len(images) == 0:
132
+ return []
133
+
134
+ # preprocess all images into a batch
135
+ tensors = []
136
+ for im in images:
137
+ if im.mode != "RGB":
138
+ im = im.convert("RGB")
139
+ t = transform(im)
140
+ tensors.append(t)
141
+
142
+ batch = torch.stack(tensors).to(device)
143
+
144
+ with torch.no_grad():
145
+ pred_age, pred_gender = model(batch)
146
+ # ensure shapes (N,1)
147
+ pred_age = pred_age.squeeze(-1).cpu().numpy()
148
+ pred_gender = pred_gender.squeeze(-1).cpu().numpy()
149
+
150
+ outputs = []
151
+ for img, pa, pg in zip(images, pred_age, pred_gender):
152
+ age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE)
153
+ gender_label = "Female" if pg > 0.5 else "Male"
154
+ gender_emoji = "πŸ‘©" if pg > 0.5 else "πŸ‘¨"
155
+ conf = float(pg if pg > 0.5 else 1 - pg)
156
+
157
+ caption = f"{gender_emoji} {gender_label} ({conf:.2f}) β€’ πŸŽ‚ Age β‰ˆ {age_val}"
158
+ out_img = draw_caption_on_image(img, caption)
159
+ outputs.append(out_img)
160
+
161
+ return outputs
162
+
163
+
164
+ # --- Load model once on startup ---
165
+ MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth")
166
+ model = build_model(MODEL_WEIGHTS)
167
+
168
+ # --- Gradio UI ---
169
+ with gr.Blocks(title="FairFace Age & Gender β€” Multi-image Demo") as demo:
170
+ gr.Markdown("""
171
+ # 🧠 FairFace Multi-task Age & Gender Predictor
172
+ Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results right on the picture.
173
+
174
+ **How to use**
175
+ 1. Click **Browse** or drag & drop multiple images. βœ…
176
+ 2. Click **Run**. The model processes images and shows results below. ⚑
177
+ 3. Use the download button on the output images if you want to save them.
178
+
179
+ *Note:* Age is estimated (approx.). This model was trained on the FairFace dataset.
180
+ """)
181
+
182
+ with gr.Row():
183
+ img_input = gr.File(file_count="multiple", label="Upload images")
184
+ run_btn = gr.Button("Run ▢️")
185
+
186
+ gallery = gr.Gallery(label="Predictions", show_label=True, elem_id="gallery").style(grid=[3], height="auto")
187
+
188
+ def run_and_predict(files):
189
+ # files is list of uploaded file dicts or file paths depending on environment
190
+ if not files:
191
+ return []
192
+
193
+ pil_imgs = []
194
+ # if File component returns list of dicts in HF spaces, handle both
195
+ for f in files:
196
+ # f might be a path string or dict-like
197
+ if isinstance(f, dict) and "name" in f and "data" in f:
198
+ # web upload format
199
+ im = Image.open(io.BytesIO(f["data"]))
200
+ else:
201
+ path = f if isinstance(f, str) else f.name
202
+ im = Image.open(path)
203
+ pil_imgs.append(im.convert("RGB"))
204
+
205
+ return predict_images(pil_imgs, model)
206
+
207
+ run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery])
208
+
209
+ gr.Markdown("""
210
+ ---
211
+ **Tips & Notes**
212
+ - The model outputs age normalized to 0–80 years (approx).
213
+ - If results look odd, try a clearer, frontal face image.
214
+ - This demo is for research / demo purposes only β€” be mindful of privacy. πŸ™
215
+ """)
216
+
217
+ if __name__ == "__main__":
218
+ demo.launch()