import os import numpy as np import onnxruntime as ort from fastapi import FastAPI, HTTPException from pydantic import BaseModel import uvicorn import json from typing import Dict, Any, List, Optional app = FastAPI( title="Content Classifier API", description="ONNX-based content classification for threat detection and sentiment analysis", version="1.0.0" ) # Model configuration MODEL_PATH = "contextClassifier.onnx" session = None class TextInput(BaseModel): text: str max_length: Optional[int] = 512 class PredictionResponse(BaseModel): is_threat: bool final_confidence: float threat_prediction: float sentiment_analysis: Optional[Dict[str, Any]] onnx_prediction: Optional[Dict[str, Any]] models_used: List[str] raw_predictions: Dict[str, Any] def load_model(): """Load the ONNX model""" global session try: if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") session = ort.InferenceSession(MODEL_PATH) print(f"✅ Model loaded successfully from {MODEL_PATH}") # Print model info inputs = [input.name for input in session.get_inputs()] outputs = [output.name for output in session.get_outputs()] print(f"📥 Model inputs: {inputs}") print(f"📤 Model outputs: {outputs}") except Exception as e: print(f"❌ Error loading model: {e}") raise e def preprocess_text(text: str, max_length: int = 512) -> Dict[str, np.ndarray]: """ Preprocess text for the ONNX model NOTE: Adjust this function based on your model's specific requirements """ try: # Basic text preprocessing (customize based on your model) text = text.strip().lower() # Simple tokenization (replace with actual tokenizer if needed) tokens = text.split()[:max_length] # Pad or truncate to fixed length if len(tokens) < max_length: tokens.extend(['[PAD]'] * (max_length - len(tokens))) # Convert to numerical representation # NOTE: This is a placeholder - replace with your actual preprocessing input_ids = np.array([hash(token) % 30000 for token in tokens], dtype=np.int64).reshape(1, -1) attention_mask = np.array([1 if token != '[PAD]' else 0 for token in tokens], dtype=np.int64).reshape(1, -1) return { "input_ids": input_ids, "attention_mask": attention_mask } except Exception as e: print(f"❌ Preprocessing error: {e}") raise HTTPException(status_code=500, detail=f"Text preprocessing failed: {str(e)}") def postprocess_predictions(outputs: List[np.ndarray]) -> Dict[str, Any]: """ Process ONNX model outputs into the required format """ try: predictions = {} if not outputs or len(outputs) == 0: raise ValueError("No outputs received from model") # Get the main output (adjust index based on your model) main_output = outputs[0] # Extract threat prediction if len(main_output.shape) == 2 and main_output.shape[1] >= 2: # Binary classification: [non_threat_prob, threat_prob] threat_prediction = float(main_output[0][1]) elif len(main_output.shape) == 2 and main_output.shape[1] == 1: # Single output probability threat_prediction = float(main_output[0][0]) else: # Fallback for other output shapes threat_prediction = float(main_output.flatten()[0]) # Calculate confidence and threat classification final_confidence = abs(threat_prediction - 0.5) * 2 # Scale to 0-1 is_threat = threat_prediction > 0.5 # Store ONNX predictions predictions["onnx"] = { "threat_probability": threat_prediction, "raw_output": main_output.tolist(), "output_shape": main_output.shape } # Generate sentiment analysis (inverse relationship with threat) sentiment_score = (0.5 - threat_prediction) * 2 # Convert to sentiment scale predictions["sentiment"] = { "label": "POSITIVE" if sentiment_score > 0 else "NEGATIVE", "score": abs(sentiment_score) } models_used = ["contextClassifier.onnx"] # Return the exact format specified return { "is_threat": is_threat, "final_confidence": final_confidence, "threat_prediction": threat_prediction, "sentiment_analysis": predictions.get("sentiment"), "onnx_prediction": predictions.get("onnx"), "models_used": models_used, "raw_predictions": predictions } except Exception as e: print(f"❌ Postprocessing error: {e}") # Return safe fallback response return { "is_threat": False, "final_confidence": 0.0, "threat_prediction": 0.0, "sentiment_analysis": {"label": "NEUTRAL", "score": 0.0}, "onnx_prediction": {"error": str(e)}, "models_used": ["contextClassifier.onnx"], "raw_predictions": {"error": str(e)} } @app.on_event("startup") async def startup_event(): """Load model on application startup""" load_model() @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "🔍 Content Classifier API", "description": "ONNX-based content classification for threat detection and sentiment analysis", "model": MODEL_PATH, "status": "running", "endpoints": { "predict": "/predict", "health": "/health", "model_info": "/model-info", "docs": "/docs" } } @app.post("/predict", response_model=PredictionResponse) async def predict(input_data: TextInput): """ Classify content for threat detection and sentiment analysis Returns the exact format: { "is_threat": bool, "final_confidence": float, "threat_prediction": float, "sentiment_analysis": dict, "onnx_prediction": dict, "models_used": list, "raw_predictions": dict } """ if session is None: raise HTTPException(status_code=500, detail="Model not loaded. Please check server logs.") if not input_data.text.strip(): raise HTTPException(status_code=400, detail="Text input cannot be empty") try: # Preprocess the input text model_inputs = preprocess_text(input_data.text, input_data.max_length) # Prepare inputs for ONNX Runtime input_names = [input.name for input in session.get_inputs()] ort_inputs = {} for name in input_names: if name in model_inputs: ort_inputs[name] = model_inputs[name] else: print(f"⚠️ Warning: Expected input '{name}' not found in processed inputs") if not ort_inputs: raise HTTPException(status_code=500, detail="No valid inputs prepared for model") # Run inference outputs = session.run(None, ort_inputs) # Process and return results result = postprocess_predictions(outputs) return result except Exception as e: print(f"❌ Prediction error: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy" if session is not None else "unhealthy", "model_loaded": session is not None, "model_path": MODEL_PATH, "model_exists": os.path.exists(MODEL_PATH) } @app.get("/model-info") async def model_info(): """Get detailed model information""" if session is None: raise HTTPException(status_code=500, detail="Model not loaded") try: inputs = [] for input_meta in session.get_inputs(): inputs.append({ "name": input_meta.name, "type": str(input_meta.type), "shape": list(input_meta.shape) if input_meta.shape else None }) outputs = [] for output_meta in session.get_outputs(): outputs.append({ "name": output_meta.name, "type": str(output_meta.type), "shape": list(output_meta.shape) if output_meta.shape else None }) return { "model_path": MODEL_PATH, "model_size": f"{os.path.getsize(MODEL_PATH) / (1024*1024):.2f} MB", "inputs": inputs, "outputs": outputs, "runtime_info": { "providers": session.get_providers(), "device": "CPU" } } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)