Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.dirname(__file__))) | |
| import gradio as gr | |
| from src.config import config | |
| from src.datasets.datasets import CombinedDataset | |
| from src.models.multimodal import CombinedMalwareDetectionModel | |
| from src.models.bigru import CNNBiGRU | |
| from src.models.cnn import ImprovedCNN | |
| from torchvision import transforms | |
| import pickle | |
| import torch | |
| from PIL import Image | |
| from src.utils.get_features import get_img_api | |
| import joblib | |
| import io | |
| # Custom unpickler to handle device mapping | |
| class CPU_Unpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| if module == "torch.storage" and name == "_load_from_bytes": | |
| def _load_from_bytes(b): | |
| return torch.load(io.BytesIO(b), map_location=torch.device('cpu')) | |
| return _load_from_bytes | |
| return super().find_class(module, name) | |
| # Path to the dataset | |
| data_path = 'src/data/subset_dataset.csv' | |
| device = torch.device('cpu') | |
| # Define the transform | |
| simple_transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ]) | |
| # Load the model | |
| def load_model(model_path, device='cpu'): | |
| """Loads the model from a joblib file and ensures it runs on the specified device.""" | |
| # Load the model using joblib with custom unpickler | |
| with open(model_path, 'rb') as f: | |
| model = CPU_Unpickler(f).load() | |
| # If the model is a PyTorch module, move it to the specified device | |
| if isinstance(model, torch.nn.Module): | |
| model = model.to(device) | |
| model.eval() # Set to evaluation mode | |
| return model | |
| # Get prediction | |
| def get_prediction(model, padded_sequences, img_x, device='cpu'): | |
| malware_classes = ["Benign", "RedLine Stealer", "Downloader", "RAT", | |
| "Banking Trojan", "Snake Keylogger", "Spyware"] | |
| # Move inputs to the device | |
| padded_sequences = padded_sequences.to(device) | |
| img_x = img_x.to(device) | |
| # Perform inference | |
| with torch.no_grad(): # Disable gradient calculation for inference | |
| outputs = model(padded_sequences, img_x) | |
| _, predicted = torch.max(outputs, 1) | |
| return malware_classes[predicted] # | |
| # Define the prediction function for Gradio | |
| def predict_malware(sha256_hash): | |
| # Get the image path and API call list for the given SHA256 hash | |
| image_path, api_call_list = get_img_api(sha256_hash, data_path) | |
| # If the hash is not found, return an error message | |
| if image_path is None or api_call_list is None: | |
| return "Hash not found in the dataset.", "", "" | |
| # Load the dataset | |
| dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform, sequence_length=config.configuration["sequence_length"]) | |
| padded_sequences, img_x = next(iter(dataset)) | |
| img_x = img_x.unsqueeze(0) # type: ignore | |
| # Load the model | |
| model_path = "model_dump/model_malware_lstm (1).pkl" | |
| model = load_model(model_path, device='cpu') | |
| # Get the prediction | |
| predicted_class = get_prediction(model, padded_sequences, img_x, 'cpu') | |
| # Load the image for display | |
| img_x_display = Image.open(image_path) | |
| # Format the API sequence for display | |
| api_sequence_display = "\n".join(api_call_list[0]) | |
| # Return the extracted features and the final prediction | |
| return img_x_display, api_sequence_display, predicted_class | |
| # Create the Gradio interface with custom layout | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Malware Detection") | |
| gr.Markdown("Enter a SHA256 hash to detect the type of malware. The extracted image and API sequence will also be displayed.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input for SHA256 hash | |
| sha256_input = gr.Textbox(label="Enter SHA256 Hash") | |
| submit_button = gr.Button("Submit") | |
| # Output for extracted image | |
| image_output = gr.Image(label="Extracted Image", height=256, width=256) | |
| with gr.Column(): | |
| # Output for API sequence | |
| api_output = gr.Textbox(label="Extracted API Sequence") | |
| # Output for predicted malware class | |
| malware_output = gr.Textbox(label="Predicted Malware Class") | |
| submit_button.click( | |
| predict_malware, | |
| inputs=sha256_input, | |
| outputs=[image_output, api_output, malware_output] | |
| ) | |
| demo.launch() |