EarTestHelper / app.py
anasmkh's picture
Create app.py
dabea11 verified
raw
history blame contribute delete
960 Bytes
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
processor = AutoImageProcessor.from_pretrained("anasmkh/customied_vit")
model = AutoModelForImageClassification.from_pretrained("anasmkh/customied_vit")
def classify_image(image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)[0]
best_idx = torch.argmax(probs).item()
label = model.config.id2label[best_idx]
score = float(probs[best_idx])
return {label: score}
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Custom Vision Transformer Classifier",
description="Upload an image to get classification results from the custom ViT model."
)
if __name__ == "__main__":
demo.launch(share=True)