marshmallow / app.py
yangfang236
add torch serialization
840968e
import gradio as gr
from fastai.learner import load_learner
import pathlib
import platform
plt = platform.system()
if plt == 'Windows': pathlib.PosixPath=pathlib.WindowsPath
from pathlib import Path
import torch
import torch.serialization
def load_learner_safe(fname, map_location=None, pickle_module=None):
# This is to ensure it does not attempt to instantiate a WindowsPath
torch.serialization.add_safe_globals([Learner])
with open(fname, 'rb') as f:
# Load the model manually using torch.load and avoid WindowsPath instantiation
model_data = torch.load(f, map_location=map_location, pickle_module=pickle_module,weights_only=True)
return model_data
# Use this safe method to load the learner model
path_to_model=Path('export.pkl')
model_data=load_learner_safe(path_to_model)
learn = Learner(data, model=model_data['model'], loss_func=model_data['loss_func'], metrics=model_data['metrics'])
# %% Untitled.ipynb 3
def predict(img):
labels=learn.dls.vocab
img=PILImage.create(img)
pred,pred_idx,probs=learn.predict(img)
return {labels[i]:float(probs[i]) for i in range (len(labels))}
# %% Untitled.ipynb 5
learn = load_learner_safe('export.pkl')
# %% Untitled.ipynb 8
examples = ["covid-19.jpg", "normal.jpg", "viral pneumonia.jpg"]
intf = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload an Image (256x256)"),
outputs=gr.Label(label="Prediction"),
examples=examples
)
intf.launch(inline=False)