import gradio as gr import torch from PIL import Image from transformers import ViTForImageClassification, AutoImageProcessor MODEL_ID = "hudaakram/FaceGuard-20ID-ViT" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ViTForImageClassification.from_pretrained(MODEL_ID).to(device).eval() processor = AutoImageProcessor.from_pretrained(MODEL_ID) # readable class names from config (id2label maps idx -> celeb_id string) id2label = {int(k): v for k, v in model.config.id2label.items()} def predict_top5(img: Image.Image): if img is None: return {} img = img.convert("RGB") with torch.no_grad(): enc = processor(images=img, return_tensors="pt").to(device) logits = model(pixel_values=enc["pixel_values"]).logits probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy() top5 = probs.argsort()[-5:][::-1] # indices of Top-5 classes return {f"label {i} (celeb_id {id2label[i]})": float(probs[i]) for i in top5} desc = ( "FaceGuard — Vision Transformer (google/vit-base-patch16-224) fine-tuned on a 20-identity subset of CelebA.\n" "Upload a face crop to see the Top-5 predicted identities. (IDs are CelebA celeb_id integers.)" ) demo = gr.Interface( fn=predict_top5, inputs=gr.Image(type="pil", label="Upload a face image"), outputs=gr.Label(num_top_classes=5, label="Top-5 identities"), title="FaceGuard – ViT (20 CelebA IDs)", description=desc, allow_flagging="never", ) if __name__ == "__main__": demo.launch()