Spaces:
Sleeping
Sleeping
| import onnxruntime as ort | |
| import numpy as np | |
| from typing import Dict, List, Any | |
| import json | |
| from pathlib import Path | |
| import os | |
| class ContentClassifierInference: | |
| def __init__(self, model_path: str = "contentclassifier.onnx", config_path: str = "config.json"): | |
| self.model_path = model_path | |
| self.config_path = config_path | |
| # Check if model exists | |
| if not os.path.exists(model_path): | |
| print(f"Warning: Model file {model_path} not found!") | |
| print("Creating a dummy model for testing...") | |
| try: | |
| from create_dummy_model import create_dummy_onnx_model | |
| create_dummy_onnx_model(model_path) | |
| except ImportError: | |
| raise FileNotFoundError(f"Model file {model_path} not found and couldn't create dummy model") | |
| # Load ONNX model | |
| try: | |
| self.session = ort.InferenceSession(model_path) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load ONNX model: {e}") | |
| # Load configuration | |
| if Path(config_path).exists(): | |
| with open(config_path, 'r') as f: | |
| self.config = json.load(f) | |
| else: | |
| print(f"Warning: Config file {config_path} not found, using default config") | |
| self.config = self._default_config() | |
| # Get input/output info | |
| try: | |
| self.input_name = self.session.get_inputs()[0].name | |
| self.output_name = self.session.get_outputs()[0].name | |
| except IndexError: | |
| raise ValueError("Model doesn't have expected inputs/outputs") | |
| def _default_config(self) -> Dict: | |
| return { | |
| "labels": ["safe", "unsafe"], | |
| "max_length": 512, | |
| "threshold": 0.5 | |
| } | |
| def preprocess(self, text: str) -> np.ndarray: | |
| """Preprocess text input for the model""" | |
| # Check input | |
| if not isinstance(text, str): | |
| raise TypeError(f"Input must be string, got {type(text)}") | |
| # This is a placeholder - adjust based on your model's input requirements | |
| # You might need tokenization, encoding, etc. | |
| # Example: Simple text to vector conversion (replace with actual preprocessing) | |
| encoded = text.encode('utf-8')[:self.config["max_length"]] | |
| # Pad or truncate to fixed length | |
| input_array = np.zeros(self.config["max_length"], dtype=np.float32) | |
| for i, byte_val in enumerate(encoded): | |
| if i < len(input_array): | |
| input_array[i] = float(byte_val) / 255.0 | |
| # Check input shape against model's expected input | |
| expected_shape = self.session.get_inputs()[0].shape | |
| input_shape = [1, self.config["max_length"]] | |
| if expected_shape != ['batch', self.config["max_length"]] and expected_shape != [1, self.config["max_length"]]: | |
| print(f"Warning: Model expects input shape {expected_shape}, but preprocessing produces {input_shape}") | |
| return input_array.reshape(1, -1) | |
| def predict(self, text: str) -> Dict[str, Any]: | |
| """Run inference on input text""" | |
| # Preprocess input | |
| input_data = self.preprocess(text) | |
| # Run inference | |
| outputs = self.session.run([self.output_name], {self.input_name: input_data}) | |
| predictions = outputs[0] | |
| # Postprocess results | |
| if len(predictions.shape) > 1: | |
| predictions = predictions[0] | |
| # Apply softmax if needed | |
| exp_scores = np.exp(predictions - np.max(predictions)) | |
| probabilities = exp_scores / np.sum(exp_scores) | |
| # Get predicted class | |
| predicted_class_idx = np.argmax(probabilities) | |
| predicted_class = self.config["labels"][predicted_class_idx] | |
| confidence = float(probabilities[predicted_class_idx]) | |
| # Create ONNX prediction dict | |
| onnx_prediction = { | |
| label: float(prob) for label, prob in zip(self.config["labels"], probabilities) | |
| } | |
| # Determine if content is a threat based on confidence | |
| is_threat = predicted_class == "unsafe" | |
| final_confidence = confidence | |
| # Create raw predictions dictionary | |
| raw_predictions = { | |
| "onnx": onnx_prediction, | |
| "sentiment": None # No sentiment analysis in this model, but included for compatibility | |
| } | |
| # Return the expected structure | |
| return { | |
| "is_threat": is_threat, | |
| "final_confidence": final_confidence, | |
| "threat_prediction": predicted_class, | |
| "sentiment_analysis": raw_predictions.get("sentiment"), | |
| "onnx_prediction": raw_predictions.get("onnx"), | |
| "models_used": ["onnx"], | |
| "raw_predictions": raw_predictions | |
| } | |
| def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]: | |
| """Run inference on multiple texts""" | |
| return [self.predict(text) for text in texts] | |