parthraninga commited on
Commit
3d76333
·
verified ·
1 Parent(s): c28eb35

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +127 -0
inference.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from typing import Dict, List, Any
4
+ import json
5
+ from pathlib import Path
6
+ import os
7
+
8
+ class ContentClassifierInference:
9
+ def __init__(self, model_path: str = "contentclassifier.onnx", config_path: str = "config.json"):
10
+ self.model_path = model_path
11
+ self.config_path = config_path
12
+
13
+ # Check if model exists
14
+ if not os.path.exists(model_path):
15
+ print(f"Warning: Model file {model_path} not found!")
16
+ print("Creating a dummy model for testing...")
17
+ try:
18
+ from create_dummy_model import create_dummy_onnx_model
19
+ create_dummy_onnx_model(model_path)
20
+ except ImportError:
21
+ raise FileNotFoundError(f"Model file {model_path} not found and couldn't create dummy model")
22
+
23
+ # Load ONNX model
24
+ try:
25
+ self.session = ort.InferenceSession(model_path)
26
+ except Exception as e:
27
+ raise RuntimeError(f"Failed to load ONNX model: {e}")
28
+
29
+ # Load configuration
30
+ if Path(config_path).exists():
31
+ with open(config_path, 'r') as f:
32
+ self.config = json.load(f)
33
+ else:
34
+ print(f"Warning: Config file {config_path} not found, using default config")
35
+ self.config = self._default_config()
36
+
37
+ # Get input/output info
38
+ try:
39
+ self.input_name = self.session.get_inputs()[0].name
40
+ self.output_name = self.session.get_outputs()[0].name
41
+ except IndexError:
42
+ raise ValueError("Model doesn't have expected inputs/outputs")
43
+
44
+ def _default_config(self) -> Dict:
45
+ return {
46
+ "labels": ["safe", "unsafe"],
47
+ "max_length": 512,
48
+ "threshold": 0.5
49
+ }
50
+
51
+ def preprocess(self, text: str) -> np.ndarray:
52
+ """Preprocess text input for the model"""
53
+ # Check input
54
+ if not isinstance(text, str):
55
+ raise TypeError(f"Input must be string, got {type(text)}")
56
+
57
+ # This is a placeholder - adjust based on your model's input requirements
58
+ # You might need tokenization, encoding, etc.
59
+
60
+ # Example: Simple text to vector conversion (replace with actual preprocessing)
61
+ encoded = text.encode('utf-8')[:self.config["max_length"]]
62
+
63
+ # Pad or truncate to fixed length
64
+ input_array = np.zeros(self.config["max_length"], dtype=np.float32)
65
+ for i, byte_val in enumerate(encoded):
66
+ if i < len(input_array):
67
+ input_array[i] = float(byte_val) / 255.0
68
+
69
+ # Check input shape against model's expected input
70
+ expected_shape = self.session.get_inputs()[0].shape
71
+ input_shape = [1, self.config["max_length"]]
72
+ if expected_shape != ['batch', self.config["max_length"]] and expected_shape != [1, self.config["max_length"]]:
73
+ print(f"Warning: Model expects input shape {expected_shape}, but preprocessing produces {input_shape}")
74
+
75
+ return input_array.reshape(1, -1)
76
+
77
+ def predict(self, text: str) -> Dict[str, Any]:
78
+ """Run inference on input text"""
79
+ # Preprocess input
80
+ input_data = self.preprocess(text)
81
+
82
+ # Run inference
83
+ outputs = self.session.run([self.output_name], {self.input_name: input_data})
84
+ predictions = outputs[0]
85
+
86
+ # Postprocess results
87
+ if len(predictions.shape) > 1:
88
+ predictions = predictions[0]
89
+
90
+ # Apply softmax if needed
91
+ exp_scores = np.exp(predictions - np.max(predictions))
92
+ probabilities = exp_scores / np.sum(exp_scores)
93
+
94
+ # Get predicted class
95
+ predicted_class_idx = np.argmax(probabilities)
96
+ predicted_class = self.config["labels"][predicted_class_idx]
97
+ confidence = float(probabilities[predicted_class_idx])
98
+
99
+ # Create ONNX prediction dict
100
+ onnx_prediction = {
101
+ label: float(prob) for label, prob in zip(self.config["labels"], probabilities)
102
+ }
103
+
104
+ # Determine if content is a threat based on confidence
105
+ is_threat = predicted_class == "unsafe"
106
+ final_confidence = confidence
107
+
108
+ # Create raw predictions dictionary
109
+ raw_predictions = {
110
+ "onnx": onnx_prediction,
111
+ "sentiment": None # No sentiment analysis in this model, but included for compatibility
112
+ }
113
+
114
+ # Return the expected structure
115
+ return {
116
+ "is_threat": is_threat,
117
+ "final_confidence": final_confidence,
118
+ "threat_prediction": predicted_class,
119
+ "sentiment_analysis": raw_predictions.get("sentiment"),
120
+ "onnx_prediction": raw_predictions.get("onnx"),
121
+ "models_used": ["onnx"],
122
+ "raw_predictions": raw_predictions
123
+ }
124
+
125
+ def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
126
+ """Run inference on multiple texts"""
127
+ return [self.predict(text) for text in texts]