content-classifier / inference.py
parthraninga's picture
Upload inference.py with huggingface_hub
3d76333 verified
import onnxruntime as ort
import numpy as np
from typing import Dict, List, Any
import json
from pathlib import Path
import os
class ContentClassifierInference:
def __init__(self, model_path: str = "contentclassifier.onnx", config_path: str = "config.json"):
self.model_path = model_path
self.config_path = config_path
# Check if model exists
if not os.path.exists(model_path):
print(f"Warning: Model file {model_path} not found!")
print("Creating a dummy model for testing...")
try:
from create_dummy_model import create_dummy_onnx_model
create_dummy_onnx_model(model_path)
except ImportError:
raise FileNotFoundError(f"Model file {model_path} not found and couldn't create dummy model")
# Load ONNX model
try:
self.session = ort.InferenceSession(model_path)
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model: {e}")
# Load configuration
if Path(config_path).exists():
with open(config_path, 'r') as f:
self.config = json.load(f)
else:
print(f"Warning: Config file {config_path} not found, using default config")
self.config = self._default_config()
# Get input/output info
try:
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
except IndexError:
raise ValueError("Model doesn't have expected inputs/outputs")
def _default_config(self) -> Dict:
return {
"labels": ["safe", "unsafe"],
"max_length": 512,
"threshold": 0.5
}
def preprocess(self, text: str) -> np.ndarray:
"""Preprocess text input for the model"""
# Check input
if not isinstance(text, str):
raise TypeError(f"Input must be string, got {type(text)}")
# This is a placeholder - adjust based on your model's input requirements
# You might need tokenization, encoding, etc.
# Example: Simple text to vector conversion (replace with actual preprocessing)
encoded = text.encode('utf-8')[:self.config["max_length"]]
# Pad or truncate to fixed length
input_array = np.zeros(self.config["max_length"], dtype=np.float32)
for i, byte_val in enumerate(encoded):
if i < len(input_array):
input_array[i] = float(byte_val) / 255.0
# Check input shape against model's expected input
expected_shape = self.session.get_inputs()[0].shape
input_shape = [1, self.config["max_length"]]
if expected_shape != ['batch', self.config["max_length"]] and expected_shape != [1, self.config["max_length"]]:
print(f"Warning: Model expects input shape {expected_shape}, but preprocessing produces {input_shape}")
return input_array.reshape(1, -1)
def predict(self, text: str) -> Dict[str, Any]:
"""Run inference on input text"""
# Preprocess input
input_data = self.preprocess(text)
# Run inference
outputs = self.session.run([self.output_name], {self.input_name: input_data})
predictions = outputs[0]
# Postprocess results
if len(predictions.shape) > 1:
predictions = predictions[0]
# Apply softmax if needed
exp_scores = np.exp(predictions - np.max(predictions))
probabilities = exp_scores / np.sum(exp_scores)
# Get predicted class
predicted_class_idx = np.argmax(probabilities)
predicted_class = self.config["labels"][predicted_class_idx]
confidence = float(probabilities[predicted_class_idx])
# Create ONNX prediction dict
onnx_prediction = {
label: float(prob) for label, prob in zip(self.config["labels"], probabilities)
}
# Determine if content is a threat based on confidence
is_threat = predicted_class == "unsafe"
final_confidence = confidence
# Create raw predictions dictionary
raw_predictions = {
"onnx": onnx_prediction,
"sentiment": None # No sentiment analysis in this model, but included for compatibility
}
# Return the expected structure
return {
"is_threat": is_threat,
"final_confidence": final_confidence,
"threat_prediction": predicted_class,
"sentiment_analysis": raw_predictions.get("sentiment"),
"onnx_prediction": raw_predictions.get("onnx"),
"models_used": ["onnx"],
"raw_predictions": raw_predictions
}
def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
"""Run inference on multiple texts"""
return [self.predict(text) for text in texts]