mulasagg commited on
Commit
f11785f
·
1 Parent(s): ac20ecf
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -14,6 +14,9 @@ from PIL import Image
14
  from src.utils.get_features import get_img_api
15
  import joblib
16
 
 
 
 
17
  # Path to the dataset
18
  data_path = 'src/data/subset_dataset.csv'
19
  device = torch.device('cpu')
@@ -28,10 +31,14 @@ simple_transform = transforms.Compose([
28
  # Load the model
29
  def load_model(model_path, device='cpu'):
30
  """Loads the model from a joblib file and moves it to the specified device."""
31
- model = joblib.load(model_path)
32
- # If the model contains PyTorch tensors, move them to the specified device
 
 
 
33
  if isinstance(model, torch.nn.Module):
34
  model = model.to(device)
 
35
  return model
36
 
37
  # Get prediction
@@ -43,10 +50,11 @@ def get_prediction(model, padded_sequences, img_x, device='cpu'):
43
  padded_sequences, img_x = padded_sequences.to(device), img_x.to(device)
44
 
45
  # Perform inference
46
- outputs = model(padded_sequences, img_x)
47
- _, predicted = torch.max(outputs, 1)
 
48
 
49
- return malware_classes[predicted]
50
 
51
  # Define the prediction function for Gradio
52
  def predict_malware(sha256_hash):
@@ -58,9 +66,9 @@ def predict_malware(sha256_hash):
58
  return "Hash not found in the dataset.", "", ""
59
 
60
  # Load the dataset
61
- dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform ,sequence_length=config.configuration["sequence_length"])
62
  padded_sequences, img_x = next(iter(dataset))
63
- img_x = img_x.unsqueeze(0) # type: ignore
64
 
65
  # Load the model
66
  model_path = "model_dump/model_malware_lstm (1).pkl"
@@ -98,14 +106,10 @@ with gr.Blocks() as demo:
98
  # Output for predicted malware class
99
  malware_output = gr.Textbox(label="Predicted Malware Class")
100
 
101
-
102
-
103
-
104
  submit_button.click(
105
  predict_malware,
106
  inputs=sha256_input,
107
  outputs=[image_output, api_output, malware_output]
108
  )
109
 
110
-
111
  demo.launch()
 
14
  from src.utils.get_features import get_img_api
15
  import joblib
16
 
17
+
18
+ device = torch.device('cpu')
19
+
20
  # Path to the dataset
21
  data_path = 'src/data/subset_dataset.csv'
22
  device = torch.device('cpu')
 
31
  # Load the model
32
  def load_model(model_path, device='cpu'):
33
  """Loads the model from a joblib file and moves it to the specified device."""
34
+ # Use torch.load with map_location to ensure CPU compatibility
35
+ with open(model_path, 'rb') as f:
36
+ model = torch.load(f, map_location=device)
37
+
38
+ # If the model is a PyTorch module, move it to the specified device and set to eval mode
39
  if isinstance(model, torch.nn.Module):
40
  model = model.to(device)
41
+ model.eval()
42
  return model
43
 
44
  # Get prediction
 
50
  padded_sequences, img_x = padded_sequences.to(device), img_x.to(device)
51
 
52
  # Perform inference
53
+ with torch.no_grad(): # Disable gradient calculation for inference
54
+ outputs = model(padded_sequences, img_x)
55
+ _, predicted = torch.max(outputs, 1)
56
 
57
+ return malware_classes[predicted] # Use .item() to get scalar value
58
 
59
  # Define the prediction function for Gradio
60
  def predict_malware(sha256_hash):
 
66
  return "Hash not found in the dataset.", "", ""
67
 
68
  # Load the dataset
69
+ dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform, sequence_length=config.configuration["sequence_length"])
70
  padded_sequences, img_x = next(iter(dataset))
71
+ img_x = img_x.unsqueeze(0) #type: ignore
72
 
73
  # Load the model
74
  model_path = "model_dump/model_malware_lstm (1).pkl"
 
106
  # Output for predicted malware class
107
  malware_output = gr.Textbox(label="Predicted Malware Class")
108
 
 
 
 
109
  submit_button.click(
110
  predict_malware,
111
  inputs=sha256_input,
112
  outputs=[image_output, api_output, malware_output]
113
  )
114
 
 
115
  demo.launch()