|
|
import os
|
|
|
import numpy as np
|
|
|
import onnxruntime as ort
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from pydantic import BaseModel
|
|
|
import uvicorn
|
|
|
import json
|
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
|
|
app = FastAPI(title="Content Classifier API", description="Content classification using ONNX model")
|
|
|
|
|
|
|
|
|
MODEL_PATH = "contextClassifier.onnx"
|
|
|
session = None
|
|
|
|
|
|
class TextInput(BaseModel):
|
|
|
text: str
|
|
|
max_length: Optional[int] = 512
|
|
|
|
|
|
class PredictionResponse(BaseModel):
|
|
|
is_threat: bool
|
|
|
final_confidence: float
|
|
|
threat_prediction: float
|
|
|
sentiment_analysis: Optional[Dict[str, Any]]
|
|
|
onnx_prediction: Optional[Dict[str, Any]]
|
|
|
models_used: List[str]
|
|
|
raw_predictions: Dict[str, Any]
|
|
|
|
|
|
def load_model():
|
|
|
"""Load the ONNX model"""
|
|
|
global session
|
|
|
try:
|
|
|
session = ort.InferenceSession(MODEL_PATH)
|
|
|
print(f"Model loaded successfully from {MODEL_PATH}")
|
|
|
print(f"Model inputs: {[input.name for input in session.get_inputs()]}")
|
|
|
print(f"Model outputs: {[output.name for output in session.get_outputs()]}")
|
|
|
except Exception as e:
|
|
|
print(f"Error loading model: {e}")
|
|
|
raise e
|
|
|
|
|
|
def preprocess_text(text: str, max_length: int = 512):
|
|
|
"""
|
|
|
Preprocess text for the model
|
|
|
This is a placeholder - you'll need to adjust this based on your model's requirements
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = text.lower().split()[:max_length]
|
|
|
|
|
|
|
|
|
if len(tokens) < max_length:
|
|
|
tokens.extend(['[PAD]'] * (max_length - len(tokens)))
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = np.array([hash(token) % 30000 for token in tokens], dtype=np.int64).reshape(1, -1)
|
|
|
attention_mask = np.array([1 if token != '[PAD]' else 0 for token in tokens], dtype=np.int64).reshape(1, -1)
|
|
|
|
|
|
return {
|
|
|
"input_ids": input_ids,
|
|
|
"attention_mask": attention_mask
|
|
|
}
|
|
|
|
|
|
def postprocess_predictions(outputs, predictions_dict):
|
|
|
"""
|
|
|
Process model outputs into the expected format
|
|
|
Adjust this based on your model's actual outputs
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(outputs) > 0:
|
|
|
raw_output = outputs[0]
|
|
|
|
|
|
|
|
|
threat_prediction = float(raw_output[0][1]) if len(raw_output[0]) > 1 else 0.5
|
|
|
final_confidence = abs(threat_prediction - 0.5) * 2
|
|
|
is_threat = threat_prediction > 0.5
|
|
|
|
|
|
predictions_dict.update({
|
|
|
"onnx": {
|
|
|
"threat_probability": threat_prediction,
|
|
|
"raw_output": raw_output.tolist()
|
|
|
}
|
|
|
})
|
|
|
|
|
|
|
|
|
sentiment_score = (threat_prediction - 0.5) * -2
|
|
|
predictions_dict["sentiment"] = {
|
|
|
"label": "NEGATIVE" if sentiment_score < 0 else "POSITIVE",
|
|
|
"score": abs(sentiment_score)
|
|
|
}
|
|
|
|
|
|
models_used = ["contextClassifier.onnx"]
|
|
|
|
|
|
return {
|
|
|
"is_threat": is_threat,
|
|
|
"final_confidence": final_confidence,
|
|
|
"threat_prediction": threat_prediction,
|
|
|
"sentiment_analysis": predictions_dict.get("sentiment"),
|
|
|
"onnx_prediction": predictions_dict.get("onnx"),
|
|
|
"models_used": models_used,
|
|
|
"raw_predictions": predictions_dict
|
|
|
}
|
|
|
|
|
|
|
|
|
return {
|
|
|
"is_threat": False,
|
|
|
"final_confidence": 0.0,
|
|
|
"threat_prediction": 0.0,
|
|
|
"sentiment_analysis": None,
|
|
|
"onnx_prediction": None,
|
|
|
"models_used": [],
|
|
|
"raw_predictions": predictions_dict
|
|
|
}
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
"""Load model on startup"""
|
|
|
load_model()
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
return {"message": "Content Classifier API is running", "model": MODEL_PATH}
|
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse)
|
|
|
async def predict(input_data: TextInput):
|
|
|
"""
|
|
|
Predict content classification for the given text
|
|
|
"""
|
|
|
if session is None:
|
|
|
raise HTTPException(status_code=500, detail="Model not loaded")
|
|
|
|
|
|
try:
|
|
|
|
|
|
model_inputs = preprocess_text(input_data.text, input_data.max_length)
|
|
|
|
|
|
|
|
|
input_names = [input.name for input in session.get_inputs()]
|
|
|
|
|
|
|
|
|
ort_inputs = {}
|
|
|
for name in input_names:
|
|
|
if name in model_inputs:
|
|
|
ort_inputs[name] = model_inputs[name]
|
|
|
else:
|
|
|
|
|
|
print(f"Warning: Expected input '{name}' not found in processed inputs")
|
|
|
|
|
|
|
|
|
outputs = session.run(None, ort_inputs)
|
|
|
|
|
|
|
|
|
predictions = {}
|
|
|
|
|
|
|
|
|
result = postprocess_predictions(outputs, predictions)
|
|
|
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Prediction error: {e}")
|
|
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
"""Health check endpoint"""
|
|
|
return {
|
|
|
"status": "healthy",
|
|
|
"model_loaded": session is not None,
|
|
|
"model_path": MODEL_PATH
|
|
|
}
|
|
|
|
|
|
@app.get("/model-info")
|
|
|
async def model_info():
|
|
|
"""Get model information"""
|
|
|
if session is None:
|
|
|
raise HTTPException(status_code=500, detail="Model not loaded")
|
|
|
|
|
|
inputs = []
|
|
|
for input_meta in session.get_inputs():
|
|
|
inputs.append({
|
|
|
"name": input_meta.name,
|
|
|
"type": str(input_meta.type),
|
|
|
"shape": input_meta.shape
|
|
|
})
|
|
|
|
|
|
outputs = []
|
|
|
for output_meta in session.get_outputs():
|
|
|
outputs.append({
|
|
|
"name": output_meta.name,
|
|
|
"type": str(output_meta.type),
|
|
|
"shape": output_meta.shape
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
"model_path": MODEL_PATH,
|
|
|
"inputs": inputs,
|
|
|
"outputs": outputs
|
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|