KFrimps commited on
Commit
a6e2c60
·
verified ·
1 Parent(s): 92bea80

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+
5
+ def predict(image):
6
+ """Predicts the class of the input image using the fine-tuned student model."""
7
+ # Convert the Gradio image to a PIL Image
8
+ image = Image.fromarray(image)
9
+ # Preprocess the image
10
+ inputs = processor(image, return_tensors="pt").to(device)
11
+ # Make prediction
12
+ with torch.no_grad():
13
+ outputs = student_model(**inputs)
14
+ predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
15
+ # Get predicted class label
16
+ predicted_class = id2label[predicted_class_idx]
17
+ return predicted_class
18
+
19
+ # Get a few examples from the test dataset
20
+ num_examples = 3 # Adjust the number of examples as needed
21
+ examples = []
22
+ for i in range(num_examples):
23
+ # Access the image data using 'bytes' key and convert to PIL Image
24
+ image = Image.open(io.BytesIO(train_test_valid_dataset["test"][i]["image"]['bytes']))
25
+
26
+ # Convert the PIL Image to a NumPy array
27
+ image_np = np.array(image)
28
+ examples.append(image_np)
29
+
30
+ iface = gr.Interface(
31
+ fn=predict,
32
+ inputs=gr.Image(type="numpy"),
33
+ outputs="text",
34
+ title="Pets Image Classification",
35
+ description="Upload an image of a cat or dog to get its breed prediction.",
36
+ examples=examples,
37
+ ).launch()