parthraninga's picture
Upload 9 files
2afd81c verified
import os
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import json
from typing import Dict, Any, List, Optional
app = FastAPI(
title="Content Classifier API",
description="ONNX-based content classification for threat detection and sentiment analysis",
version="1.0.0"
)
# Model configuration
MODEL_PATH = "contextClassifier.onnx"
session = None
class TextInput(BaseModel):
text: str
max_length: Optional[int] = 512
class PredictionResponse(BaseModel):
is_threat: bool
final_confidence: float
threat_prediction: float
sentiment_analysis: Optional[Dict[str, Any]]
onnx_prediction: Optional[Dict[str, Any]]
models_used: List[str]
raw_predictions: Dict[str, Any]
def load_model():
"""Load the ONNX model"""
global session
try:
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
session = ort.InferenceSession(MODEL_PATH)
print(f"βœ… Model loaded successfully from {MODEL_PATH}")
# Print model info
inputs = [input.name for input in session.get_inputs()]
outputs = [output.name for output in session.get_outputs()]
print(f"πŸ“₯ Model inputs: {inputs}")
print(f"πŸ“€ Model outputs: {outputs}")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise e
def preprocess_text(text: str, max_length: int = 512) -> Dict[str, np.ndarray]:
"""
Preprocess text for the ONNX model
NOTE: Adjust this function based on your model's specific requirements
"""
try:
# Basic text preprocessing (customize based on your model)
text = text.strip().lower()
# Simple tokenization (replace with actual tokenizer if needed)
tokens = text.split()[:max_length]
# Pad or truncate to fixed length
if len(tokens) < max_length:
tokens.extend(['[PAD]'] * (max_length - len(tokens)))
# Convert to numerical representation
# NOTE: This is a placeholder - replace with your actual preprocessing
input_ids = np.array([hash(token) % 30000 for token in tokens], dtype=np.int64).reshape(1, -1)
attention_mask = np.array([1 if token != '[PAD]' else 0 for token in tokens], dtype=np.int64).reshape(1, -1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask
}
except Exception as e:
print(f"❌ Preprocessing error: {e}")
raise HTTPException(status_code=500, detail=f"Text preprocessing failed: {str(e)}")
def postprocess_predictions(outputs: List[np.ndarray]) -> Dict[str, Any]:
"""
Process ONNX model outputs into the required format
"""
try:
predictions = {}
if not outputs or len(outputs) == 0:
raise ValueError("No outputs received from model")
# Get the main output (adjust index based on your model)
main_output = outputs[0]
# Extract threat prediction
if len(main_output.shape) == 2 and main_output.shape[1] >= 2:
# Binary classification: [non_threat_prob, threat_prob]
threat_prediction = float(main_output[0][1])
elif len(main_output.shape) == 2 and main_output.shape[1] == 1:
# Single output probability
threat_prediction = float(main_output[0][0])
else:
# Fallback for other output shapes
threat_prediction = float(main_output.flatten()[0])
# Calculate confidence and threat classification
final_confidence = abs(threat_prediction - 0.5) * 2 # Scale to 0-1
is_threat = threat_prediction > 0.5
# Store ONNX predictions
predictions["onnx"] = {
"threat_probability": threat_prediction,
"raw_output": main_output.tolist(),
"output_shape": main_output.shape
}
# Generate sentiment analysis (inverse relationship with threat)
sentiment_score = (0.5 - threat_prediction) * 2 # Convert to sentiment scale
predictions["sentiment"] = {
"label": "POSITIVE" if sentiment_score > 0 else "NEGATIVE",
"score": abs(sentiment_score)
}
models_used = ["contextClassifier.onnx"]
# Return the exact format specified
return {
"is_threat": is_threat,
"final_confidence": final_confidence,
"threat_prediction": threat_prediction,
"sentiment_analysis": predictions.get("sentiment"),
"onnx_prediction": predictions.get("onnx"),
"models_used": models_used,
"raw_predictions": predictions
}
except Exception as e:
print(f"❌ Postprocessing error: {e}")
# Return safe fallback response
return {
"is_threat": False,
"final_confidence": 0.0,
"threat_prediction": 0.0,
"sentiment_analysis": {"label": "NEUTRAL", "score": 0.0},
"onnx_prediction": {"error": str(e)},
"models_used": ["contextClassifier.onnx"],
"raw_predictions": {"error": str(e)}
}
@app.on_event("startup")
async def startup_event():
"""Load model on application startup"""
load_model()
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "πŸ” Content Classifier API",
"description": "ONNX-based content classification for threat detection and sentiment analysis",
"model": MODEL_PATH,
"status": "running",
"endpoints": {
"predict": "/predict",
"health": "/health",
"model_info": "/model-info",
"docs": "/docs"
}
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(input_data: TextInput):
"""
Classify content for threat detection and sentiment analysis
Returns the exact format:
{
"is_threat": bool,
"final_confidence": float,
"threat_prediction": float,
"sentiment_analysis": dict,
"onnx_prediction": dict,
"models_used": list,
"raw_predictions": dict
}
"""
if session is None:
raise HTTPException(status_code=500, detail="Model not loaded. Please check server logs.")
if not input_data.text.strip():
raise HTTPException(status_code=400, detail="Text input cannot be empty")
try:
# Preprocess the input text
model_inputs = preprocess_text(input_data.text, input_data.max_length)
# Prepare inputs for ONNX Runtime
input_names = [input.name for input in session.get_inputs()]
ort_inputs = {}
for name in input_names:
if name in model_inputs:
ort_inputs[name] = model_inputs[name]
else:
print(f"⚠️ Warning: Expected input '{name}' not found in processed inputs")
if not ort_inputs:
raise HTTPException(status_code=500, detail="No valid inputs prepared for model")
# Run inference
outputs = session.run(None, ort_inputs)
# Process and return results
result = postprocess_predictions(outputs)
return result
except Exception as e:
print(f"❌ Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if session is not None else "unhealthy",
"model_loaded": session is not None,
"model_path": MODEL_PATH,
"model_exists": os.path.exists(MODEL_PATH)
}
@app.get("/model-info")
async def model_info():
"""Get detailed model information"""
if session is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
inputs = []
for input_meta in session.get_inputs():
inputs.append({
"name": input_meta.name,
"type": str(input_meta.type),
"shape": list(input_meta.shape) if input_meta.shape else None
})
outputs = []
for output_meta in session.get_outputs():
outputs.append({
"name": output_meta.name,
"type": str(output_meta.type),
"shape": list(output_meta.shape) if output_meta.shape else None
})
return {
"model_path": MODEL_PATH,
"model_size": f"{os.path.getsize(MODEL_PATH) / (1024*1024):.2f} MB",
"inputs": inputs,
"outputs": outputs,
"runtime_info": {
"providers": session.get_providers(),
"device": "CPU"
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)