mulasagg's picture
add more samples
eb13a14
import sys
import os
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
import gradio as gr
from src.config import config
from src.datasets.datasets import CombinedDataset
from src.models.multimodal import CombinedMalwareDetectionModel
from src.models.bigru import CNNBiGRU
from src.models.cnn import ImprovedCNN
from torchvision import transforms
import pickle
import torch
from PIL import Image
from src.utils.get_features import get_img_api
import joblib
import io
# Custom unpickler to handle device mapping
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "torch.storage" and name == "_load_from_bytes":
def _load_from_bytes(b):
return torch.load(io.BytesIO(b), map_location=torch.device('cpu'))
return _load_from_bytes
return super().find_class(module, name)
# Path to the dataset
data_path = 'src/data/subset_dataset.csv'
device = torch.device('cpu')
# Define the transform
simple_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# Load the model
def load_model(model_path, device='cpu'):
"""Loads the model from a joblib file and ensures it runs on the specified device."""
# Load the model using joblib with custom unpickler
with open(model_path, 'rb') as f:
model = CPU_Unpickler(f).load()
# If the model is a PyTorch module, move it to the specified device
if isinstance(model, torch.nn.Module):
model = model.to(device)
model.eval() # Set to evaluation mode
return model
# Get prediction
def get_prediction(model, padded_sequences, img_x, device='cpu'):
malware_classes = ["Benign", "RedLine Stealer", "Downloader", "RAT",
"Banking Trojan", "Snake Keylogger", "Spyware"]
# Move inputs to the device
padded_sequences = padded_sequences.to(device)
img_x = img_x.to(device)
# Perform inference
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(padded_sequences, img_x)
_, predicted = torch.max(outputs, 1)
return malware_classes[predicted] #
# Define the prediction function for Gradio
def predict_malware(sha256_hash):
# Get the image path and API call list for the given SHA256 hash
image_path, api_call_list = get_img_api(sha256_hash, data_path)
# If the hash is not found, return an error message
if image_path is None or api_call_list is None:
return "Hash not found in the dataset.", "", ""
# Load the dataset
dataset = CombinedDataset(api_call_list, image_path, transforms=simple_transform, sequence_length=config.configuration["sequence_length"])
padded_sequences, img_x = next(iter(dataset))
img_x = img_x.unsqueeze(0) # type: ignore
# Load the model
model_path = "model_dump/model_malware_lstm (1).pkl"
model = load_model(model_path, device='cpu')
# Get the prediction
predicted_class = get_prediction(model, padded_sequences, img_x, 'cpu')
# Load the image for display
img_x_display = Image.open(image_path)
# Format the API sequence for display
api_sequence_display = "\n".join(api_call_list[0])
# Return the extracted features and the final prediction
return img_x_display, api_sequence_display, predicted_class
# Create the Gradio interface with custom layout
with gr.Blocks() as demo:
gr.Markdown("# Malware Detection")
gr.Markdown("Enter a SHA256 hash to detect the type of malware. The extracted image and API sequence will also be displayed.")
with gr.Row():
with gr.Column():
# Input for SHA256 hash
sha256_input = gr.Textbox(label="Enter SHA256 Hash")
submit_button = gr.Button("Submit")
# Output for extracted image
image_output = gr.Image(label="Extracted Image", height=256, width=256)
with gr.Column():
# Output for API sequence
api_output = gr.Textbox(label="Extracted API Sequence")
# Output for predicted malware class
malware_output = gr.Textbox(label="Predicted Malware Class")
submit_button.click(
predict_malware,
inputs=sha256_input,
outputs=[image_output, api_output, malware_output]
)
demo.launch()