flud / test_sintered_ore.py
Halfotter's picture
Upload folder using huggingface_hub
2f1dcf0 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import joblib
import numpy as np
# SimpleClassifier 클래슀 μ •μ˜
class SimpleClassifier(nn.Module):
def __init__(self, input_size, num_classes):
super(SimpleClassifier, self).__init__()
self.fc1 = nn.Linear(input_size, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, num_classes)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
def test_sintered_ore():
"""μ†Œκ²°κ΄‘ ν…ŒμŠ€νŠΈ"""
print("=== μ†Œκ²°κ΄‘ ν…ŒμŠ€νŠΈ ===")
try:
# μ„€μ • λ‘œλ“œ
with open('config.json', 'r', encoding='utf-8') as f:
import json
config = json.load(f)
id2label = config.get('id2label', {})
# λͺ¨λΈ λ‘œλ“œ
input_size = 3000
num_classes = len(id2label)
model = SimpleClassifier(input_size, num_classes)
model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
# 벑터라이저 λ‘œλ“œ
vectorizer = joblib.load('vectorizer.pkl')
model.eval()
# μ†Œκ²°κ΄‘ ν…ŒμŠ€νŠΈ
test_word = "μ†Œκ²°κ΄‘"
print(f"μž…λ ₯: '{test_word}'")
# TF-IDF 벑터화
word_vector = vectorizer.transform([test_word]).toarray()
word_tensor = torch.FloatTensor(word_vector)
with torch.no_grad():
outputs = model(word_tensor)
probabilities = F.softmax(outputs, dim=1)
# μƒμœ„ 10개 예츑
top_probs, top_indices = torch.topk(probabilities, 10, dim=1)
print(f"μ΅œλŒ€ ν™•λ₯ : {probabilities.max().item():.4f} ({probabilities.max().item()*100:.1f}%)")
print(f"μƒμœ„ 10개 예츑:")
for i in range(10):
label_id = top_indices[0][i].item()
probability = top_probs[0][i].item()
label = id2label.get(str(label_id), f"Unknown_{label_id}")
print(f" {i+1}. {label}: {probability:.4f} ({probability*100:.1f}%)")
except Exception as e:
print(f"μ—λŸ¬ λ°œμƒ: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_sintered_ore()