File size: 2,377 Bytes
9a24bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import tensorflow as tf
from tensorflow import keras
import numpy as np
# --- 1. ๋ชจ๋ธ๊ณผ ๋จ์ด ์ฌ์ ๋ก๋ ---
# ์ ์ฅ๋ Keras ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
model_path = "my_rnn_model_imdb.keras"
try:
loaded_model = keras.models.load_model(model_path)
print(f"'{model_path}' ๋ชจ๋ธ์ ์ฑ๊ณต์ ์ผ๋ก ๋ถ๋ฌ์์ต๋๋ค.")
except Exception as e:
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {e}")
exit()
# IMDB ๋ฐ์ดํฐ์
์ ๋จ์ด-์ธ๋ฑ์ค ์ฌ์ ๋ก๋
word_index = keras.datasets.imdb.get_word_index()
# Keras์ ์์ฝ๋ ์ธ๋ฑ์ค๋ฅผ ๋ฐ์ํ์ฌ 3๋งํผ ์คํ์
์ถ๊ฐ
word_index = {k: (v + 3) for k, v in word_index.items()}
word_index["<pad>"] = 0
word_index["<start>"] = 1
word_index["<unk>"] = 2 # ์๋ ค์ง์ง ์์ ๋จ์ด(out-of-vocabulary)
word_index["<unused>"] = 3
# --- 2. ์์ธก์ ์ํ ์ ์ฒ๋ฆฌ ํจ์ ---
MAX_LEN = 256
def preprocess_text(text):
"""
์๋ก์ด ํ
์คํธ๋ฅผ ๋ชจ๋ธ ์
๋ ฅ ํ์์ ๋ง๊ฒ ์ ์ฒ๋ฆฌํฉ๋๋ค.
"""
# ํ
์คํธ๋ฅผ ํ ํฐํํ๊ณ ์ ์๋ก ์ธ์ฝ๋ฉ
tokens = [word_index.get(word, 2) for word in text.lower().split()]
# <start> ์ธ๋ฑ์ค ์ถ๊ฐ
tokens = [word_index["<start>"]] + tokens
# ์ํ์ค ํจ๋ฉ
padded_sequence = keras.preprocessing.sequence.pad_sequences(
[tokens], maxlen=MAX_LEN, padding='pre'
)
return padded_sequence
# --- 3. ์ฌ์ฉ์ ์
๋ ฅ ๊ธฐ๋ฐ ์์ธก ์คํ ---
print("\n์ํ ๋ฆฌ๋ทฐ ๊ฐ์ฑ ๋ถ์๊ธฐ (์ข
๋ฃํ๋ ค๋ฉด 'exit'๋ฅผ ์
๋ ฅํ์ธ์)")
print("-" * 50)
while True:
# ์ฌ์ฉ์๋ก๋ถํฐ ๋ฆฌ๋ทฐ ์
๋ ฅ๋ฐ๊ธฐ
review_text = input("๋ฆฌ๋ทฐ๋ฅผ ์
๋ ฅํ์ธ์: ")
if review_text.lower() == 'exit':
print("ํ๋ก๊ทธ๋จ์ ์ข
๋ฃํฉ๋๋ค.")
break
if not review_text.strip():
print("์
๋ ฅ๋ ๋ด์ฉ์ด ์์ต๋๋ค. ๋ค์ ์๋ํด์ฃผ์ธ์.")
continue
# ํ
์คํธ ์ ์ฒ๋ฆฌ
processed_input = preprocess_text(review_text)
# ๋ชจ๋ธ๋ก ์์ธก ์ํ
prediction = loaded_model.predict(processed_input)
# ๊ฒฐ๊ณผ ํด์ (sigmoid ์ถ๋ ฅ > 0.5 ์ด๋ฉด ๊ธ์ )
score = prediction[0][0]
sentiment = "๊ธ์ (Positive)" if score > 0.5 else "๋ถ์ (Negative)"
print(f"๊ฒฐ๊ณผ: {sentiment} (์์ธก ์ ์: {score:.4f})")
print("-" * 50) |