Update app.py
Browse files
app.py
CHANGED
|
@@ -191,9 +191,14 @@ class ViT(nn.Module):
|
|
| 191 |
|
| 192 |
# Load model weights
|
| 193 |
checkpoint = torch.load("best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
model.load_state_dict(checkpoint["model_state"])
|
| 195 |
model.eval()
|
| 196 |
|
|
|
|
| 197 |
# Image preprocessing
|
| 198 |
transform = transforms.Compose([
|
| 199 |
transforms.Resize((32,32)),
|
|
|
|
| 191 |
|
| 192 |
# Load model weights
|
| 193 |
checkpoint = torch.load("best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device)
|
| 194 |
+
|
| 195 |
+
model = ViT(cfg).to(device)
|
| 196 |
+
|
| 197 |
+
# Load only the model weights
|
| 198 |
model.load_state_dict(checkpoint["model_state"])
|
| 199 |
model.eval()
|
| 200 |
|
| 201 |
+
|
| 202 |
# Image preprocessing
|
| 203 |
transform = transforms.Compose([
|
| 204 |
transforms.Resize((32,32)),
|