FaceGuard-demo / app.py
hudaakram's picture
Deploy FaceGuard demo with proper README config
94f3b39 verified
import gradio as gr
import torch
from PIL import Image
from transformers import ViTForImageClassification, AutoImageProcessor
MODEL_ID = "hudaakram/FaceGuard-20ID-ViT"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTForImageClassification.from_pretrained(MODEL_ID).to(device).eval()
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
# readable class names from config (id2label maps idx -> celeb_id string)
id2label = {int(k): v for k, v in model.config.id2label.items()}
def predict_top5(img: Image.Image):
if img is None:
return {}
img = img.convert("RGB")
with torch.no_grad():
enc = processor(images=img, return_tensors="pt").to(device)
logits = model(pixel_values=enc["pixel_values"]).logits
probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
top5 = probs.argsort()[-5:][::-1] # indices of Top-5 classes
return {f"label {i} (celeb_id {id2label[i]})": float(probs[i]) for i in top5}
desc = (
"FaceGuard — Vision Transformer (google/vit-base-patch16-224) fine-tuned on a 20-identity subset of CelebA.\n"
"Upload a face crop to see the Top-5 predicted identities. (IDs are CelebA celeb_id integers.)"
)
demo = gr.Interface(
fn=predict_top5,
inputs=gr.Image(type="pil", label="Upload a face image"),
outputs=gr.Label(num_top_classes=5, label="Top-5 identities"),
title="FaceGuard – ViT (20 CelebA IDs)",
description=desc,
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()