flud / inference.py
Halfotter's picture
Upload inference.py with huggingface_hub
25b8ff1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import os
from transformers import PreTrainedModel, PretrainedConfig, XLMRobertaModel, XLMRobertaConfig
class XLMSteelConfig(PretrainedConfig):
"""XLM-RoBERTa 철강 분류기 설정"""
model_type = "xlm_steel_classifier"
def __init__(self, num_labels=66, **kwargs):
super().__init__(**kwargs)
self.num_labels = num_labels
class XLMIntegratedModel(PreTrainedModel):
"""XLM-RoBERTa + TF-IDF 통합 모델"""
config_class = XLMSteelConfig
def __init__(self, config):
super().__init__(config)
# XLM-RoBERTa 모델
self.xlm_roberta = XLMRobertaModel.from_pretrained('xlm-roberta-base')
# TF-IDF 벡터라이저 정보 저장
self.feature_names = getattr(config, 'feature_names', [])
self.input_size = getattr(config, 'input_size', 3000)
# 신경망 레이어 (기존 TF-IDF 모델 구조)
self.fc1 = nn.Linear(self.input_size, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, config.num_labels)
self.dropout = nn.Dropout(0.3)
# 라벨 매핑 저장
self.id2label = config.id2label
self.num_classes = config.num_labels
# 벡터라이저의 특성 정보를 텐서로 저장
self.register_buffer('feature_names_list', torch.tensor([hash(f) for f in self.feature_names], dtype=torch.long))
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
"""통합 forward"""
# XLM-RoBERTa 출력
if input_ids is not None:
xlm_outputs = self.xlm_roberta(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
xlm_features = xlm_outputs.pooler_output
else:
xlm_features = torch.zeros(1, self.xlm_roberta.config.hidden_size)
# TF-IDF 벡터화 (내부적으로 수행)
if input_ids is not None:
# input_ids를 텍스트로 변환하여 TF-IDF 벡터화
text_vector = self._vectorize_from_ids(input_ids[0])
tfidf_features = torch.FloatTensor(text_vector).unsqueeze(0)
else:
tfidf_features = torch.zeros(1, self.input_size)
# 신경망 통과 (TF-IDF 부분만 사용)
x = F.relu(self.fc1(tfidf_features))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
logits = self.fc3(x)
# 손실 계산
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def _vectorize_from_ids(self, input_ids):
"""input_ids를 TF-IDF 벡터로 변환"""
vector = np.zeros(self.input_size)
# input_ids를 기반으로 벡터 생성
for token_id in input_ids:
if token_id < self.input_size:
vector[token_id] += 1
if np.sum(vector) > 0:
vector = vector / np.sum(vector)
return vector
# 전역 변수
model = None
def load_model():
"""모델 로드"""
global model
# 설정 파일 로드
config_path = os.path.join(os.getcwd(), "config.json")
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
# XLMSteelConfig 생성
config = XLMSteelConfig(
num_labels=config_data['num_labels'],
id2label=config_data['id2label'],
label2id=config_data['label2id'],
feature_names=config_data.get('feature_names', []),
input_size=config_data.get('input_size', 3000)
)
# 모델 생성 및 로드
model = XLMIntegratedModel(config)
model_path = os.path.join(os.getcwd(), "xlm_integrated_model.bin")
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()
return model
def predict(inputs):
"""예측 함수"""
global model
if model is None:
model = load_model()
# 입력 처리
if isinstance(inputs, str):
text = inputs
elif isinstance(inputs, list):
text = inputs[0] if len(inputs) > 0 else ""
elif isinstance(inputs, dict) and "inputs" in inputs:
text = inputs["inputs"]
else:
text = str(inputs)
# 텍스트를 토큰 ID로 변환 (간단한 구현)
tokens = text.lower().split()
input_ids = torch.tensor([[hash(token) % 50000 for token in tokens]]) # XLM-RoBERTa vocab size
attention_mask = torch.ones_like(input_ids)
# 예측
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs["logits"]
probabilities = F.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
label = model.id2label[str(predicted_class)]
confidence = probabilities[0][predicted_class].item()
return {
"label": label,
"confidence": confidence,
"text": text
}
# 모델 초기 로드
if __name__ == "__main__":
load_model()