mulasagg commited on
Commit
d0f4387
·
1 Parent(s): 201eb2f
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -12,7 +12,17 @@ import pickle
12
  import torch
13
  from PIL import Image
14
  from src.utils.get_features import get_img_api
15
- import joblib
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Path to the dataset
18
  data_path = 'src/data/subset_dataset.csv'
@@ -27,16 +37,15 @@ simple_transform = transforms.Compose([
27
 
28
  # Load the model
29
  def load_model(model_path, device='cpu'):
30
- """Loads the model from a joblib file and moves it to the specified device."""
31
- # Load the model using joblib
32
- model = joblib.load(model_path)
 
33
 
34
  # If the model is a PyTorch module, move it to the specified device
35
  if isinstance(model, torch.nn.Module):
36
- # Move model to CPU and handle any CUDA tensors
37
  model = model.to(device)
38
- # Set to evaluation mode
39
- model.eval()
40
  return model
41
 
42
  # Get prediction
 
12
  import torch
13
  from PIL import Image
14
  from src.utils.get_features import get_img_api
15
+ import joblib
16
+ import io
17
+
18
+ # Custom unpickler to handle device mapping
19
+ class CPU_Unpickler(pickle.Unpickler):
20
+ def find_class(self, module, name):
21
+ if module == "torch.storage" and name == "_load_from_bytes":
22
+ def _load_from_bytes(b):
23
+ return torch.load(io.BytesIO(b), map_location=torch.device('cpu'))
24
+ return _load_from_bytes
25
+ return super().find_class(module, name)
26
 
27
  # Path to the dataset
28
  data_path = 'src/data/subset_dataset.csv'
 
37
 
38
  # Load the model
39
  def load_model(model_path, device='cpu'):
40
+ """Loads the model from a joblib file and ensures it runs on the specified device."""
41
+ # Load the model using joblib with custom unpickler
42
+ with open(model_path, 'rb') as f:
43
+ model = CPU_Unpickler(f).load()
44
 
45
  # If the model is a PyTorch module, move it to the specified device
46
  if isinstance(model, torch.nn.Module):
 
47
  model = model.to(device)
48
+ model.eval() # Set to evaluation mode
 
49
  return model
50
 
51
  # Get prediction