flud / test_current_model.py
Halfotter's picture
Upload folder using huggingface_hub
2f1dcf0 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import joblib
import numpy as np
# SimpleClassifier ํด๋ž˜์Šค ์ •์˜
class SimpleClassifier(nn.Module):
def __init__(self, input_size, num_classes):
super(SimpleClassifier, self).__init__()
self.fc1 = nn.Linear(input_size, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, num_classes)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
def test_current_model():
"""ํ˜„์žฌ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ"""
print("=== ํ˜„์žฌ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ===")
try:
# ์„ค์ • ๋กœ๋“œ
with open('config.json', 'r', encoding='utf-8') as f:
import json
config = json.load(f)
id2label = config.get('id2label', {})
print(f"๋ผ๋ฒจ ์ˆ˜: {len(id2label)}")
# ๋ชจ๋ธ ๋กœ๋“œ
input_size = 3000 # TF-IDF ํŠน์„ฑ ์ˆ˜
num_classes = len(id2label)
model = SimpleClassifier(input_size, num_classes)
model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
# ๋ฒกํ„ฐ๋ผ์ด์ € ๋กœ๋“œ
vectorizer = joblib.load('vectorizer.pkl')
model.eval()
# ํ…Œ์ŠคํŠธ ๋‹จ์–ด๋“ค (ํ™˜์›์ฒ  ํฌํ•จ)
test_words = ["์ฒ ใ„น", "CaO", "ํ•ด๋ฉด์ฒ ", "๋“ฑ๋ฅ˜", "ํ™˜์›์ฒ "]
for word in test_words:
print(f"\n{'='*50}")
print(f"์ž…๋ ฅ: '{word}'")
print(f"{'='*50}")
# TF-IDF ๋ฒกํ„ฐํ™”
word_vector = vectorizer.transform([word]).toarray()
word_tensor = torch.FloatTensor(word_vector)
with torch.no_grad():
outputs = model(word_tensor)
probabilities = F.softmax(outputs, dim=1)
# ์ƒ์œ„ 5๊ฐœ ์˜ˆ์ธก
top_probs, top_indices = torch.topk(probabilities, 5, dim=1)
print(f"์ตœ๋Œ€ ํ™•๋ฅ : {probabilities.max().item():.4f} ({probabilities.max().item()*100:.1f}%)")
print(f"์ƒ์œ„ 5๊ฐœ ์˜ˆ์ธก:")
for i in range(5):
label_id = top_indices[0][i].item()
probability = top_probs[0][i].item()
label = id2label.get(str(label_id), f"Unknown_{label_id}")
print(f" {i+1}. {label}: {probability:.4f} ({probability*100:.1f}%)")
except Exception as e:
print(f"์—๋Ÿฌ ๋ฐœ์ƒ: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_current_model()