bran138's picture
Skip learn.predict and run manual inference
786e3d4
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
# In[1]:
# import fastai
# print(fastai.__version__)
# In[2]:
#/export
from fastai.vision.all import *
import gradio as gr
import torch
learn = load_learner('model-2.pkl')
# Use valid batch to get input size and device for manual inference (avoids learn.predict
# passing (Image, dict) into transforms and triggering TypeError: 'PILImage' + 'dict').
try:
_inf_batch = next(iter(learn.dls.valid))
_INFERENCE_DEVICE = _inf_batch[0].device
_INFERENCE_SIZE = _inf_batch[0].shape[-1]
except Exception:
_INFERENCE_DEVICE = torch.device("cpu")
_INFERENCE_SIZE = 224
#Pydantic Warnings
# UnsupportedFieldAttributeWarning: The 'repr' attribute with value False ...
# UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True ...
# These are coming from pydantic, a dependency used by some fastai internals or Gradio.
# They don’t affect your model predictions.
# It’s just saying certain metadata on fields is ignored in Python 3.12.
# You can safely ignore these.
# Pickle Security Warning
# UserWarning: load_learner uses Python's insecure pickle module ...
# This is fastai reminding you that load_learner uses pickle, which can execute arbitrary code.
# Safe for you because this is your own export.pkl file.
# The warning is just a general caution; it doesn’t break your notebook.
# ✅ Only a concern if you were loading a pickle from an untrusted source.
# In[3]:
#/export
# Function for multi-class flower predictions
# This returns a Python dictionary with class names as keys and their corresponding probabilities as values.
# def classify_flower(img):
# pred, pred_idx, probs = learn.predict(img)
# return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}
# def classify_flower(img):
# pred, pred_idx, probs = learn.predict(img)
# # Convert to dict and sort by probability descending
# result = {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}
# return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
# def classify_flower(img):
# pred, pred_idx, probs = learn.predict(img)
# return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}
# with gr.Blocks() as demo:
# gr.Markdown("## 🌸 16 Flower Classifier")
# with gr.Row():
# image_input = gr.Image(type="pil")
# label_output = gr.Label()
# submit_btn = gr.Button("Submit")
# submit_btn.click(fn=classify_flower, inputs=image_input, outputs=label_output)
# def classify_flower(img):
# img = PILImage.create(img)
# pred, pred_idx, probs = learn.predict(img)
def _preprocess_for_learner(pil_img):
"""Turn PIL image into a batch tensor matching the learner's expected size and normalization."""
from torchvision import transforms as T
# Match typical fastai/ImageNet preprocessing
transform = T.Compose([
T.Resize((_INFERENCE_SIZE, _INFERENCE_SIZE)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
x = transform(pil_img).unsqueeze(0).to(_INFERENCE_DEVICE)
return x
def classify_flower(img):
if img is None:
return None
# Normalize Gradio input to PIL (can be dict with "path" or "image", or numpy, or PIL).
if isinstance(img, dict):
path = img.get("path")
if isinstance(path, str):
from PIL import Image as PILImageModule
img = PILImageModule.open(path).convert("RGB")
elif img.get("image") is not None:
img = img["image"]
else:
return None
from PIL import Image as PILImageModule
if not isinstance(img, PILImageModule.Image):
img = np.asarray(img)
if img.ndim == 2:
img = np.stack([img] * 3, axis=-1)
img = PILImageModule.fromarray(img).convert("RGB")
else:
img = img.convert("RGB")
# Bypass learn.predict() to avoid (PILImage, dict) in the transform pipeline.
x = _preprocess_for_learner(img)
learn.model.eval()
with torch.no_grad():
logits = learn.model(x)
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
return {
learn.dls.vocab[i]: float(probs[i])
for i in range(len(probs))
}
with gr.Blocks() as demo:
gr.Markdown("## 🌸 16 Flower Classifier")
with gr.Row():
image_input = gr.Image(type="pil")
label_output = gr.Label(num_top_classes=16)
submit_btn = gr.Button("Submit")
submit_btn.click(
fn=classify_flower,
inputs=image_input,
outputs=label_output
)
# def classify_flower(img):
# pred, pred_idx, probs = learn.predict(img)
# # gr.Label expects a dict {label: confidence}
# result = {
# learn.dls.vocab[i]: float(probs[i])
# for i in range(len(probs))
# }
# return result
# with gr.Blocks() as demo:
# gr.Markdown("## 🌸 16 Flower Classifier")
# with gr.Row():
# image_input = gr.Image(type="pil")
# label_output = gr.Label()
# submit_btn = gr.Button("Submit")
# submit_btn.click(fn=classify_flower, inputs=image_input, outputs=label_output)
demo.launch()
# ## EXPORT
# In[19]: