File size: 5,230 Bytes
3d76333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import onnxruntime as ort
import numpy as np
from typing import Dict, List, Any
import json
from pathlib import Path
import os

class ContentClassifierInference:
    def __init__(self, model_path: str = "contentclassifier.onnx", config_path: str = "config.json"):
        self.model_path = model_path
        self.config_path = config_path
        
        # Check if model exists
        if not os.path.exists(model_path):
            print(f"Warning: Model file {model_path} not found!")
            print("Creating a dummy model for testing...")
            try:
                from create_dummy_model import create_dummy_onnx_model
                create_dummy_onnx_model(model_path)
            except ImportError:
                raise FileNotFoundError(f"Model file {model_path} not found and couldn't create dummy model")
        
        # Load ONNX model
        try:
            self.session = ort.InferenceSession(model_path)
        except Exception as e:
            raise RuntimeError(f"Failed to load ONNX model: {e}")
            
        # Load configuration
        if Path(config_path).exists():
            with open(config_path, 'r') as f:
                self.config = json.load(f)
        else:
            print(f"Warning: Config file {config_path} not found, using default config")
            self.config = self._default_config()
        
        # Get input/output info
        try:
            self.input_name = self.session.get_inputs()[0].name
            self.output_name = self.session.get_outputs()[0].name
        except IndexError:
            raise ValueError("Model doesn't have expected inputs/outputs")
        
    def _default_config(self) -> Dict:
        return {
            "labels": ["safe", "unsafe"],
            "max_length": 512,
            "threshold": 0.5
        }
    
    def preprocess(self, text: str) -> np.ndarray:
        """Preprocess text input for the model"""
        # Check input
        if not isinstance(text, str):
            raise TypeError(f"Input must be string, got {type(text)}")
            
        # This is a placeholder - adjust based on your model's input requirements
        # You might need tokenization, encoding, etc.
        
        # Example: Simple text to vector conversion (replace with actual preprocessing)
        encoded = text.encode('utf-8')[:self.config["max_length"]]
        
        # Pad or truncate to fixed length
        input_array = np.zeros(self.config["max_length"], dtype=np.float32)
        for i, byte_val in enumerate(encoded):
            if i < len(input_array):
                input_array[i] = float(byte_val) / 255.0
        
        # Check input shape against model's expected input
        expected_shape = self.session.get_inputs()[0].shape
        input_shape = [1, self.config["max_length"]]
        if expected_shape != ['batch', self.config["max_length"]] and expected_shape != [1, self.config["max_length"]]:
            print(f"Warning: Model expects input shape {expected_shape}, but preprocessing produces {input_shape}")
            
        return input_array.reshape(1, -1)
    
    def predict(self, text: str) -> Dict[str, Any]:
        """Run inference on input text"""
        # Preprocess input
        input_data = self.preprocess(text)
        
        # Run inference
        outputs = self.session.run([self.output_name], {self.input_name: input_data})
        predictions = outputs[0]
        
        # Postprocess results
        if len(predictions.shape) > 1:
            predictions = predictions[0]
        
        # Apply softmax if needed
        exp_scores = np.exp(predictions - np.max(predictions))
        probabilities = exp_scores / np.sum(exp_scores)
        
        # Get predicted class
        predicted_class_idx = np.argmax(probabilities)
        predicted_class = self.config["labels"][predicted_class_idx]
        confidence = float(probabilities[predicted_class_idx])
        
        # Create ONNX prediction dict
        onnx_prediction = {
            label: float(prob) for label, prob in zip(self.config["labels"], probabilities)
        }
        
        # Determine if content is a threat based on confidence
        is_threat = predicted_class == "unsafe"
        final_confidence = confidence
        
        # Create raw predictions dictionary
        raw_predictions = {
            "onnx": onnx_prediction,
            "sentiment": None  # No sentiment analysis in this model, but included for compatibility
        }
        
        # Return the expected structure
        return {
            "is_threat": is_threat,
            "final_confidence": final_confidence,
            "threat_prediction": predicted_class,
            "sentiment_analysis": raw_predictions.get("sentiment"),
            "onnx_prediction": raw_predictions.get("onnx"),
            "models_used": ["onnx"],
            "raw_predictions": raw_predictions
        }
    
    def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
        """Run inference on multiple texts"""
        return [self.predict(text) for text in texts]