| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import pickle
|
| import joblib
|
| import numpy as np
|
|
|
|
|
| class SimpleClassifier(nn.Module):
|
| def __init__(self, input_size, num_classes):
|
| super(SimpleClassifier, self).__init__()
|
| self.fc1 = nn.Linear(input_size, 256)
|
| self.fc2 = nn.Linear(256, 128)
|
| self.fc3 = nn.Linear(128, num_classes)
|
| self.dropout = nn.Dropout(0.3)
|
|
|
| def forward(self, x):
|
| x = F.relu(self.fc1(x))
|
| x = self.dropout(x)
|
| x = F.relu(self.fc2(x))
|
| x = self.dropout(x)
|
| x = self.fc3(x)
|
| return x
|
|
|
| def test_current_model():
|
| """ํ์ฌ ๋ชจ๋ธ ํ
์คํธ"""
|
| print("=== ํ์ฌ ๋ชจ๋ธ ํ
์คํธ ===")
|
|
|
| try:
|
|
|
| with open('config.json', 'r', encoding='utf-8') as f:
|
| import json
|
| config = json.load(f)
|
|
|
| id2label = config.get('id2label', {})
|
| print(f"๋ผ๋ฒจ ์: {len(id2label)}")
|
|
|
|
|
| input_size = 3000
|
| num_classes = len(id2label)
|
| model = SimpleClassifier(input_size, num_classes)
|
| model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
|
|
|
|
|
| vectorizer = joblib.load('vectorizer.pkl')
|
|
|
| model.eval()
|
|
|
|
|
| test_words = ["์ฒ ใน", "CaO", "ํด๋ฉด์ฒ ", "๋ฑ๋ฅ", "ํ์์ฒ "]
|
|
|
| for word in test_words:
|
| print(f"\n{'='*50}")
|
| print(f"์
๋ ฅ: '{word}'")
|
| print(f"{'='*50}")
|
|
|
|
|
| word_vector = vectorizer.transform([word]).toarray()
|
| word_tensor = torch.FloatTensor(word_vector)
|
|
|
| with torch.no_grad():
|
| outputs = model(word_tensor)
|
| probabilities = F.softmax(outputs, dim=1)
|
|
|
|
|
| top_probs, top_indices = torch.topk(probabilities, 5, dim=1)
|
|
|
| print(f"์ต๋ ํ๋ฅ : {probabilities.max().item():.4f} ({probabilities.max().item()*100:.1f}%)")
|
| print(f"์์ 5๊ฐ ์์ธก:")
|
|
|
| for i in range(5):
|
| label_id = top_indices[0][i].item()
|
| probability = top_probs[0][i].item()
|
| label = id2label.get(str(label_id), f"Unknown_{label_id}")
|
| print(f" {i+1}. {label}: {probability:.4f} ({probability*100:.1f}%)")
|
|
|
| except Exception as e:
|
| print(f"์๋ฌ ๋ฐ์: {e}")
|
| import traceback
|
| traceback.print_exc()
|
|
|
| if __name__ == "__main__":
|
| test_current_model()
|
|
|