sdafd commited on
Commit
9d086a7
·
1 Parent(s): fd8a9a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -3,6 +3,9 @@ import torch.nn as nn
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import gradio as gr
 
 
 
6
 
7
  # Load your trained model
8
  with torch.no_grad():
@@ -22,7 +25,7 @@ def preprocess(image):
22
  # Define the predict function
23
  def predict(image):
24
  # Preprocess the image
25
- input_tensor = preprocess(image)
26
 
27
  # Make a prediction
28
  with torch.no_grad():
@@ -30,6 +33,7 @@ def predict(image):
30
 
31
  # Perform post-processing if needed (e.g., softmax for probabilities)
32
  # Replace this with your actual post-processing logic
 
33
  probabilities = torch.softmax(output, dim=1).squeeze().tolist()
34
 
35
  # Map the class indices to class labels
 
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import gradio as gr
6
+ from transformers import ViTFeatureExtractor
7
+
8
+ transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
9
 
10
  # Load your trained model
11
  with torch.no_grad():
 
25
  # Define the predict function
26
  def predict(image):
27
  # Preprocess the image
28
+ input_tensor = transforms(image, return_tensors='pt')
29
 
30
  # Make a prediction
31
  with torch.no_grad():
 
33
 
34
  # Perform post-processing if needed (e.g., softmax for probabilities)
35
  # Replace this with your actual post-processing logic
36
+ print(output.logits.argmax(1).item())
37
  probabilities = torch.softmax(output, dim=1).squeeze().tolist()
38
 
39
  # Map the class indices to class labels