handpredic / app.py
SuriRaja's picture
Update app.py
ae455b5 verified
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
# Load the pre-trained model and feature extractor from Hugging Face
model_name = "nateraw/vit-base-patch16-224-in21k-finetuned-mnist"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Transformation to apply to the input images
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3), # Convert image to 3 channels
transforms.Resize((224, 224)), # Resize to 224x224 for ViT
])
def predict_image(img):
"""Function to predict the digit from an input image."""
img = Image.fromarray(img).convert('RGB') # Convert image to RGB
img_transformed = transform(img)
inputs = feature_extractor(images=[img_transformed], return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
prediction = logits.argmax(dim=1).item()
return prediction
# Define Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.inputs.Image(shape=(224, 224), image_mode='L', invert_colors=False),
outputs=gr.outputs.Textbox(label="Predicted Digit")
)
# Launch the interface
iface.launch()