malconv / src /utils.py
cycloevan's picture
Upload 17 files
b92918a verified
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, balanced_accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
# utils.py์— ๋ˆ„๋ฝ๋œ import
import tensorflow as tf # configure_gpu_memory ํ•จ์ˆ˜์—์„œ ํ•„์š”
def read_binary_file(file_path, max_length=2_000_000):
"""
๋ฐ”์ด๋„ˆ๋ฆฌ ํŒŒ์ผ์„ ์ฝ์–ด ์ •์ˆ˜ ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
๋…ผ๋ฌธ ์‚ฌ์–‘: 2MB๊นŒ์ง€ ์ฒ˜๋ฆฌ
"""
try:
with open(file_path, 'rb') as f:
raw_bytes = f.read()
# ๋ฐ”์ดํŠธ๋ฅผ 0-255 ์ •์ˆ˜๋กœ ๋ณ€ํ™˜
byte_array = np.frombuffer(raw_bytes, dtype=np.uint8)
if len(byte_array) > max_length:
# ๊ธด ํŒŒ์ผ: ์•ž 2MB๋งŒ ์‚ฌ์šฉ (๋…ผ๋ฌธ ๋ฐฉ์‹)
return byte_array[:max_length]
else:
# ์งง์€ ํŒŒ์ผ: 0์œผ๋กœ ํŒจ๋”ฉ
padded = np.zeros(max_length, dtype=np.uint8)
padded[:len(byte_array)] = byte_array
return padded
except Exception as e:
print(f"ํŒŒ์ผ ์ฝ๊ธฐ ์˜ค๋ฅ˜ {file_path}: {e}")
return np.zeros(max_length, dtype=np.uint8)
def load_dataset_from_directory(malware_dir, benign_dir, max_length=2_000_000, max_samples_per_class=None):
"""
๋””๋ ‰ํ† ๋ฆฌ์—์„œ ์ง์ ‘ ๋ฐ”์ด๋„ˆ๋ฆฌ ํŒŒ์ผ๋“ค์„ ๋กœ๋“œ
Args:
malware_dir: ์•…์„ฑ์ฝ”๋“œ ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ
benign_dir: ์ •์ƒ ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ
max_length: ์ตœ๋Œ€ ๋ฐ”์ดํŠธ ๊ธธ์ด
max_samples_per_class: ํด๋ž˜์Šค๋‹น ์ตœ๋Œ€ ์ƒ˜ํ”Œ ์ˆ˜
"""
X, y = [], []
# ์•…์„ฑ์ฝ”๋“œ ํŒŒ์ผ ๋กœ๋“œ
if os.path.exists(malware_dir):
malware_files = [f for f in os.listdir(malware_dir) if os.path.isfile(os.path.join(malware_dir, f))]
if max_samples_per_class:
malware_files = malware_files[:max_samples_per_class]
print(f"์•…์„ฑ์ฝ”๋“œ ํŒŒ์ผ ๋กœ๋”ฉ ์ค‘... ({len(malware_files)}๊ฐœ)")
for i, filename in enumerate(malware_files):
file_path = os.path.join(malware_dir, filename)
byte_array = read_binary_file(file_path, max_length)
X.append(byte_array)
y.append(0) # ์•…์„ฑ์ฝ”๋“œ = 0
if (i + 1) % 100 == 0:
print(f" {i + 1}/{len(malware_files)} ์ฒ˜๋ฆฌ ์™„๋ฃŒ")
# ์ •์ƒ ํŒŒ์ผ ๋กœ๋“œ
if os.path.exists(benign_dir):
benign_files = [f for f in os.listdir(benign_dir) if os.path.isfile(os.path.join(benign_dir, f))]
if max_samples_per_class:
benign_files = benign_files[:max_samples_per_class]
print(f"์ •์ƒ ํŒŒ์ผ ๋กœ๋”ฉ ์ค‘... ({len(benign_files)}๊ฐœ)")
for i, filename in enumerate(benign_files):
file_path = os.path.join(benign_dir, filename)
byte_array = read_binary_file(file_path, max_length)
X.append(byte_array)
y.append(1) # ์ •์ƒ = 1
if (i + 1) % 100 == 0:
print(f" {i + 1}/{len(benign_files)} ์ฒ˜๋ฆฌ ์™„๋ฃŒ")
X = np.array(X)
y = np.array(y)
print(f"\n๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ ์™„๋ฃŒ:")
print(f" ์ด ์ƒ˜ํ”Œ: {len(X)}")
print(f" ์•…์„ฑ์ฝ”๋“œ: {np.sum(y == 0)}")
print(f" ์ •์ƒํŒŒ์ผ: {np.sum(y == 1)}")
return X, y
def load_dataset_from_csv(csv_path, max_length=2_000_000):
"""CSV ํŒŒ์ผ์—์„œ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ"""
df = pd.read_csv(csv_path)
X, y = [], []
print("CSV์—์„œ ํŒŒ์ผ ๋กœ๋”ฉ ์ค‘...")
for idx, row in df.iterrows():
file_path = row['filepath']
label = row['label']
if os.path.exists(file_path):
byte_array = read_binary_file(file_path, max_length)
X.append(byte_array)
y.append(label)
else:
print(f"ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {file_path}")
if (idx + 1) % 1000 == 0:
print(f" {idx + 1} ํŒŒ์ผ ์ฒ˜๋ฆฌ ์™„๋ฃŒ")
return np.array(X), np.array(y)
def configure_gpu_memory():
"""GPU ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •"""
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
print(f"GPU ์„ค์ • ์™„๋ฃŒ: {len(gpus)}๊ฐœ GPU ์‚ฌ์šฉ")
return True
except RuntimeError as e:
print(f"GPU ์„ค์ • ์˜ค๋ฅ˜: {e}")
return False
def plot_training_history(history):
"""ํ›ˆ๋ จ ํžˆ์Šคํ† ๋ฆฌ ์‹œ๊ฐํ™”"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Loss
axes[0, 0].plot(history.history['loss'], label='Training Loss')
if 'val_loss' in history.history:
axes[0, 0].plot(history.history['val_loss'], label='Validation Loss')
axes[0, 0].set_title('Model Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)
# Accuracy
axes[0, 1].plot(history.history['accuracy'], label='Training Accuracy')
if 'val_accuracy' in history.history:
axes[0, 1].plot(history.history['val_accuracy'], label='Validation Accuracy')
axes[0, 1].set_title('Model Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True)
# AUC
if 'auc' in history.history:
axes[1, 0].plot(history.history['auc'], label='Training AUC')
if 'val_auc' in history.history:
axes[1, 0].plot(history.history['val_auc'], label='Validation AUC')
axes[1, 0].set_title('Model AUC')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('AUC')
axes[1, 0].legend()
axes[1, 0].grid(True)
# Learning Rate
if 'lr' in history.history:
axes[1, 1].plot(history.history['lr'], label='Learning Rate', color='red')
axes[1, 1].set_title('Learning Rate Schedule')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_yscale('log')
axes[1, 1].legend()
axes[1, 1].grid(True)
plt.tight_layout()
plt.show()
def plot_confusion_matrix(y_true, y_pred, title="Confusion Matrix"):
"""ํ˜ผ๋™ ํ–‰๋ ฌ ์‹œ๊ฐํ™”"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Malware', 'Benign'],
yticklabels=['Malware', 'Benign'])
plt.title(title)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()
def evaluate_model(model, X_test, y_test, batch_size=16):
"""๋ชจ๋ธ ์„ฑ๋Šฅ ํ‰๊ฐ€"""
print("๋ชจ๋ธ ํ‰๊ฐ€ ์ค‘...")
# ์˜ˆ์ธก
y_pred_prob = model.predict(X_test, batch_size=batch_size, verbose=1)
y_pred = (y_pred_prob > 0.5).astype(int).flatten()
# ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ
accuracy = np.mean(y_pred == y_test)
balanced_acc = balanced_accuracy_score(y_test, y_pred)
auc_score = roc_auc_score(y_test, y_pred_prob)
print(f"\n=== ํ‰๊ฐ€ ๊ฒฐ๊ณผ ===")
print(f"Accuracy: {accuracy:.4f}")
print(f"Balanced Accuracy: {balanced_acc:.4f}")
print(f"AUC Score: {auc_score:.4f}")
print(f"\n๋ถ„๋ฅ˜ ๋ฆฌํฌํŠธ:")
print(classification_report(y_test, y_pred, target_names=['Malware', 'Benign']))
# ํ˜ผ๋™ ํ–‰๋ ฌ ์‹œ๊ฐํ™”
plot_confusion_matrix(y_test, y_pred, "MalConv Performance")
return {
'accuracy': accuracy,
'balanced_accuracy': balanced_acc,
'auc': auc_score,
'predictions': y_pred_prob
}
def get_file_paths_and_labels(malware_dir, benign_dir, max_samples_per_class=None):
"""
๋””๋ ‰ํ† ๋ฆฌ์—์„œ ํŒŒ์ผ ๊ฒฝ๋กœ์™€ ๋ ˆ์ด๋ธ” ๋ชฉ๋ก์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. (๋ฉ”๋ชจ๋ฆฌ์— ํŒŒ์ผ ๋กœ๋“œ ์•ˆํ•จ)
"""
filepaths = []
labels = []
# ์•…์„ฑ์ฝ”๋“œ ํŒŒ์ผ ๊ฒฝ๋กœ
if os.path.exists(malware_dir):
malware_files = [os.path.join(malware_dir, f) for f in os.listdir(malware_dir) if os.path.isfile(os.path.join(malware_dir, f))]
if max_samples_per_class:
malware_files = malware_files[:max_samples_per_class]
filepaths.extend(malware_files)
labels.extend([0] * len(malware_files)) # ์•…์„ฑ์ฝ”๋“œ = 0
print(f"์•…์„ฑ์ฝ”๋“œ ํŒŒ์ผ ๊ฒฝ๋กœ ๋กœ๋”ฉ: {len(malware_files)}๊ฐœ")
# ์ •์ƒ ํŒŒ์ผ ๊ฒฝ๋กœ
if os.path.exists(benign_dir):
benign_files = [os.path.join(benign_dir, f) for f in os.listdir(benign_dir) if os.path.isfile(os.path.join(benign_dir, f))]
if max_samples_per_class:
benign_files = benign_files[:max_samples_per_class]
filepaths.extend(benign_files)
labels.extend([1] * len(benign_files)) # ์ •์ƒ = 1
print(f"์ •์ƒ ํŒŒ์ผ ๊ฒฝ๋กœ ๋กœ๋”ฉ: {len(benign_files)}๊ฐœ")
print(f"\n์ด ํŒŒ์ผ ๊ฒฝ๋กœ: {len(filepaths)}")
print(f" ์•…์„ฑ์ฝ”๋“œ: {labels.count(0)}")
print(f" ์ •์ƒํŒŒ์ผ: {labels.count(1)}")
# ๋ฐ์ดํ„ฐ ์ˆœ์„œ ์„ž๊ธฐ
indices = np.arange(len(filepaths))
np.random.shuffle(indices)
filepaths = np.array(filepaths)[indices].tolist()
labels = np.array(labels)[indices]
return filepaths, labels
def data_generator(filepaths, labels, batch_size, max_length=2_000_000, shuffle=True):
"""
๋ฐ์ดํ„ฐ๋ฅผ ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ์ƒ์„ฑํ•˜๋Š” ์ œ๋„ˆ๋ ˆ์ดํ„ฐ
"""
num_samples = len(filepaths)
if num_samples == 0:
return
while True:
indices = np.arange(num_samples)
if shuffle:
np.random.shuffle(indices)
for i in range(0, num_samples, batch_size):
batch_indices = indices[i:i+batch_size]
X_batch = []
y_batch_list = []
for j in batch_indices:
try:
X_batch.append(read_binary_file(filepaths[j], max_length))
y_batch_list.append(labels[j])
except Exception as e:
print(f"Warning: Skipping file {filepaths[j]} due to error: {e}")
continue
if not X_batch:
continue
yield np.array(X_batch), np.array(y_batch_list)