Aumkeshchy2003 commited on
Commit
890dc2c
·
verified ·
1 Parent(s): 808c348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -191,9 +191,14 @@ class ViT(nn.Module):
191
 
192
  # Load model weights
193
  checkpoint = torch.load("best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device)
 
 
 
 
194
  model.load_state_dict(checkpoint["model_state"])
195
  model.eval()
196
 
 
197
  # Image preprocessing
198
  transform = transforms.Compose([
199
  transforms.Resize((32,32)),
 
191
 
192
  # Load model weights
193
  checkpoint = torch.load("best_ViT_CIFAR100_baseline_checkpoint.pth", map_location=device)
194
+
195
+ model = ViT(cfg).to(device)
196
+
197
+ # Load only the model weights
198
  model.load_state_dict(checkpoint["model_state"])
199
  model.eval()
200
 
201
+
202
  # Image preprocessing
203
  transform = transforms.Compose([
204
  transforms.Resize((32,32)),