File size: 5,273 Bytes
b92918a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import argparse
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.model import create_malconv_model
from src.utils import (
    configure_gpu_memory, 
    plot_training_history, 
    evaluate_model,
    get_file_paths_and_labels,
    data_generator,
    read_binary_file
)

def train_malconv(data_source,
                  epochs=10,
                  batch_size=256,
                  max_length=2_000_000,
                  validation_split=0.2,
                  save_path="models/malconv_model.h5"):
    """
    MalConv ๋ชจ๋ธ ํ›ˆ๋ จ (๋ฐ์ดํ„ฐ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ์‚ฌ์šฉ)
    
    Args:
        data_source: (malware_dir, benign_dir) ํŠœํ”Œ
        epochs: ํ›ˆ๋ จ ์—ํฌํฌ ์ˆ˜
        batch_size: ๋ฐฐ์น˜ ํฌ๊ธฐ
        max_length: ์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด (2MB)
        validation_split: ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ ๋น„์œจ
        save_path: ๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ
    """
    
    print("=" * 60)
    print("MalConv ๋ชจ๋ธ ํ›ˆ๋ จ ์‹œ์ž‘ (๋ฐ์ดํ„ฐ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ๋ชจ๋“œ)")
    print("=" * 60)
    
    # GPU ์„ค์ •
    configure_gpu_memory()
    
    # ๋ฐ์ดํ„ฐ ๊ฒฝ๋กœ ๋ฐ ๋ ˆ์ด๋ธ” ๋กœ๋”ฉ
    if isinstance(data_source, tuple) and len(data_source) == 2:
        malware_dir, benign_dir = data_source
        filepaths, labels = get_file_paths_and_labels(malware_dir, benign_dir)
    else:
        raise ValueError("data_source๋Š” (malware_dir, benign_dir) ํŠœํ”Œ์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")

    # ํ›ˆ๋ จ/๊ฒ€์ฆ ๋ถ„ํ•  (ํŒŒ์ผ ๊ฒฝ๋กœ ๊ธฐ์ค€)
    filepaths_train, filepaths_val, labels_train, labels_val = train_test_split(
        filepaths, labels, test_size=validation_split, random_state=42, stratify=labels
    )
    
    print(f"์ด ๋ฐ์ดํ„ฐ: {len(filepaths)}")
    print(f"ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ: {len(filepaths_train)}, ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ: {len(filepaths_val)}")

    # ๋ฐ์ดํ„ฐ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ์ƒ์„ฑ
    train_gen = data_generator(filepaths_train, labels_train, batch_size, max_length)
    val_gen = data_generator(filepaths_val, labels_val, batch_size, max_length, shuffle=False) # ๊ฒ€์ฆ ์‹œ์—๋Š” ์…”ํ”Œ ์•ˆํ•จ

    # ๋ชจ๋ธ ์ƒ์„ฑ
    print("MalConv ๋ชจ๋ธ ์ƒ์„ฑ ์ค‘...")
    model = create_malconv_model(max_length)
    
    # ๋”๋ฏธ ์ž…๋ ฅ์œผ๋กœ ๋ชจ๋ธ ๋นŒ๋“œ
    dummy_input = np.zeros((1, max_length), dtype=np.uint8)
    _ = model(dummy_input)
    
    print("\n=== ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ===")
    model.summary()
    print(f"์ด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {model.count_params():,}")
    
    # ์ฝœ๋ฐฑ ์„ค์ •
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=5, # ์ฐธ์„์„ฑ ์ฆ๊ฐ€
            restore_best_weights=True,
            verbose=1
        ),
        tf.keras.callbacks.ModelCheckpoint(
            save_path,
            monitor='val_auc',
            save_best_only=True,
            verbose=1,
            mode='max' # AUC๋Š” ๋†’์„์ˆ˜๋ก ์ข‹์Œ
        )
    ]
    
    # ํ›ˆ๋ จ
    print(f"\n=== ํ›ˆ๋ จ ์‹œ์ž‘ ===")
    print(f"๋ฐฐ์น˜ ํฌ๊ธฐ: {batch_size}")
    print(f"์—ํฌํฌ: {epochs}")
    
    history = model.fit(
        train_gen,
        steps_per_epoch=len(filepaths_train) // batch_size,
        epochs=epochs,
        validation_data=val_gen,
        validation_steps=len(filepaths_val) // batch_size,
        callbacks=callbacks,
        verbose=1
    )
    
    # ํ‰๊ฐ€ (๋ฉ”๋ชจ๋ฆฌ ๋ฌธ์ œ๋กœ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์˜ ์ผ๋ถ€๋งŒ ์‚ฌ์šฉ)
    print("\n=== ์ตœ์ข… ํ‰๊ฐ€ ===")
    num_eval_samples = min(len(filepaths_val), 1024) # ํ‰๊ฐ€ ์ƒ˜ํ”Œ ์ˆ˜ ์ œํ•œ
    X_eval = np.array([read_binary_file(fp, max_length) for fp in filepaths_val[:num_eval_samples]])
    y_eval = np.array(labels_val[:num_eval_samples])
    
    if X_eval.size > 0:
        results = evaluate_model(model, X_eval, y_eval, batch_size=batch_size//2)
    else:
        print("ํ‰๊ฐ€ํ•  ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
        results = {}

    # ์‹œ๊ฐํ™”
    plot_training_history(history)
    
    print(f"\n๋ชจ๋ธ์ด ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {save_path}")
    
    return model, history, results

def main():
    parser = argparse.ArgumentParser(description='MalConv ๋ชจ๋ธ ํ›ˆ๋ จ')
    
    # ๋ฐ์ดํ„ฐ ์†Œ์Šค ์˜ต์…˜
    parser.add_argument('--malware_dir', required=True, help='์•…์„ฑ์ฝ”๋“œ ๋””๋ ‰ํ† ๋ฆฌ')
    parser.add_argument('--benign_dir', required=True, help='์ •์ƒํŒŒ์ผ ๋””๋ ‰ํ† ๋ฆฌ')
    
    # ํ›ˆ๋ จ ์˜ต์…˜
    parser.add_argument('--epochs', type=int, default=20, help='์—ํฌํฌ ์ˆ˜') # ์—ํฌํฌ ์ฆ๊ฐ€
    parser.add_argument('--batch_size', type=int, default=64, help='๋ฐฐ์น˜ ํฌ๊ธฐ') # ๋ฐฐ์น˜ ํฌ๊ธฐ ์กฐ์ •
    parser.add_argument('--max_length', type=int, default=2_000_000, help='์ตœ๋Œ€ ์ž…๋ ฅ ๊ธธ์ด')
    parser.add_argument('--save_path', default='models/malconv_model.h5', help='๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ')
    
    args = parser.parse_args()
    
    data_source = (args.malware_dir, args.benign_dir)
    
    # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    
    # ๋ชจ๋ธ ํ›ˆ๋ จ
    train_malconv(
        data_source=data_source,
        epochs=args.epochs,
        batch_size=args.batch_size,
        max_length=args.max_length,
        save_path=args.save_path
    )

if __name__ == "__main__":
    main()