sdafd commited on
Commit
49dd760
·
1 Parent(s): 9d086a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -3,9 +3,6 @@ import torch.nn as nn
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import gradio as gr
6
- from transformers import ViTFeatureExtractor
7
-
8
- transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
9
 
10
  # Load your trained model
11
  with torch.no_grad():
@@ -25,7 +22,7 @@ def preprocess(image):
25
  # Define the predict function
26
  def predict(image):
27
  # Preprocess the image
28
- input_tensor = transforms(image, return_tensors='pt')
29
 
30
  # Make a prediction
31
  with torch.no_grad():
@@ -44,10 +41,11 @@ def predict(image):
44
 
45
  return predictions
46
 
 
47
  # Create the Gradio interface
48
  iface = gr.Interface(
49
  fn=predict,
50
- inputs=gr.Image(),
51
  outputs=gr.Label(num_top_classes=4),
52
  live=True
53
  )
 
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import gradio as gr
 
 
 
6
 
7
  # Load your trained model
8
  with torch.no_grad():
 
22
  # Define the predict function
23
  def predict(image):
24
  # Preprocess the image
25
+ input_tensor = preprocess(image)
26
 
27
  # Make a prediction
28
  with torch.no_grad():
 
41
 
42
  return predictions
43
 
44
+
45
  # Create the Gradio interface
46
  iface = gr.Interface(
47
  fn=predict,
48
+ inputs=gr.Image(preprocess),
49
  outputs=gr.Label(num_top_classes=4),
50
  live=True
51
  )