lantzmurray commited on
Commit
494fcf3
·
1 Parent(s): 410ba0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -1,17 +1,24 @@
1
- import requests
2
  import gradio as gr
3
- from fastai import *
4
 
 
 
 
 
 
5
 
6
- # Load the trained model
7
- with open("model.pkl", "rb") as f:
8
- learn = pickle.load(f)
9
 
10
  # Define the predict function
11
  def predict(image):
12
  # Load the input image
13
  img = PILImage.create(image)
14
 
 
 
 
 
15
  # Perform inference
16
  predicted_class, _, _ = learn.predict(img)
17
 
 
 
1
  import gradio as gr
2
+ from fastai.vision.all import *
3
 
4
+ # Train or load your fastai Learner object (replace this with your actual model training or loading code)
5
+ path = untar_data(URLs.PETS)
6
+ dls = ImageDataLoaders.from_folder(path/"images", valid_pct=0.2, seed=42)
7
+ learn = cnn_learner(dls, resnet34)
8
+ learn.fine_tune(1)
9
 
10
+ # Save the model using PyTorch's torch.save
11
+ torch.save(learn.model, "model.pth")
 
12
 
13
  # Define the predict function
14
  def predict(image):
15
  # Load the input image
16
  img = PILImage.create(image)
17
 
18
+ # Load the model using PyTorch's torch.load
19
+ model = torch.load("model.pth")
20
+ model.eval()
21
+
22
  # Perform inference
23
  predicted_class, _, _ = learn.predict(img)
24