Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import ViTForImageClassification, AutoImageProcessor | |
| MODEL_ID = "hudaakram/FaceGuard-20ID-ViT" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = ViTForImageClassification.from_pretrained(MODEL_ID).to(device).eval() | |
| processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
| # readable class names from config (id2label maps idx -> celeb_id string) | |
| id2label = {int(k): v for k, v in model.config.id2label.items()} | |
| def predict_top5(img: Image.Image): | |
| if img is None: | |
| return {} | |
| img = img.convert("RGB") | |
| with torch.no_grad(): | |
| enc = processor(images=img, return_tensors="pt").to(device) | |
| logits = model(pixel_values=enc["pixel_values"]).logits | |
| probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy() | |
| top5 = probs.argsort()[-5:][::-1] # indices of Top-5 classes | |
| return {f"label {i} (celeb_id {id2label[i]})": float(probs[i]) for i in top5} | |
| desc = ( | |
| "FaceGuard — Vision Transformer (google/vit-base-patch16-224) fine-tuned on a 20-identity subset of CelebA.\n" | |
| "Upload a face crop to see the Top-5 predicted identities. (IDs are CelebA celeb_id integers.)" | |
| ) | |
| demo = gr.Interface( | |
| fn=predict_top5, | |
| inputs=gr.Image(type="pil", label="Upload a face image"), | |
| outputs=gr.Label(num_top_classes=5, label="Top-5 identities"), | |
| title="FaceGuard – ViT (20 CelebA IDs)", | |
| description=desc, | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |