File size: 1,224 Bytes
6127b85
b5ba006
f4116c6
ef6b925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23cb710
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!pip install transformers gradio torch torchvision

from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import gradio as gr

# Step 1: Load the model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")

# Step 2: Define the function for prediction
def recognize_image(image):
    # Convert the input image to RGB
    image = Image.fromarray(image).convert("RGB")
    
    # Preprocess the image
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Make predictions
    outputs = model(**inputs)
    predicted_class_idx = outputs.logits.argmax(-1).item()
    
    # Get the predicted class label
    return model.config.id2label[predicted_class_idx]

# Step 3: Create a Gradio interface
app = gr.Interface(
    fn=recognize_image,                 # Prediction function
    inputs=gr.Image(type="numpy"),      # Input: Image
    outputs="text",                     # Output: Predicted label
    title="Image Recognition App"       # App title
)

# Launch the app
app.launch()
print(transformers.__version__)