Spaces:
Sleeping
Sleeping
final?
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import pickle
|
|
| 12 |
import torch
|
| 13 |
from PIL import Image
|
| 14 |
from src.utils.get_features import get_img_api
|
|
|
|
| 15 |
|
| 16 |
# Path to the dataset
|
| 17 |
data_path = 'src/data/subset_dataset.csv'
|
|
@@ -26,10 +27,12 @@ simple_transform = transforms.Compose([
|
|
| 26 |
|
| 27 |
# Load the model
|
| 28 |
def load_model(model_path, device='cpu'):
|
| 29 |
-
"""Loads the model from a
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# Get prediction
|
| 35 |
def get_prediction(model, padded_sequences, img_x, device='cpu'):
|
|
|
|
| 12 |
import torch
|
| 13 |
from PIL import Image
|
| 14 |
from src.utils.get_features import get_img_api
|
| 15 |
+
import joblib
|
| 16 |
|
| 17 |
# Path to the dataset
|
| 18 |
data_path = 'src/data/subset_dataset.csv'
|
|
|
|
| 27 |
|
| 28 |
# Load the model
|
| 29 |
def load_model(model_path, device='cpu'):
|
| 30 |
+
"""Loads the model from a joblib file and moves it to the specified device."""
|
| 31 |
+
model = joblib.load(model_path)
|
| 32 |
+
# If the model contains PyTorch tensors, move them to the specified device
|
| 33 |
+
if isinstance(model, torch.nn.Module):
|
| 34 |
+
model = model.to(device)
|
| 35 |
+
return model
|
| 36 |
|
| 37 |
# Get prediction
|
| 38 |
def get_prediction(model, padded_sequences, img_x, device='cpu'):
|