mulasagg commited on
Commit
38b2615
·
1 Parent(s): 064e7b4
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.config import config
3
+ from src.datasets.datasets import CombinedDataset
4
+ from src.models.multimodal import CombinedMalwareDetectionModel
5
+ from src.models.bigru import CNNBiGRU
6
+ from src.models.cnn import ImprovedCNN
7
+ from torchvision import transforms
8
+ import pickle
9
+ import torch
10
+ from PIL import Image
11
+ from src.utils.get_features import get_img_api
12
+
13
+ # Path to the dataset
14
+ data_path = 'src/data/subset_dataset.csv'
15
+
16
+ # Define the transform
17
+ simple_transform = transforms.Compose([
18
+ transforms.Resize((128, 128)),
19
+ transforms.RandomHorizontalFlip(p=0.5),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.5], std=[0.5])
22
+ ])
23
+
24
+ # Load the model
25
+ def load_model(model_path, device='cpu'):
26
+ """Loads the model from a pickle file and moves it to the specified device."""
27
+ with open(model_path, 'rb') as f:
28
+ model = pickle.load(f)
29
+ return model.to(device)
30
+
31
+ # Get prediction
32
+ def get_prediction(model, padded_sequences, img_x, device='cuda'):
33
+ malware_classes = ["Benign", "RedLine Stealer", "Downloader", "RAT",
34
+ "Banking Trojan", "Snake Keylogger", "Spyware"]
35
+
36
+ # Move inputs to the device
37
+ padded_sequences, img_x = padded_sequences.to(device), img_x.to(device)
38
+
39
+ # Perform inference
40
+ outputs = model(padded_sequences, img_x)
41
+ _, predicted = torch.max(outputs, 1)
42
+
43
+ return malware_classes[predicted]
44
+
45
+ # Define the prediction function for Gradio
46
+ def predict_malware(sha256_hash):
47
+ # Get the image path and API call list for the given SHA256 hash
48
+ image_path, api_call_list = get_img_api(sha256_hash, data_path)
49
+
50
+ # If the hash is not found, return an error message
51
+ if image_path is None or api_call_list is None:
52
+ return "Hash not found in the dataset.", "", ""
53
+
54
+ # Load the dataset
55
+ dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform ,sequence_length=config.configuration["sequence_length"])
56
+ padded_sequences, img_x = next(iter(dataset))
57
+ img_x = img_x.unsqueeze(0) # type: ignore
58
+
59
+ # Load the model
60
+ model_path = "model_dump/model_malware_lstm (1).pkl"
61
+ model = load_model(model_path, device=config.configuration["device"])
62
+
63
+ # Get the prediction
64
+ predicted_class = get_prediction(model, padded_sequences, img_x, config.configuration["device"])
65
+
66
+ # Load the image for display
67
+ img_x_display = Image.open(image_path)
68
+
69
+ # Format the API sequence for display
70
+ api_sequence_display = "\n".join(api_call_list[0])
71
+
72
+ # Return the extracted features and the final prediction
73
+ return img_x_display, api_sequence_display, predicted_class
74
+
75
+ # Create the Gradio interface with custom layout
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# Malware Detection")
78
+ gr.Markdown("Enter a SHA256 hash to detect the type of malware. The extracted image and API sequence will also be displayed.")
79
+
80
+ with gr.Row():
81
+ with gr.Column():
82
+ # Input for SHA256 hash
83
+ sha256_input = gr.Textbox(label="Enter SHA256 Hash")
84
+ submit_button = gr.Button("Submit")
85
+
86
+ # Output for extracted image
87
+ image_output = gr.Image(label="Extracted Image", height=256, width=256)
88
+
89
+ with gr.Column():
90
+ # Output for API sequence
91
+ api_output = gr.Textbox(label="Extracted API Sequence")
92
+ # Output for predicted malware class
93
+ malware_output = gr.Textbox(label="Predicted Malware Class")
94
+
95
+
96
+
97
+
98
+ submit_button.click(
99
+ predict_malware,
100
+ inputs=sha256_input,
101
+ outputs=[image_output, api_output, malware_output]
102
+ )
103
+
104
+
105
+ demo.launch()