File size: 6,928 Bytes
c1f3888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import json
from typing import Dict, Any, List, Optional

app = FastAPI(title="Content Classifier API", description="Content classification using ONNX model")

# Model configuration
MODEL_PATH = "contextClassifier.onnx"
session = None

class TextInput(BaseModel):
    text: str
    max_length: Optional[int] = 512

class PredictionResponse(BaseModel):
    is_threat: bool
    final_confidence: float
    threat_prediction: float
    sentiment_analysis: Optional[Dict[str, Any]]
    onnx_prediction: Optional[Dict[str, Any]]
    models_used: List[str]
    raw_predictions: Dict[str, Any]

def load_model():
    """Load the ONNX model"""
    global session
    try:
        session = ort.InferenceSession(MODEL_PATH)
        print(f"Model loaded successfully from {MODEL_PATH}")
        print(f"Model inputs: {[input.name for input in session.get_inputs()]}")
        print(f"Model outputs: {[output.name for output in session.get_outputs()]}")
    except Exception as e:
        print(f"Error loading model: {e}")
        raise e

def preprocess_text(text: str, max_length: int = 512):
    """

    Preprocess text for the model

    This is a placeholder - you'll need to adjust this based on your model's requirements

    """
    # This is a simple tokenization example
    # You may need to use a specific tokenizer depending on your model
    
    # Convert text to token IDs (this is just an example)
    # You might need to use transformers tokenizer or similar
    tokens = text.lower().split()[:max_length]
    
    # Pad or truncate to fixed length
    if len(tokens) < max_length:
        tokens.extend(['[PAD]'] * (max_length - len(tokens)))
    
    # Convert to input format expected by your model
    # This is a placeholder - adjust based on your model's input requirements
    input_ids = np.array([hash(token) % 30000 for token in tokens], dtype=np.int64).reshape(1, -1)
    attention_mask = np.array([1 if token != '[PAD]' else 0 for token in tokens], dtype=np.int64).reshape(1, -1)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    }

def postprocess_predictions(outputs, predictions_dict):
    """

    Process model outputs into the expected format

    Adjust this based on your model's actual outputs

    """
    # This is a placeholder implementation
    # Adjust based on your actual model outputs
    
    # Assuming the model outputs probabilities or logits
    if len(outputs) > 0:
        raw_output = outputs[0]
        
        # Calculate threat prediction (adjust logic as needed)
        threat_prediction = float(raw_output[0][1]) if len(raw_output[0]) > 1 else 0.5
        final_confidence = abs(threat_prediction - 0.5) * 2  # Scale to 0-1
        is_threat = threat_prediction > 0.5
        
        predictions_dict.update({
            "onnx": {
                "threat_probability": threat_prediction,
                "raw_output": raw_output.tolist()
            }
        })
        
        # Mock sentiment analysis (replace with actual logic if available)
        sentiment_score = (threat_prediction - 0.5) * -2  # Inverse relationship
        predictions_dict["sentiment"] = {
            "label": "NEGATIVE" if sentiment_score < 0 else "POSITIVE",
            "score": abs(sentiment_score)
        }
        
        models_used = ["contextClassifier.onnx"]
        
        return {
            "is_threat": is_threat,
            "final_confidence": final_confidence,
            "threat_prediction": threat_prediction,
            "sentiment_analysis": predictions_dict.get("sentiment"),
            "onnx_prediction": predictions_dict.get("onnx"),
            "models_used": models_used,
            "raw_predictions": predictions_dict
        }
    
    # Fallback response
    return {
        "is_threat": False,
        "final_confidence": 0.0,
        "threat_prediction": 0.0,
        "sentiment_analysis": None,
        "onnx_prediction": None,
        "models_used": [],
        "raw_predictions": predictions_dict
    }

@app.on_event("startup")
async def startup_event():
    """Load model on startup"""
    load_model()

@app.get("/")
async def root():
    return {"message": "Content Classifier API is running", "model": MODEL_PATH}

@app.post("/predict", response_model=PredictionResponse)
async def predict(input_data: TextInput):
    """

    Predict content classification for the given text

    """
    if session is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    try:
        # Preprocess the text
        model_inputs = preprocess_text(input_data.text, input_data.max_length)
        
        # Get input names from the model
        input_names = [input.name for input in session.get_inputs()]
        
        # Prepare inputs for ONNX Runtime
        ort_inputs = {}
        for name in input_names:
            if name in model_inputs:
                ort_inputs[name] = model_inputs[name]
            else:
                # Handle case where expected input is not provided
                print(f"Warning: Expected input '{name}' not found in processed inputs")
        
        # Run inference
        outputs = session.run(None, ort_inputs)
        
        # Initialize predictions dictionary
        predictions = {}
        
        # Process outputs
        result = postprocess_predictions(outputs, predictions)
        
        return result
        
    except Exception as e:
        print(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "model_loaded": session is not None,
        "model_path": MODEL_PATH
    }

@app.get("/model-info")
async def model_info():
    """Get model information"""
    if session is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    inputs = []
    for input_meta in session.get_inputs():
        inputs.append({
            "name": input_meta.name,
            "type": str(input_meta.type),
            "shape": input_meta.shape
        })
    
    outputs = []
    for output_meta in session.get_outputs():
        outputs.append({
            "name": output_meta.name,
            "type": str(output_meta.type),
            "shape": output_meta.shape
        })
    
    return {
        "model_path": MODEL_PATH,
        "inputs": inputs,
        "outputs": outputs
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)