mulasagg commited on
Commit
ac20ecf
·
1 Parent(s): 7166d38
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -12,6 +12,7 @@ import pickle
12
  import torch
13
  from PIL import Image
14
  from src.utils.get_features import get_img_api
 
15
 
16
  # Path to the dataset
17
  data_path = 'src/data/subset_dataset.csv'
@@ -26,10 +27,12 @@ simple_transform = transforms.Compose([
26
 
27
  # Load the model
28
  def load_model(model_path, device='cpu'):
29
- """Loads the model from a pickle file and moves it to the specified device."""
30
- with open(model_path, 'rb') as f:
31
- model = pickle.load(f)
32
- return model.to(device)
 
 
33
 
34
  # Get prediction
35
  def get_prediction(model, padded_sequences, img_x, device='cpu'):
 
12
  import torch
13
  from PIL import Image
14
  from src.utils.get_features import get_img_api
15
+ import joblib
16
 
17
  # Path to the dataset
18
  data_path = 'src/data/subset_dataset.csv'
 
27
 
28
  # Load the model
29
  def load_model(model_path, device='cpu'):
30
+ """Loads the model from a joblib file and moves it to the specified device."""
31
+ model = joblib.load(model_path)
32
+ # If the model contains PyTorch tensors, move them to the specified device
33
+ if isinstance(model, torch.nn.Module):
34
+ model = model.to(device)
35
+ return model
36
 
37
  # Get prediction
38
  def get_prediction(model, padded_sequences, img_x, device='cpu'):