# app.py import io import os from typing import List import gradio as gr import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as T from PIL import Image import numpy as np # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Model definition --- class AgeGenderClassifier(nn.Module): def __init__(self): super(AgeGenderClassifier, self).__init__() self.intermediate = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.4), nn.Linear(128, 64), nn.ReLU(), ) self.age_classifier = nn.Sequential( nn.Linear(64, 1), nn.Sigmoid() ) self.gender_classifier = nn.Sequential( nn.Linear(64, 1), nn.Sigmoid() ) def forward(self, x): x = self.intermediate(x) age = self.age_classifier(x) gender = self.gender_classifier(x) return age, gender def build_model(weights_path: str): """Rebuild VGG16 backbone + custom avgpool/classifier then load weights.""" backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) # freeze all layers for p in backbone.parameters(): p.requires_grad = False # optionally allow last block to be trainable for p in backbone.features[24:].parameters(): p.requires_grad = True # replace avgpool backbone.avgpool = nn.Sequential( nn.Conv2d(512, 512, kernel_size=3), nn.MaxPool2d(2), nn.ReLU(), nn.Flatten() ) # attach classifier model = backbone model.classifier = AgeGenderClassifier() # load weights if not os.path.exists(weights_path): raise FileNotFoundError(f"Model weights not found at {weights_path}") state = torch.load(weights_path, map_location=device) try: model.load_state_dict(state) except Exception: if "model_state_dict" in state: model.load_state_dict(state["model_state_dict"]) else: raise model.to(device) model.eval() return model # --- Preprocessing --- transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) INV_AGE_SCALE = 80 # training used age/80 normalization # --- Prediction function --- def predict_images_with_text(images: List[Image.Image], model): """Return original images and captions for each.""" if not images: return [], [] tensors = [] for im in images: if im.mode != "RGB": im = im.convert("RGB") tensors.append(transform(im)) batch = torch.stack(tensors).to(device) with torch.no_grad(): pred_age, pred_gender = model(batch) pred_age = pred_age.squeeze(-1).cpu().numpy() pred_gender = pred_gender.squeeze(-1).cpu().numpy() output_images = [] captions = [] for img, pa, pg in zip(images, pred_age, pred_gender): age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE) gender_label = "Female" if pg > 0.5 else "Male" gender_emoji = "👩" if pg > 0.5 else "👨" conf = float(pg if pg > 0.5 else 1 - pg) output_images.append(np.array(img)) captions.append(f"{gender_emoji} {gender_label} ({conf:.2f}) • 🎂 Age ≈ {age_val}") return output_images, captions # --- Load model --- MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth") model = build_model(MODEL_WEIGHTS) # --- Gradio UI --- with gr.Blocks(title="FairFace Age & Gender — Multi-image Demo") as demo: gr.Markdown(""" # 🧠 FairFace Multi-task Age & Gender Predictor Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results below the image. """) with gr.Row(): img_input = gr.File(file_count="multiple", label="Upload images") run_btn = gr.Button("Run ▶️") gallery = gr.Gallery( label="Uploaded Images", columns=3, height="auto" ) captions = gr.HTML(label="Predictions") def run_and_predict(files): if not files: return [], "" pil_imgs = [] for f in files: path = f if isinstance(f, str) else f.name pil_imgs.append(Image.open(path).convert("RGB")) imgs, texts = predict_images_with_text(pil_imgs, model) captions_html = "
".join([f"

{t}

" for t in texts]) return imgs, captions_html run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery, captions]) gr.Markdown(""" --- **Tips & Notes** - Age is normalized to 0–80 years (approx.). - For best results, upload clear frontal face images. - This is a demo — respect privacy when using photos. 🙏 """) if __name__ == "__main__": demo.launch()