Pablogps commited on
Commit
97b9bc9
·
verified ·
1 Parent(s): d72f2f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -25,16 +25,21 @@ def predict(img):
25
  print("Starting prediction...", flush=True)
26
  img = PILImage.create(img)
27
 
28
- fimg = PILImage.create(img)
29
- x = Resize(224)(fimg) # -> fastai PILImage
30
- x = ToTensor()(x) # -> float tensor [0,1], shape [3,224,224]
31
- x = Normalize.from_stats(*imagenet_stats)(x)
32
- x = x.unsqueeze(0) # -> [1,3,224,224]
 
 
 
 
 
33
 
34
- # --- Forward pass ---
35
  learn.model.eval()
36
- logits = learn.model(x)
37
- probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
38
 
39
  # pred, pred_idx, probs = learn.predict(img)
40
  return {labels[i]: float(probs[i]) for i in range(len(labels))}
 
25
  print("Starting prediction...", flush=True)
26
  img = PILImage.create(img)
27
 
28
+ # 2) Torchvision preprocessing:
29
+ # Resize shorter side to 256, center-crop 224, ToTensor, ImageNet normalize
30
+ tfm = T.Compose([
31
+ T.Resize(256),
32
+ T.CenterCrop(224),
33
+ T.ToTensor(), # -> [C,H,W] in [0,1]
34
+ T.Normalize(mean=[0.485, 0.456, 0.406],
35
+ std=[0.229, 0.224, 0.225]),
36
+ ])
37
+ x = tfm(img).unsqueeze(0) # [1,3,224,224]
38
 
39
+ # 3) Forward pass
40
  learn.model.eval()
41
+ logits = learn.model(x) # [1, C]
42
+ probs = torch.softmax(logits, dim=1)[0] # [C]
43
 
44
  # pred, pred_idx, probs = learn.predict(img)
45
  return {labels[i]: float(probs[i]) for i in range(len(labels))}