import onnxruntime as ort import numpy as np from typing import Dict, List, Any import json from pathlib import Path import os class ContentClassifierInference: def __init__(self, model_path: str = "contentclassifier.onnx", config_path: str = "config.json"): self.model_path = model_path self.config_path = config_path # Check if model exists if not os.path.exists(model_path): print(f"Warning: Model file {model_path} not found!") print("Creating a dummy model for testing...") try: from create_dummy_model import create_dummy_onnx_model create_dummy_onnx_model(model_path) except ImportError: raise FileNotFoundError(f"Model file {model_path} not found and couldn't create dummy model") # Load ONNX model try: self.session = ort.InferenceSession(model_path) except Exception as e: raise RuntimeError(f"Failed to load ONNX model: {e}") # Load configuration if Path(config_path).exists(): with open(config_path, 'r') as f: self.config = json.load(f) else: print(f"Warning: Config file {config_path} not found, using default config") self.config = self._default_config() # Get input/output info try: self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name except IndexError: raise ValueError("Model doesn't have expected inputs/outputs") def _default_config(self) -> Dict: return { "labels": ["safe", "unsafe"], "max_length": 512, "threshold": 0.5 } def preprocess(self, text: str) -> np.ndarray: """Preprocess text input for the model""" # Check input if not isinstance(text, str): raise TypeError(f"Input must be string, got {type(text)}") # This is a placeholder - adjust based on your model's input requirements # You might need tokenization, encoding, etc. # Example: Simple text to vector conversion (replace with actual preprocessing) encoded = text.encode('utf-8')[:self.config["max_length"]] # Pad or truncate to fixed length input_array = np.zeros(self.config["max_length"], dtype=np.float32) for i, byte_val in enumerate(encoded): if i < len(input_array): input_array[i] = float(byte_val) / 255.0 # Check input shape against model's expected input expected_shape = self.session.get_inputs()[0].shape input_shape = [1, self.config["max_length"]] if expected_shape != ['batch', self.config["max_length"]] and expected_shape != [1, self.config["max_length"]]: print(f"Warning: Model expects input shape {expected_shape}, but preprocessing produces {input_shape}") return input_array.reshape(1, -1) def predict(self, text: str) -> Dict[str, Any]: """Run inference on input text""" # Preprocess input input_data = self.preprocess(text) # Run inference outputs = self.session.run([self.output_name], {self.input_name: input_data}) predictions = outputs[0] # Postprocess results if len(predictions.shape) > 1: predictions = predictions[0] # Apply softmax if needed exp_scores = np.exp(predictions - np.max(predictions)) probabilities = exp_scores / np.sum(exp_scores) # Get predicted class predicted_class_idx = np.argmax(probabilities) predicted_class = self.config["labels"][predicted_class_idx] confidence = float(probabilities[predicted_class_idx]) # Create ONNX prediction dict onnx_prediction = { label: float(prob) for label, prob in zip(self.config["labels"], probabilities) } # Determine if content is a threat based on confidence is_threat = predicted_class == "unsafe" final_confidence = confidence # Create raw predictions dictionary raw_predictions = { "onnx": onnx_prediction, "sentiment": None # No sentiment analysis in this model, but included for compatibility } # Return the expected structure return { "is_threat": is_threat, "final_confidence": final_confidence, "threat_prediction": predicted_class, "sentiment_analysis": raw_predictions.get("sentiment"), "onnx_prediction": raw_predictions.get("onnx"), "models_used": ["onnx"], "raw_predictions": raw_predictions } def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]: """Run inference on multiple texts""" return [self.predict(text) for text in texts]