File size: 4,458 Bytes
b1904a9
 
 
38b2615
 
 
 
 
 
 
 
 
 
 
d0f4387
 
 
 
 
 
 
 
 
 
 
38b2615
 
 
7166d38
38b2615
 
 
 
 
 
 
 
 
 
d0f4387
 
 
 
f11785f
201eb2f
ac20ecf
 
d0f4387
ac20ecf
38b2615
 
707c74b
38b2615
 
 
 
201eb2f
 
38b2615
 
f11785f
 
 
38b2615
eb13a14
38b2615
 
 
 
 
 
 
 
 
 
 
f11785f
38b2615
eb13a14
38b2615
 
 
707c74b
38b2615
 
707c74b
38b2615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import sys
import os
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
import gradio as gr
from src.config import config
from src.datasets.datasets import CombinedDataset
from src.models.multimodal import CombinedMalwareDetectionModel
from src.models.bigru import CNNBiGRU
from src.models.cnn import ImprovedCNN
from torchvision import transforms
import pickle
import torch
from PIL import Image
from src.utils.get_features import get_img_api
import joblib
import io

# Custom unpickler to handle device mapping
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == "torch.storage" and name == "_load_from_bytes":
            def _load_from_bytes(b):
                return torch.load(io.BytesIO(b), map_location=torch.device('cpu'))
            return _load_from_bytes
        return super().find_class(module, name)

# Path to the dataset
data_path = 'src/data/subset_dataset.csv'
device = torch.device('cpu')
# Define the transform
simple_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  
])

# Load the model
def load_model(model_path, device='cpu'):
    """Loads the model from a joblib file and ensures it runs on the specified device."""
    # Load the model using joblib with custom unpickler
    with open(model_path, 'rb') as f:
        model = CPU_Unpickler(f).load()
    
    # If the model is a PyTorch module, move it to the specified device
    if isinstance(model, torch.nn.Module):
        model = model.to(device)
        model.eval()  # Set to evaluation mode
    return model

# Get prediction
def get_prediction(model, padded_sequences, img_x, device='cpu'):
    malware_classes = ["Benign", "RedLine Stealer", "Downloader", "RAT", 
                       "Banking Trojan", "Snake Keylogger", "Spyware"]

    # Move inputs to the device
    padded_sequences = padded_sequences.to(device)
    img_x = img_x.to(device)
    
    # Perform inference
    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model(padded_sequences, img_x)
        _, predicted = torch.max(outputs, 1)  
    
    return malware_classes[predicted]  #

# Define the prediction function for Gradio
def predict_malware(sha256_hash):
    # Get the image path and API call list for the given SHA256 hash
    image_path, api_call_list = get_img_api(sha256_hash, data_path)
    
    # If the hash is not found, return an error message
    if image_path is None or api_call_list is None:
        return "Hash not found in the dataset.", "", ""

    # Load the dataset
    dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform, sequence_length=config.configuration["sequence_length"])
    padded_sequences, img_x = next(iter(dataset))
    img_x = img_x.unsqueeze(0)  # type: ignore

    # Load the model
    model_path = "model_dump/model_malware_lstm (1).pkl"
    model = load_model(model_path, device='cpu')

    # Get the prediction
    predicted_class = get_prediction(model, padded_sequences, img_x, 'cpu')

    # Load the image for display
    img_x_display = Image.open(image_path)

    # Format the API sequence for display
    api_sequence_display = "\n".join(api_call_list[0])

    # Return the extracted features and the final prediction
    return img_x_display, api_sequence_display, predicted_class

# Create the Gradio interface with custom layout
with gr.Blocks() as demo:
    gr.Markdown("# Malware Detection")
    gr.Markdown("Enter a SHA256 hash to detect the type of malware. The extracted image and API sequence will also be displayed.")

    with gr.Row():
        with gr.Column():
            # Input for SHA256 hash
            sha256_input = gr.Textbox(label="Enter SHA256 Hash")
            submit_button = gr.Button("Submit")
            
            # Output for extracted image 
            image_output = gr.Image(label="Extracted Image", height=256, width=256)
        
        with gr.Column():
            # Output for API sequence
            api_output = gr.Textbox(label="Extracted API Sequence")
            # Output for predicted malware class
            malware_output = gr.Textbox(label="Predicted Malware Class")

    submit_button.click(
        predict_malware,
        inputs=sha256_input,
        outputs=[image_output, api_output, malware_output]
    )

demo.launch()