import os import sys import argparse import numpy as np import tensorflow as tf import pandas as pd sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.utils import read_binary_file from src.model import MalConv def predict_file(model_path, file_path, max_length=2_000_000): # 2,000,000 """ 단일 파일에 대한 예측 Args: model_path: 저장된 모델 경로 file_path: 예측할 파일 경로 max_length: 최대 입력 길이 Returns: float: 예측 확률 (0에 가까우면 악성코드, 1에 가까우면 정상) """ # 모델 로드 model = MalConv(max_input_length=max_length) # 모델의 가중치를 로드하기 전에 빌드 dummy_input = tf.zeros((1, max_length), dtype=tf.int32) model(dummy_input) # 모델 빌드 model.load_weights(model_path) # 파일 읽기 byte_array = read_binary_file(file_path, max_length) # 배치 차원 추가 input_data = np.expand_dims(byte_array, axis=0) # 예측 prediction = model.predict(input_data, verbose=0)[0][0] return prediction def predict_batch(model_path, csv_path, output_path=None, max_length=2**20): """ 배치 예측 Args: model_path: 저장된 모델 경로 csv_path: 예측할 파일들의 CSV 경로 output_path: 결과 저장 경로 max_length: 최대 입력 길이 """ # 모델 로드 print("모델 로딩 중...") model = MalConv(max_input_length=max_length) # 모델의 가중치를 로드하기 전에 빌드 dummy_input = tf.zeros((1, max_length), dtype=tf.int32) model(dummy_input) # 모델 빌드 model.load_weights(model_path) # CSV 파일 읽기 df = pd.read_csv(csv_path) predictions = [] labels = [] print("예측 중...") for idx, row in df.iterrows(): file_path = row['filepath'] if os.path.exists(file_path): try: # 파일 읽기 byte_array = read_binary_file(file_path, max_length) input_data = np.expand_dims(byte_array, axis=0) # 예측 pred = model.predict(input_data, verbose=0)[0][0] predictions.append(pred) # 라벨이 있는 경우 if 'label' in row: labels.append(row['label']) # 결과 출력 status = "정상" if pred > 0.5 else "악성코드" confidence = pred if pred > 0.5 else 1 - pred print(f"{file_path}: {status} (신뢰도: {confidence:.4f})") except Exception as e: print(f"Error processing {file_path}: {e}") predictions.append(-1) # 에러 표시 else: print(f"파일을 찾을 수 없습니다: {file_path}") predictions.append(-1) # 결과 저장 result_df = df.copy() result_df['prediction'] = predictions result_df['predicted_label'] = (np.array(predictions) > 0.5).astype(int) result_df['prediction_text'] = ['정상' if p > 0.5 else '악성코드' if p >= 0 else '에러' for p in predictions] if output_path: result_df.to_csv(output_path, index=False) print(f"결과가 저장되었습니다: {output_path}") # 정확도 계산 (라벨이 있는 경우) if labels and len(labels) == len(predictions): valid_predictions = [p for p in predictions if p >= 0] valid_labels = [labels[i] for i, p in enumerate(predictions) if p >= 0] if valid_predictions: pred_binary = (np.array(valid_predictions) > 0.5).astype(int) accuracy = np.mean(pred_binary == np.array(valid_labels)) print(f"\n정확도: {accuracy:.4f}") return result_df def main(): parser = argparse.ArgumentParser(description='MalConv 모델 예측') parser.add_argument('model_path', help='저장된 모델 경로') parser.add_argument('--file', help='단일 파일 예측') parser.add_argument('--csv', help='배치 예측용 CSV 파일') parser.add_argument('--output', help='결과 저장 경로') parser.add_argument('--max_length', type=int, default=2**20, help='최대 입력 길이') args = parser.parse_args() if args.file: # 단일 파일 예측 prediction = predict_file(args.model_path, args.file, args.max_length) status = "정상" if prediction > 0.5 else "악성코드" confidence = prediction if prediction > 0.5 else 1 - prediction print(f"파일: {args.file}") print(f"예측: {status} (신뢰도: {confidence:.4f})") elif args.csv: # 배치 예측 predict_batch(args.model_path, args.csv, args.output, args.max_length) else: print("--file 또는 --csv 옵션을 지정해주세요.") if __name__ == "__main__": main()