import sys import os import json import logging from typing import List, Dict, Tuple, Optional import time import numpy as np from tqdm import tqdm import onnxruntime as ort from transformers import AutoTokenizer class StopJudgmentONNXInference: def __init__(self, onnx_model_path: str, tokenizer_path: str, device: str = 'auto'): """ 判停模型ONNX推理类 Args: onnx_model_path: ONNX模型路径 tokenizer_path: tokenizer路径 device: 设备类型 ('auto', 'cuda', 'cpu') """ self.onnx_model_path = onnx_model_path self.tokenizer_path = tokenizer_path self.setup_logging() self.load_model_and_tokenizer() def setup_logging(self): """设置日志""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) def load_model_and_tokenizer(self): """加载ONNX模型和tokenizer""" # 加载tokenizer try: self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, local_files_only=True) self.logger.info("Tokenizer loaded successfully") except Exception as e: self.logger.error(f"Failed to load tokenizer: {e}") raise # 修复providers配置 providers = [] # 检查CUDA是否可用 available_providers = ort.get_available_providers() if 'CUDAExecutionProvider' in available_providers: providers.append('CUDAExecutionProvider') self.logger.info("CUDA provider is available and will be used") providers.append('CPUExecutionProvider') # 始终添加CPU作为备选 try: self.ort_session = ort.InferenceSession(self.onnx_model_path, providers=providers) self.logger.info(f"ONNX model loaded successfully with providers: {self.ort_session.get_providers()}") except Exception as e: self.logger.error(f"Failed to load ONNX model: {e}") raise # 获取输入输出信息 self.input_names = [input.name for input in self.ort_session.get_inputs()] self.output_names = [output.name for output in self.ort_session.get_outputs()] self.logger.info(f"Input names: {self.input_names}") self.logger.info(f"Output names: {self.output_names}") def preprocess_text(self, texts: List[str], max_length: int = 128) -> Dict[str, np.ndarray]: """ 预处理文本数据 Args: texts: 文本列表 max_length: 最大长度 Returns: 包含input_ids和attention_mask的字典 """ encoding = self.tokenizer( texts, truncation=True, padding='max_length', max_length=max_length, return_tensors='np' # 返回numpy数组 ) return { 'input_ids': encoding['input_ids'].astype(np.int64), 'attention_mask': encoding['attention_mask'].astype(np.int64) } def predict_single(self, text: str, max_length: int = 128) -> Tuple[int, float]: """单个文本预测""" inputs = self.preprocess_text([text], max_length) # ONNX推理 ort_inputs = { self.input_names[0]: inputs['input_ids'], self.input_names[1]: inputs['attention_mask'] } ort_outputs = self.ort_session.run(self.output_names, ort_inputs) logits = ort_outputs[0] # 计算概率和预测 probabilities = self.softmax(logits) prediction = np.argmax(probabilities[0]) confidence = probabilities[0][prediction] return int(prediction), float(confidence) def predict_batch(self, texts: List[str], max_length: int = 128, batch_size: int = 32) -> Tuple[List[int], List[float]]: """批量预测""" all_predictions = [] all_confidences = [] for i in tqdm(range(0, len(texts), batch_size), desc="ONNX Predicting"): batch_texts = texts[i:i + batch_size] inputs = self.preprocess_text(batch_texts, max_length) # ONNX推理 ort_inputs = { self.input_names[0]: inputs['input_ids'], self.input_names[1]: inputs['attention_mask'] } ort_outputs = self.ort_session.run(self.output_names, ort_inputs) logits = ort_outputs[0] # 计算概率和预测 probabilities = self.softmax(logits) predictions = np.argmax(probabilities, axis=1) confidences = [probabilities[j][pred] for j, pred in enumerate(predictions)] all_predictions.extend(predictions.tolist()) all_confidences.extend(confidences) return all_predictions, all_confidences @staticmethod def softmax(x): """Softmax函数""" exp_x = np.exp(x - np.max(x, axis=1, keepdims=True)) return exp_x / np.sum(exp_x, axis=1, keepdims=True) def main(): """主函数""" if len(sys.argv) < 3: print("Usage: python validate_onnx.py [test_sentence]") sys.exit(1) tokenizer_path = sys.argv[1] onnx_model_path = sys.argv[2] test_sentence = sys.argv[3] if len(sys.argv) > 3 else "欢迎测试本判停模型有修正建议请随时提出" print("\n ONNX Model Inference...") onnx_inferencer = StopJudgmentONNXInference(onnx_model_path, tokenizer_path) prediction, confidence = onnx_inferencer.predict_single( test_sentence, max_length=128 ) print(prediction, confidence) if __name__ == "__main__": main()