Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import onnxruntime as ort | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import json | |
| from typing import Dict, Any, List, Optional | |
| app = FastAPI( | |
| title="Content Classifier API", | |
| description="ONNX-based content classification for threat detection and sentiment analysis", | |
| version="1.0.0" | |
| ) | |
| # Model configuration | |
| MODEL_PATH = "contextClassifier.onnx" | |
| session = None | |
| class TextInput(BaseModel): | |
| text: str | |
| max_length: Optional[int] = 512 | |
| class PredictionResponse(BaseModel): | |
| is_threat: bool | |
| final_confidence: float | |
| threat_prediction: float | |
| sentiment_analysis: Optional[Dict[str, Any]] | |
| onnx_prediction: Optional[Dict[str, Any]] | |
| models_used: List[str] | |
| raw_predictions: Dict[str, Any] | |
| def load_model(): | |
| """Load the ONNX model""" | |
| global session | |
| try: | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") | |
| session = ort.InferenceSession(MODEL_PATH) | |
| print(f"β Model loaded successfully from {MODEL_PATH}") | |
| # Print model info | |
| inputs = [input.name for input in session.get_inputs()] | |
| outputs = [output.name for output in session.get_outputs()] | |
| print(f"π₯ Model inputs: {inputs}") | |
| print(f"π€ Model outputs: {outputs}") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| raise e | |
| def preprocess_text(text: str, max_length: int = 512) -> Dict[str, np.ndarray]: | |
| """ | |
| Preprocess text for the ONNX model | |
| NOTE: Adjust this function based on your model's specific requirements | |
| """ | |
| try: | |
| # Basic text preprocessing (customize based on your model) | |
| text = text.strip().lower() | |
| # Simple tokenization (replace with actual tokenizer if needed) | |
| tokens = text.split()[:max_length] | |
| # Pad or truncate to fixed length | |
| if len(tokens) < max_length: | |
| tokens.extend(['[PAD]'] * (max_length - len(tokens))) | |
| # Convert to numerical representation | |
| # NOTE: This is a placeholder - replace with your actual preprocessing | |
| input_ids = np.array([hash(token) % 30000 for token in tokens], dtype=np.int64).reshape(1, -1) | |
| attention_mask = np.array([1 if token != '[PAD]' else 0 for token in tokens], dtype=np.int64).reshape(1, -1) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask | |
| } | |
| except Exception as e: | |
| print(f"β Preprocessing error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Text preprocessing failed: {str(e)}") | |
| def postprocess_predictions(outputs: List[np.ndarray]) -> Dict[str, Any]: | |
| """ | |
| Process ONNX model outputs into the required format | |
| """ | |
| try: | |
| predictions = {} | |
| if not outputs or len(outputs) == 0: | |
| raise ValueError("No outputs received from model") | |
| # Get the main output (adjust index based on your model) | |
| main_output = outputs[0] | |
| # Extract threat prediction | |
| if len(main_output.shape) == 2 and main_output.shape[1] >= 2: | |
| # Binary classification: [non_threat_prob, threat_prob] | |
| threat_prediction = float(main_output[0][1]) | |
| elif len(main_output.shape) == 2 and main_output.shape[1] == 1: | |
| # Single output probability | |
| threat_prediction = float(main_output[0][0]) | |
| else: | |
| # Fallback for other output shapes | |
| threat_prediction = float(main_output.flatten()[0]) | |
| # Calculate confidence and threat classification | |
| final_confidence = abs(threat_prediction - 0.5) * 2 # Scale to 0-1 | |
| is_threat = threat_prediction > 0.5 | |
| # Store ONNX predictions | |
| predictions["onnx"] = { | |
| "threat_probability": threat_prediction, | |
| "raw_output": main_output.tolist(), | |
| "output_shape": main_output.shape | |
| } | |
| # Generate sentiment analysis (inverse relationship with threat) | |
| sentiment_score = (0.5 - threat_prediction) * 2 # Convert to sentiment scale | |
| predictions["sentiment"] = { | |
| "label": "POSITIVE" if sentiment_score > 0 else "NEGATIVE", | |
| "score": abs(sentiment_score) | |
| } | |
| models_used = ["contextClassifier.onnx"] | |
| # Return the exact format specified | |
| return { | |
| "is_threat": is_threat, | |
| "final_confidence": final_confidence, | |
| "threat_prediction": threat_prediction, | |
| "sentiment_analysis": predictions.get("sentiment"), | |
| "onnx_prediction": predictions.get("onnx"), | |
| "models_used": models_used, | |
| "raw_predictions": predictions | |
| } | |
| except Exception as e: | |
| print(f"β Postprocessing error: {e}") | |
| # Return safe fallback response | |
| return { | |
| "is_threat": False, | |
| "final_confidence": 0.0, | |
| "threat_prediction": 0.0, | |
| "sentiment_analysis": {"label": "NEUTRAL", "score": 0.0}, | |
| "onnx_prediction": {"error": str(e)}, | |
| "models_used": ["contextClassifier.onnx"], | |
| "raw_predictions": {"error": str(e)} | |
| } | |
| async def startup_event(): | |
| """Load model on application startup""" | |
| load_model() | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "π Content Classifier API", | |
| "description": "ONNX-based content classification for threat detection and sentiment analysis", | |
| "model": MODEL_PATH, | |
| "status": "running", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "health": "/health", | |
| "model_info": "/model-info", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def predict(input_data: TextInput): | |
| """ | |
| Classify content for threat detection and sentiment analysis | |
| Returns the exact format: | |
| { | |
| "is_threat": bool, | |
| "final_confidence": float, | |
| "threat_prediction": float, | |
| "sentiment_analysis": dict, | |
| "onnx_prediction": dict, | |
| "models_used": list, | |
| "raw_predictions": dict | |
| } | |
| """ | |
| if session is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded. Please check server logs.") | |
| if not input_data.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text input cannot be empty") | |
| try: | |
| # Preprocess the input text | |
| model_inputs = preprocess_text(input_data.text, input_data.max_length) | |
| # Prepare inputs for ONNX Runtime | |
| input_names = [input.name for input in session.get_inputs()] | |
| ort_inputs = {} | |
| for name in input_names: | |
| if name in model_inputs: | |
| ort_inputs[name] = model_inputs[name] | |
| else: | |
| print(f"β οΈ Warning: Expected input '{name}' not found in processed inputs") | |
| if not ort_inputs: | |
| raise HTTPException(status_code=500, detail="No valid inputs prepared for model") | |
| # Run inference | |
| outputs = session.run(None, ort_inputs) | |
| # Process and return results | |
| result = postprocess_predictions(outputs) | |
| return result | |
| except Exception as e: | |
| print(f"β Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy" if session is not None else "unhealthy", | |
| "model_loaded": session is not None, | |
| "model_path": MODEL_PATH, | |
| "model_exists": os.path.exists(MODEL_PATH) | |
| } | |
| async def model_info(): | |
| """Get detailed model information""" | |
| if session is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| try: | |
| inputs = [] | |
| for input_meta in session.get_inputs(): | |
| inputs.append({ | |
| "name": input_meta.name, | |
| "type": str(input_meta.type), | |
| "shape": list(input_meta.shape) if input_meta.shape else None | |
| }) | |
| outputs = [] | |
| for output_meta in session.get_outputs(): | |
| outputs.append({ | |
| "name": output_meta.name, | |
| "type": str(output_meta.type), | |
| "shape": list(output_meta.shape) if output_meta.shape else None | |
| }) | |
| return { | |
| "model_path": MODEL_PATH, | |
| "model_size": f"{os.path.getsize(MODEL_PATH) / (1024*1024):.2f} MB", | |
| "inputs": inputs, | |
| "outputs": outputs, | |
| "runtime_info": { | |
| "providers": session.get_providers(), | |
| "device": "CPU" | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |