Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,8 @@ import gradio as gr
|
|
| 7 |
from transformers import ViTFeatureExtractor
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
import spaces
|
|
|
|
|
|
|
| 10 |
|
| 11 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 12 |
model = None
|
|
@@ -17,6 +19,8 @@ VALID_DS_PATH = 'valid_ds.pth'
|
|
| 17 |
valid_ds = torch.load(VALID_DS_PATH)
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
| 20 |
from transformers import ViTModel
|
| 21 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 22 |
import torch.nn as nn
|
|
@@ -68,6 +72,11 @@ def run_inference(image, device, valid_ds):
|
|
| 68 |
model.eval()
|
| 69 |
# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
| 72 |
input_tensor = transform(image)
|
| 73 |
input_tensor = input_tensor.unsqueeze(0) # Add a batch dimension
|
|
|
|
| 7 |
from transformers import ViTFeatureExtractor
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
import spaces
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
|
| 12 |
|
| 13 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 14 |
model = None
|
|
|
|
| 19 |
valid_ds = torch.load(VALID_DS_PATH)
|
| 20 |
|
| 21 |
|
| 22 |
+
|
| 23 |
+
|
| 24 |
from transformers import ViTModel
|
| 25 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 26 |
import torch.nn as nn
|
|
|
|
| 72 |
model.eval()
|
| 73 |
# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k', do_rescale=False)
|
| 74 |
|
| 75 |
+
transform = transforms.Compose([
|
| 76 |
+
transforms.Resize((224, 224)), # Resize to the model's input size
|
| 77 |
+
transforms.ToTensor(),
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
| 81 |
input_tensor = transform(image)
|
| 82 |
input_tensor = input_tensor.unsqueeze(0) # Add a batch dimension
|