malconv / src /model.py
cycloevan's picture
Upload 17 files
b92918a verified
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
class DeCorrelationLoss(tf.keras.layers.Layer):
"""๋…ผ๋ฌธ์˜ ์ •ํ™•ํ•œ DeCov ์ •๊ทœํ™” ๊ตฌํ˜„"""
def __init__(self, lambda_decov=1e-4, **kwargs):
super(DeCorrelationLoss, self).__init__(**kwargs)
self.lambda_decov = lambda_decov
def build(self, input_shape):
super(DeCorrelationLoss, self).build(input_shape)
def call(self, inputs):
batch_size = tf.cast(tf.shape(inputs)[0], tf.float32)
# ์ค‘์‹ฌํ™”
inputs_centered = inputs - tf.reduce_mean(inputs, axis=0, keepdims=True)
# ๊ณต๋ถ„์‚ฐ ํ–‰๋ ฌ ๊ณ„์‚ฐ
covariance = tf.matmul(inputs_centered, inputs_centered, transpose_a=True) / (batch_size - 1)
# ๋Œ€๊ฐ์„  ์ œ๊ฑฐ
covariance_off_diagonal = covariance - tf.linalg.diag(tf.linalg.diag_part(covariance))
# DeCov ์†์‹ค
decov_loss = 0.5 * tf.reduce_sum(tf.square(covariance_off_diagonal))
self.add_loss(self.lambda_decov * decov_loss)
return inputs
class MalConv(Model):
"""๋…ผ๋ฌธ ์ •ํ™• ์‚ฌ์–‘ MalConv ๋ชจ๋ธ"""
def __init__(self,
max_input_length=2_000_000,
embedding_size=8,
filter_size=500,
stride=500,
num_filters=128,
fc_size=128,
use_decov=True,
lambda_decov=1e-4,
**kwargs):
super(MalConv, self).__init__(**kwargs)
self.max_input_length = max_input_length
self.use_decov = use_decov
# ๋…ผ๋ฌธ ์ •ํ™• ์‚ฌ์–‘: 0-255 ๋ฐ”์ดํŠธ๋งŒ ์‚ฌ์šฉ
self.embedding = layers.Embedding(
input_dim=256, # ์ˆ˜์ •: 257โ†’256
output_dim=embedding_size,
input_length=None, # ๊ฐ€๋ณ€ ๊ธธ์ด ์ง€์›
mask_zero=False,
name='byte_embedding'
)
# ๊ฒŒ์ดํŠธ ์ปจ๋ณผ๋ฃจ์…˜ (๋…ผ๋ฌธ Figure 1)
self.conv_A = layers.Conv1D(
filters=num_filters,
kernel_size=filter_size,
strides=stride,
padding='valid',
activation='relu',
name='conv_A'
)
self.conv_B = layers.Conv1D(
filters=num_filters,
kernel_size=filter_size,
strides=stride,
padding='valid',
activation='sigmoid',
name='conv_B'
)
# ์ „์—ญ ์ตœ๋Œ€ ํ’€๋ง
self.global_max_pool = layers.GlobalMaxPooling1D(name='global_max_pool')
# ์™„์ „์—ฐ๊ฒฐ์ธต
self.fc = layers.Dense(fc_size, activation='relu', name='fc_layer')
# DeCov ์ •๊ทœํ™”
if use_decov:
self.decov_layer = DeCorrelationLoss(lambda_decov=lambda_decov)
self.dropout = layers.Dropout(0.5, name='dropout')
self.output_layer = layers.Dense(1, activation='sigmoid', name='output')
def call(self, inputs, training=None):
# 1. ๋ฐ”์ดํŠธ ์ž„๋ฒ ๋”ฉ
x = self.embedding(inputs)
# 2. ๊ฒŒ์ดํŠธ ์ปจ๋ณผ๋ฃจ์…˜ (๋…ผ๋ฌธ ํ•ต์‹ฌ)
conv_a = self.conv_A(x)
conv_b = self.conv_B(x)
gated_conv = layers.multiply([conv_a, conv_b], name='gated_conv')
# 3. ์ „์—ญ ์ตœ๋Œ€ ํ’€๋ง
pooled = self.global_max_pool(gated_conv)
# 4. ์™„์ „์—ฐ๊ฒฐ์ธต
fc_out = self.fc(pooled)
# 5. DeCov ์ •๊ทœํ™” (penultimate layer)
if self.use_decov:
fc_out = self.decov_layer(fc_out)
# 6. ๋“œ๋กญ์•„์›ƒ
if training:
fc_out = self.dropout(fc_out, training=training)
# 7. ์ถœ๋ ฅ
output = self.output_layer(fc_out)
return output
def create_malconv_model (max_input_length=2_000_000):
"""๋…ผ๋ฌธ ์™„์ „ ๋™์ผ ์‚ฌ์–‘ ๋ชจ๋ธ"""
model = MalConv(max_input_length=max_input_length)
# ๋…ผ๋ฌธ ์ •ํ™•ํ•œ ์˜ตํ‹ฐ๋งˆ์ด์ € + ์Šค์ผ€์ค„๋Ÿฌ
initial_lr = 0.01
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=initial_lr,
decay_steps=1000,
decay_rate=0.96, # ๋…ผ๋ฌธ์—์„œ ์–ธ๊ธ‰๋œ ์ง€์ˆ˜ ๊ฐ์†Œ
staircase=True
)
optimizer = tf.keras.optimizers.SGD(
learning_rate=lr_schedule,
momentum=0.9,
nesterov=True
)
model.compile(
optimizer=optimizer,
loss='binary_crossentropy',
metrics=['accuracy',
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tf.keras.metrics.AUC(name='auc')]
)
return model