bran138 commited on
Commit
786e3d4
·
1 Parent(s): 336b8a3

Skip learn.predict and run manual inference

Browse files
Files changed (1) hide show
  1. app.py +39 -5
app.py CHANGED
@@ -31,9 +31,20 @@ import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
31
  #/export
32
  from fastai.vision.all import *
33
  import gradio as gr
 
34
 
35
  learn = load_learner('model-2.pkl')
36
 
 
 
 
 
 
 
 
 
 
 
37
  #Pydantic Warnings
38
  # UnsupportedFieldAttributeWarning: The 'repr' attribute with value False ...
39
  # UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True ...
@@ -89,12 +100,23 @@ learn = load_learner('model-2.pkl')
89
  # img = PILImage.create(img)
90
  # pred, pred_idx, probs = learn.predict(img)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def classify_flower(img):
93
  if img is None:
94
  return None
95
- # Gradio can pass PIL Image, numpy array, or dict (e.g. {"path": "..."}).
96
- # Fastai's transforms expect its own Image type; raw PIL/dict causes:
97
- # TypeError: unsupported operand type(s) for +: 'Image' and 'dict'
98
  if isinstance(img, dict):
99
  path = img.get("path")
100
  if isinstance(path, str):
@@ -104,8 +126,20 @@ def classify_flower(img):
104
  img = img["image"]
105
  else:
106
  return None
107
- img = PILImage.create(img)
108
- pred, pred_idx, probs = learn.predict(img)
 
 
 
 
 
 
 
 
 
 
 
 
109
  return {
110
  learn.dls.vocab[i]: float(probs[i])
111
  for i in range(len(probs))
 
31
  #/export
32
  from fastai.vision.all import *
33
  import gradio as gr
34
+ import torch
35
 
36
  learn = load_learner('model-2.pkl')
37
 
38
+ # Use valid batch to get input size and device for manual inference (avoids learn.predict
39
+ # passing (Image, dict) into transforms and triggering TypeError: 'PILImage' + 'dict').
40
+ try:
41
+ _inf_batch = next(iter(learn.dls.valid))
42
+ _INFERENCE_DEVICE = _inf_batch[0].device
43
+ _INFERENCE_SIZE = _inf_batch[0].shape[-1]
44
+ except Exception:
45
+ _INFERENCE_DEVICE = torch.device("cpu")
46
+ _INFERENCE_SIZE = 224
47
+
48
  #Pydantic Warnings
49
  # UnsupportedFieldAttributeWarning: The 'repr' attribute with value False ...
50
  # UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True ...
 
100
  # img = PILImage.create(img)
101
  # pred, pred_idx, probs = learn.predict(img)
102
 
103
+ def _preprocess_for_learner(pil_img):
104
+ """Turn PIL image into a batch tensor matching the learner's expected size and normalization."""
105
+ from torchvision import transforms as T
106
+ # Match typical fastai/ImageNet preprocessing
107
+ transform = T.Compose([
108
+ T.Resize((_INFERENCE_SIZE, _INFERENCE_SIZE)),
109
+ T.ToTensor(),
110
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
111
+ ])
112
+ x = transform(pil_img).unsqueeze(0).to(_INFERENCE_DEVICE)
113
+ return x
114
+
115
+
116
  def classify_flower(img):
117
  if img is None:
118
  return None
119
+ # Normalize Gradio input to PIL (can be dict with "path" or "image", or numpy, or PIL).
 
 
120
  if isinstance(img, dict):
121
  path = img.get("path")
122
  if isinstance(path, str):
 
126
  img = img["image"]
127
  else:
128
  return None
129
+ from PIL import Image as PILImageModule
130
+ if not isinstance(img, PILImageModule.Image):
131
+ img = np.asarray(img)
132
+ if img.ndim == 2:
133
+ img = np.stack([img] * 3, axis=-1)
134
+ img = PILImageModule.fromarray(img).convert("RGB")
135
+ else:
136
+ img = img.convert("RGB")
137
+ # Bypass learn.predict() to avoid (PILImage, dict) in the transform pipeline.
138
+ x = _preprocess_for_learner(img)
139
+ learn.model.eval()
140
+ with torch.no_grad():
141
+ logits = learn.model(x)
142
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
143
  return {
144
  learn.dls.vocab[i]: float(probs[i])
145
  for i in range(len(probs))