cycloevan commited on
Commit
52b5518
ยท
verified ยท
1 Parent(s): b92918a

Update script files

Browse files
Files changed (1) hide show
  1. src/predict.py +11 -2
src/predict.py CHANGED
@@ -8,6 +8,7 @@ import pandas as pd
8
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
 
10
  from src.utils import read_binary_file
 
11
 
12
  def predict_file(model_path, file_path, max_length=2_000_000): # 2,000,000
13
  """
@@ -22,7 +23,11 @@ def predict_file(model_path, file_path, max_length=2_000_000): # 2,000,000
22
  float: ์˜ˆ์ธก ํ™•๋ฅ  (0์— ๊ฐ€๊นŒ์šฐ๋ฉด ์•…์„ฑ์ฝ”๋“œ, 1์— ๊ฐ€๊นŒ์šฐ๋ฉด ์ •์ƒ)
23
  """
24
  # ๋ชจ๋ธ ๋กœ๋“œ
25
- model = tf.keras.models.load_model(model_path)
 
 
 
 
26
 
27
  # ํŒŒ์ผ ์ฝ๊ธฐ
28
  byte_array = read_binary_file(file_path, max_length)
@@ -47,7 +52,11 @@ def predict_batch(model_path, csv_path, output_path=None, max_length=2**20):
47
  """
48
  # ๋ชจ๋ธ ๋กœ๋“œ
49
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
50
- model = tf.keras.models.load_model(model_path)
 
 
 
 
51
 
52
  # CSV ํŒŒ์ผ ์ฝ๊ธฐ
53
  df = pd.read_csv(csv_path)
 
8
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
 
10
  from src.utils import read_binary_file
11
+ from src.model import MalConv
12
 
13
  def predict_file(model_path, file_path, max_length=2_000_000): # 2,000,000
14
  """
 
23
  float: ์˜ˆ์ธก ํ™•๋ฅ  (0์— ๊ฐ€๊นŒ์šฐ๋ฉด ์•…์„ฑ์ฝ”๋“œ, 1์— ๊ฐ€๊นŒ์šฐ๋ฉด ์ •์ƒ)
24
  """
25
  # ๋ชจ๋ธ ๋กœ๋“œ
26
+ model = MalConv(max_input_length=max_length)
27
+ # ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๊ธฐ ์ „์— ๋นŒ๋“œ
28
+ dummy_input = tf.zeros((1, max_length), dtype=tf.int32)
29
+ model(dummy_input) # ๋ชจ๋ธ ๋นŒ๋“œ
30
+ model.load_weights(model_path)
31
 
32
  # ํŒŒ์ผ ์ฝ๊ธฐ
33
  byte_array = read_binary_file(file_path, max_length)
 
52
  """
53
  # ๋ชจ๋ธ ๋กœ๋“œ
54
  print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
55
+ model = MalConv(max_input_length=max_length)
56
+ # ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๊ธฐ ์ „์— ๋นŒ๋“œ
57
+ dummy_input = tf.zeros((1, max_length), dtype=tf.int32)
58
+ model(dummy_input) # ๋ชจ๋ธ ๋นŒ๋“œ
59
+ model.load_weights(model_path)
60
 
61
  # CSV ํŒŒ์ผ ์ฝ๊ธฐ
62
  df = pd.read_csv(csv_path)