RNN_test_Model / test.py
OneclickAI's picture
Upload 4 files
9a24bb1 verified
import tensorflow as tf
from tensorflow import keras
import numpy as np
# --- 1. ๋ชจ๋ธ๊ณผ ๋‹จ์–ด ์‚ฌ์ „ ๋กœ๋“œ ---
# ์ €์žฅ๋œ Keras ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model_path = "my_rnn_model_imdb.keras"
try:
loaded_model = keras.models.load_model(model_path)
print(f"'{model_path}' ๋ชจ๋ธ์„ ์„ฑ๊ณต์ ์œผ๋กœ ๋ถˆ๋Ÿฌ์™”์Šต๋‹ˆ๋‹ค.")
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
exit()
# IMDB ๋ฐ์ดํ„ฐ์…‹์˜ ๋‹จ์–ด-์ธ๋ฑ์Šค ์‚ฌ์ „ ๋กœ๋“œ
word_index = keras.datasets.imdb.get_word_index()
# Keras์˜ ์˜ˆ์•ฝ๋œ ์ธ๋ฑ์Šค๋ฅผ ๋ฐ˜์˜ํ•˜์—ฌ 3๋งŒํผ ์˜คํ”„์…‹ ์ถ”๊ฐ€
word_index = {k: (v + 3) for k, v in word_index.items()}
word_index["<pad>"] = 0
word_index["<start>"] = 1
word_index["<unk>"] = 2 # ์•Œ๋ ค์ง€์ง€ ์•Š์€ ๋‹จ์–ด(out-of-vocabulary)
word_index["<unused>"] = 3
# --- 2. ์˜ˆ์ธก์„ ์œ„ํ•œ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜ ---
MAX_LEN = 256
def preprocess_text(text):
"""
์ƒˆ๋กœ์šด ํ…์ŠคํŠธ๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
"""
# ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๊ณ  ์ •์ˆ˜๋กœ ์ธ์ฝ”๋”ฉ
tokens = [word_index.get(word, 2) for word in text.lower().split()]
# <start> ์ธ๋ฑ์Šค ์ถ”๊ฐ€
tokens = [word_index["<start>"]] + tokens
# ์‹œํ€€์Šค ํŒจ๋”ฉ
padded_sequence = keras.preprocessing.sequence.pad_sequences(
[tokens], maxlen=MAX_LEN, padding='pre'
)
return padded_sequence
# --- 3. ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๊ธฐ๋ฐ˜ ์˜ˆ์ธก ์‹คํ–‰ ---
print("\n์˜ํ™” ๋ฆฌ๋ทฐ ๊ฐ์„ฑ ๋ถ„์„๊ธฐ (์ข…๋ฃŒํ•˜๋ ค๋ฉด 'exit'๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”)")
print("-" * 50)
while True:
# ์‚ฌ์šฉ์ž๋กœ๋ถ€ํ„ฐ ๋ฆฌ๋ทฐ ์ž…๋ ฅ๋ฐ›๊ธฐ
review_text = input("๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”: ")
if review_text.lower() == 'exit':
print("ํ”„๋กœ๊ทธ๋žจ์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
break
if not review_text.strip():
print("์ž…๋ ฅ๋œ ๋‚ด์šฉ์ด ์—†์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”.")
continue
# ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ
processed_input = preprocess_text(review_text)
# ๋ชจ๋ธ๋กœ ์˜ˆ์ธก ์ˆ˜ํ–‰
prediction = loaded_model.predict(processed_input)
# ๊ฒฐ๊ณผ ํ•ด์„ (sigmoid ์ถœ๋ ฅ > 0.5 ์ด๋ฉด ๊ธ์ •)
score = prediction[0][0]
sentiment = "๊ธ์ • (Positive)" if score > 0.5 else "๋ถ€์ • (Negative)"
print(f"๊ฒฐ๊ณผ: {sentiment} (์˜ˆ์ธก ์ ์ˆ˜: {score:.4f})")
print("-" * 50)