lantzmurray commited on
Commit
313dbee
·
1 Parent(s): 494fcf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -14
app.py CHANGED
@@ -1,28 +1,21 @@
1
  import gradio as gr
2
  from fastai.vision.all import *
3
 
4
- # Train or load your fastai Learner object (replace this with your actual model training or loading code)
5
- path = untar_data(URLs.PETS)
6
- dls = ImageDataLoaders.from_folder(path/"images", valid_pct=0.2, seed=42)
7
- learn = cnn_learner(dls, resnet34)
8
- learn.fine_tune(1)
9
 
10
- # Save the model using PyTorch's torch.save
11
- torch.save(learn.model, "model.pth")
12
 
13
  # Define the predict function
14
  def predict(image):
15
  # Load the input image
16
  img = PILImage.create(image)
17
 
18
- # Load the model using PyTorch's torch.load
19
- model = torch.load("model.pth")
20
- model.eval()
21
 
22
- # Perform inference
23
- predicted_class, _, _ = learn.predict(img)
24
-
25
- return predicted_class
26
 
27
  # Create Gradio interface
28
  iface = gr.Interface(
 
1
  import gradio as gr
2
  from fastai.vision.all import *
3
 
4
+ path = Path()
5
+ path.ls(file_exts='.pkl')
 
 
 
6
 
7
+ # Specify the path to your fastai model file
8
+ learn_inf = load_learner(path/'export.pkl')
9
 
10
  # Define the predict function
11
  def predict(image):
12
  # Load the input image
13
  img = PILImage.create(image)
14
 
15
+ # Perform inference using the loaded fastai Learner
16
+ pred, pred_idx, probs = learn_inf.predict(img)
 
17
 
18
+ return f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'
 
 
 
19
 
20
  # Create Gradio interface
21
  iface = gr.Interface(