Spaces:
Sleeping
Sleeping
test
Browse files
app.py
CHANGED
|
@@ -12,7 +12,17 @@ import pickle
|
|
| 12 |
import torch
|
| 13 |
from PIL import Image
|
| 14 |
from src.utils.get_features import get_img_api
|
| 15 |
-
import joblib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Path to the dataset
|
| 18 |
data_path = 'src/data/subset_dataset.csv'
|
|
@@ -27,16 +37,15 @@ simple_transform = transforms.Compose([
|
|
| 27 |
|
| 28 |
# Load the model
|
| 29 |
def load_model(model_path, device='cpu'):
|
| 30 |
-
"""Loads the model from a joblib file and
|
| 31 |
-
# Load the model using joblib
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
# If the model is a PyTorch module, move it to the specified device
|
| 35 |
if isinstance(model, torch.nn.Module):
|
| 36 |
-
# Move model to CPU and handle any CUDA tensors
|
| 37 |
model = model.to(device)
|
| 38 |
-
# Set to evaluation mode
|
| 39 |
-
model.eval()
|
| 40 |
return model
|
| 41 |
|
| 42 |
# Get prediction
|
|
|
|
| 12 |
import torch
|
| 13 |
from PIL import Image
|
| 14 |
from src.utils.get_features import get_img_api
|
| 15 |
+
import joblib
|
| 16 |
+
import io
|
| 17 |
+
|
| 18 |
+
# Custom unpickler to handle device mapping
|
| 19 |
+
class CPU_Unpickler(pickle.Unpickler):
|
| 20 |
+
def find_class(self, module, name):
|
| 21 |
+
if module == "torch.storage" and name == "_load_from_bytes":
|
| 22 |
+
def _load_from_bytes(b):
|
| 23 |
+
return torch.load(io.BytesIO(b), map_location=torch.device('cpu'))
|
| 24 |
+
return _load_from_bytes
|
| 25 |
+
return super().find_class(module, name)
|
| 26 |
|
| 27 |
# Path to the dataset
|
| 28 |
data_path = 'src/data/subset_dataset.csv'
|
|
|
|
| 37 |
|
| 38 |
# Load the model
|
| 39 |
def load_model(model_path, device='cpu'):
|
| 40 |
+
"""Loads the model from a joblib file and ensures it runs on the specified device."""
|
| 41 |
+
# Load the model using joblib with custom unpickler
|
| 42 |
+
with open(model_path, 'rb') as f:
|
| 43 |
+
model = CPU_Unpickler(f).load()
|
| 44 |
|
| 45 |
# If the model is a PyTorch module, move it to the specified device
|
| 46 |
if isinstance(model, torch.nn.Module):
|
|
|
|
| 47 |
model = model.to(device)
|
| 48 |
+
model.eval() # Set to evaluation mode
|
|
|
|
| 49 |
return model
|
| 50 |
|
| 51 |
# Get prediction
|