import sys import os sys.path.append(os.path.abspath(os.path.dirname(__file__))) import gradio as gr from src.config import config from src.datasets.datasets import CombinedDataset from src.models.multimodal import CombinedMalwareDetectionModel from src.models.bigru import CNNBiGRU from src.models.cnn import ImprovedCNN from torchvision import transforms import pickle import torch from PIL import Image from src.utils.get_features import get_img_api import joblib import io # Custom unpickler to handle device mapping class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == "torch.storage" and name == "_load_from_bytes": def _load_from_bytes(b): return torch.load(io.BytesIO(b), map_location=torch.device('cpu')) return _load_from_bytes return super().find_class(module, name) # Path to the dataset data_path = 'src/data/subset_dataset.csv' device = torch.device('cpu') # Define the transform simple_transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) # Load the model def load_model(model_path, device='cpu'): """Loads the model from a joblib file and ensures it runs on the specified device.""" # Load the model using joblib with custom unpickler with open(model_path, 'rb') as f: model = CPU_Unpickler(f).load() # If the model is a PyTorch module, move it to the specified device if isinstance(model, torch.nn.Module): model = model.to(device) model.eval() # Set to evaluation mode return model # Get prediction def get_prediction(model, padded_sequences, img_x, device='cpu'): malware_classes = ["Benign", "RedLine Stealer", "Downloader", "RAT", "Banking Trojan", "Snake Keylogger", "Spyware"] # Move inputs to the device padded_sequences = padded_sequences.to(device) img_x = img_x.to(device) # Perform inference with torch.no_grad(): # Disable gradient calculation for inference outputs = model(padded_sequences, img_x) _, predicted = torch.max(outputs, 1) return malware_classes[predicted] # # Define the prediction function for Gradio def predict_malware(sha256_hash): # Get the image path and API call list for the given SHA256 hash image_path, api_call_list = get_img_api(sha256_hash, data_path) # If the hash is not found, return an error message if image_path is None or api_call_list is None: return "Hash not found in the dataset.", "", "" # Load the dataset dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform, sequence_length=config.configuration["sequence_length"]) padded_sequences, img_x = next(iter(dataset)) img_x = img_x.unsqueeze(0) # type: ignore # Load the model model_path = "model_dump/model_malware_lstm (1).pkl" model = load_model(model_path, device='cpu') # Get the prediction predicted_class = get_prediction(model, padded_sequences, img_x, 'cpu') # Load the image for display img_x_display = Image.open(image_path) # Format the API sequence for display api_sequence_display = "\n".join(api_call_list[0]) # Return the extracted features and the final prediction return img_x_display, api_sequence_display, predicted_class # Create the Gradio interface with custom layout with gr.Blocks() as demo: gr.Markdown("# Malware Detection") gr.Markdown("Enter a SHA256 hash to detect the type of malware. The extracted image and API sequence will also be displayed.") with gr.Row(): with gr.Column(): # Input for SHA256 hash sha256_input = gr.Textbox(label="Enter SHA256 Hash") submit_button = gr.Button("Submit") # Output for extracted image image_output = gr.Image(label="Extracted Image", height=256, width=256) with gr.Column(): # Output for API sequence api_output = gr.Textbox(label="Extracted API Sequence") # Output for predicted malware class malware_output = gr.Textbox(label="Predicted Malware Class") submit_button.click( predict_malware, inputs=sha256_input, outputs=[image_output, api_output, malware_output] ) demo.launch()