|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import pandas as pd |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from src.utils import read_binary_file |
|
|
from src.model import MalConv |
|
|
|
|
|
def predict_file(model_path, file_path, max_length=2_000_000): |
|
|
""" |
|
|
๋จ์ผ ํ์ผ์ ๋ํ ์์ธก |
|
|
|
|
|
Args: |
|
|
model_path: ์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก |
|
|
file_path: ์์ธกํ ํ์ผ ๊ฒฝ๋ก |
|
|
max_length: ์ต๋ ์
๋ ฅ ๊ธธ์ด |
|
|
|
|
|
Returns: |
|
|
float: ์์ธก ํ๋ฅ (0์ ๊ฐ๊น์ฐ๋ฉด ์
์ฑ์ฝ๋, 1์ ๊ฐ๊น์ฐ๋ฉด ์ ์) |
|
|
""" |
|
|
|
|
|
model = MalConv(max_input_length=max_length) |
|
|
|
|
|
dummy_input = tf.zeros((1, max_length), dtype=tf.int32) |
|
|
model(dummy_input) |
|
|
model.load_weights(model_path) |
|
|
|
|
|
|
|
|
byte_array = read_binary_file(file_path, max_length) |
|
|
|
|
|
|
|
|
input_data = np.expand_dims(byte_array, axis=0) |
|
|
|
|
|
|
|
|
prediction = model.predict(input_data, verbose=0)[0][0] |
|
|
|
|
|
return prediction |
|
|
|
|
|
def predict_batch(model_path, csv_path, output_path=None, max_length=2**20): |
|
|
""" |
|
|
๋ฐฐ์น ์์ธก |
|
|
|
|
|
Args: |
|
|
model_path: ์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก |
|
|
csv_path: ์์ธกํ ํ์ผ๋ค์ CSV ๊ฒฝ๋ก |
|
|
output_path: ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก |
|
|
max_length: ์ต๋ ์
๋ ฅ ๊ธธ์ด |
|
|
""" |
|
|
|
|
|
print("๋ชจ๋ธ ๋ก๋ฉ ์ค...") |
|
|
model = MalConv(max_input_length=max_length) |
|
|
|
|
|
dummy_input = tf.zeros((1, max_length), dtype=tf.int32) |
|
|
model(dummy_input) |
|
|
model.load_weights(model_path) |
|
|
|
|
|
|
|
|
df = pd.read_csv(csv_path) |
|
|
|
|
|
predictions = [] |
|
|
labels = [] |
|
|
|
|
|
print("์์ธก ์ค...") |
|
|
for idx, row in df.iterrows(): |
|
|
file_path = row['filepath'] |
|
|
|
|
|
if os.path.exists(file_path): |
|
|
try: |
|
|
|
|
|
byte_array = read_binary_file(file_path, max_length) |
|
|
input_data = np.expand_dims(byte_array, axis=0) |
|
|
|
|
|
|
|
|
pred = model.predict(input_data, verbose=0)[0][0] |
|
|
predictions.append(pred) |
|
|
|
|
|
|
|
|
if 'label' in row: |
|
|
labels.append(row['label']) |
|
|
|
|
|
|
|
|
status = "์ ์" if pred > 0.5 else "์
์ฑ์ฝ๋" |
|
|
confidence = pred if pred > 0.5 else 1 - pred |
|
|
print(f"{file_path}: {status} (์ ๋ขฐ๋: {confidence:.4f})") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {file_path}: {e}") |
|
|
predictions.append(-1) |
|
|
else: |
|
|
print(f"ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {file_path}") |
|
|
predictions.append(-1) |
|
|
|
|
|
|
|
|
result_df = df.copy() |
|
|
result_df['prediction'] = predictions |
|
|
result_df['predicted_label'] = (np.array(predictions) > 0.5).astype(int) |
|
|
result_df['prediction_text'] = ['์ ์' if p > 0.5 else '์
์ฑ์ฝ๋' if p >= 0 else '์๋ฌ' |
|
|
for p in predictions] |
|
|
|
|
|
if output_path: |
|
|
result_df.to_csv(output_path, index=False) |
|
|
print(f"๊ฒฐ๊ณผ๊ฐ ์ ์ฅ๋์์ต๋๋ค: {output_path}") |
|
|
|
|
|
|
|
|
if labels and len(labels) == len(predictions): |
|
|
valid_predictions = [p for p in predictions if p >= 0] |
|
|
valid_labels = [labels[i] for i, p in enumerate(predictions) if p >= 0] |
|
|
|
|
|
if valid_predictions: |
|
|
pred_binary = (np.array(valid_predictions) > 0.5).astype(int) |
|
|
accuracy = np.mean(pred_binary == np.array(valid_labels)) |
|
|
print(f"\n์ ํ๋: {accuracy:.4f}") |
|
|
|
|
|
return result_df |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='MalConv ๋ชจ๋ธ ์์ธก') |
|
|
parser.add_argument('model_path', help='์ ์ฅ๋ ๋ชจ๋ธ ๊ฒฝ๋ก') |
|
|
parser.add_argument('--file', help='๋จ์ผ ํ์ผ ์์ธก') |
|
|
parser.add_argument('--csv', help='๋ฐฐ์น ์์ธก์ฉ CSV ํ์ผ') |
|
|
parser.add_argument('--output', help='๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก') |
|
|
parser.add_argument('--max_length', type=int, default=2**20, help='์ต๋ ์
๋ ฅ ๊ธธ์ด') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.file: |
|
|
|
|
|
prediction = predict_file(args.model_path, args.file, args.max_length) |
|
|
status = "์ ์" if prediction > 0.5 else "์
์ฑ์ฝ๋" |
|
|
confidence = prediction if prediction > 0.5 else 1 - prediction |
|
|
print(f"ํ์ผ: {args.file}") |
|
|
print(f"์์ธก: {status} (์ ๋ขฐ๋: {confidence:.4f})") |
|
|
|
|
|
elif args.csv: |
|
|
|
|
|
predict_batch(args.model_path, args.csv, args.output, args.max_length) |
|
|
|
|
|
else: |
|
|
print("--file ๋๋ --csv ์ต์
์ ์ง์ ํด์ฃผ์ธ์.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|