limitedonly41 commited on
Commit
a3dc540
·
verified ·
1 Parent(s): 4042db3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -7,6 +7,8 @@ import gradio as gr
7
  from transformers import ViTFeatureExtractor
8
  from huggingface_hub import hf_hub_download
9
  import spaces
 
 
10
 
11
  HF_TOKEN = os.environ.get("HF_TOKEN")
12
  model = None
@@ -17,6 +19,8 @@ VALID_DS_PATH = 'valid_ds.pth'
17
  valid_ds = torch.load(VALID_DS_PATH)
18
 
19
 
 
 
20
  from transformers import ViTModel
21
  from transformers.modeling_outputs import SequenceClassifierOutput
22
  import torch.nn as nn
@@ -68,6 +72,11 @@ def run_inference(image, device, valid_ds):
68
  model.eval()
69
  # feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False)
70
 
 
 
 
 
 
71
  image = Image.fromarray(image.astype('uint8'), 'RGB')
72
  input_tensor = transform(image)
73
  input_tensor = input_tensor.unsqueeze(0) # Add a batch dimension
 
7
  from transformers import ViTFeatureExtractor
8
  from huggingface_hub import hf_hub_download
9
  import spaces
10
+ from torchvision import transforms
11
+
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
  model = None
 
19
  valid_ds = torch.load(VALID_DS_PATH)
20
 
21
 
22
+
23
+
24
  from transformers import ViTModel
25
  from transformers.modeling_outputs import SequenceClassifierOutput
26
  import torch.nn as nn
 
72
  model.eval()
73
  # feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False)
74
 
75
+ transform = transforms.Compose([
76
+ transforms.Resize((224, 224)), # Resize to the model's input size
77
+ transforms.ToTensor(),
78
+ ])
79
+
80
  image = Image.fromarray(image.astype('uint8'), 'RGB')
81
  input_tensor = transform(image)
82
  input_tensor = input_tensor.unsqueeze(0) # Add a batch dimension