mulasagg commited on
Commit
201eb2f
·
1 Parent(s): f11785f
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -14,9 +14,6 @@ from PIL import Image
14
  from src.utils.get_features import get_img_api
15
  import joblib
16
 
17
-
18
- device = torch.device('cpu')
19
-
20
  # Path to the dataset
21
  data_path = 'src/data/subset_dataset.csv'
22
  device = torch.device('cpu')
@@ -31,13 +28,14 @@ simple_transform = transforms.Compose([
31
  # Load the model
32
  def load_model(model_path, device='cpu'):
33
  """Loads the model from a joblib file and moves it to the specified device."""
34
- # Use torch.load with map_location to ensure CPU compatibility
35
- with open(model_path, 'rb') as f:
36
- model = torch.load(f, map_location=device)
37
 
38
- # If the model is a PyTorch module, move it to the specified device and set to eval mode
39
  if isinstance(model, torch.nn.Module):
 
40
  model = model.to(device)
 
41
  model.eval()
42
  return model
43
 
@@ -47,7 +45,8 @@ def get_prediction(model, padded_sequences, img_x, device='cpu'):
47
  "Banking Trojan", "Snake Keylogger", "Spyware"]
48
 
49
  # Move inputs to the device
50
- padded_sequences, img_x = padded_sequences.to(device), img_x.to(device)
 
51
 
52
  # Perform inference
53
  with torch.no_grad(): # Disable gradient calculation for inference
@@ -68,7 +67,7 @@ def predict_malware(sha256_hash):
68
  # Load the dataset
69
  dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform, sequence_length=config.configuration["sequence_length"])
70
  padded_sequences, img_x = next(iter(dataset))
71
- img_x = img_x.unsqueeze(0) #type: ignore
72
 
73
  # Load the model
74
  model_path = "model_dump/model_malware_lstm (1).pkl"
 
14
  from src.utils.get_features import get_img_api
15
  import joblib
16
 
 
 
 
17
  # Path to the dataset
18
  data_path = 'src/data/subset_dataset.csv'
19
  device = torch.device('cpu')
 
28
  # Load the model
29
  def load_model(model_path, device='cpu'):
30
  """Loads the model from a joblib file and moves it to the specified device."""
31
+ # Load the model using joblib
32
+ model = joblib.load(model_path)
 
33
 
34
+ # If the model is a PyTorch module, move it to the specified device
35
  if isinstance(model, torch.nn.Module):
36
+ # Move model to CPU and handle any CUDA tensors
37
  model = model.to(device)
38
+ # Set to evaluation mode
39
  model.eval()
40
  return model
41
 
 
45
  "Banking Trojan", "Snake Keylogger", "Spyware"]
46
 
47
  # Move inputs to the device
48
+ padded_sequences = padded_sequences.to(device)
49
+ img_x = img_x.to(device)
50
 
51
  # Perform inference
52
  with torch.no_grad(): # Disable gradient calculation for inference
 
67
  # Load the dataset
68
  dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform, sequence_length=config.configuration["sequence_length"])
69
  padded_sequences, img_x = next(iter(dataset))
70
+ img_x = img_x.unsqueeze(0) # Add batch dimension
71
 
72
  # Load the model
73
  model_path = "model_dump/model_malware_lstm (1).pkl"