Patrick Daniel commited on
Commit
f728991
·
1 Parent(s): 3aaf42a
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -47,20 +47,17 @@ model.to(device)
47
  def predict(image):
48
  try:
49
  transform = transforms.Compose([
50
- transforms.Resize(256),
51
- transforms.CenterCrop(224),
52
- transforms.ToTensor(), # Converts to [0, 1] range
53
- transforms.Normalize(
54
- mean=[0.485, 0.456, 0.406],
55
- std=[0.229, 0.224, 0.225]
56
- )
57
  ])
58
 
59
  pixel_values = transform(image).unsqueeze(0).to(device)
60
 
61
  with torch.no_grad():
62
  logits = model(pixel_values).logits
63
- probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
 
64
 
65
  topk = torch.topk(probs, k=5)
66
  top_indices = topk.indices.tolist()
 
47
  def predict(image):
48
  try:
49
  transform = transforms.Compose([
50
+ transforms.Resize((224, 224)), # match ViT input size
51
+ transforms.Normalize(mean=(0.485, 0.456, 0.406),
52
+ std=(0.229, 0.224, 0.225))
 
 
 
 
53
  ])
54
 
55
  pixel_values = transform(image).unsqueeze(0).to(device)
56
 
57
  with torch.no_grad():
58
  logits = model(pixel_values).logits
59
+ probs = torch.nn.functional.softmax(logits, dim=1).squeeze()
60
+
61
 
62
  topk = torch.topk(probs, k=5)
63
  top_indices = topk.indices.tolist()