|
|
| """ |
| CyberForge ML Inference Module |
| Backend integration for mlService.js |
| """ |
|
|
| import json |
| import time |
| import joblib |
| import numpy as np |
| from pathlib import Path |
| from typing import Dict, List, Any, Optional |
|
|
| class CyberForgeInference: |
| """ |
| ML inference service for CyberForge backend. |
| Compatible with mlService.js API contract. |
| """ |
|
|
| def __init__(self, models_dir: str): |
| self.models_dir = Path(models_dir) |
| self.loaded_models = {} |
| self.manifest = self._load_manifest() |
|
|
| def _load_manifest(self) -> Dict: |
| manifest_path = self.models_dir / "manifest.json" |
| if manifest_path.exists(): |
| with open(manifest_path) as f: |
| return json.load(f) |
| return {"models": {}} |
|
|
| def load_model(self, model_name: str) -> bool: |
| """Load a model into memory""" |
| if model_name in self.loaded_models: |
| return True |
|
|
| model_dir = self.models_dir / model_name |
| model_path = model_dir / "model.pkl" |
| scaler_path = model_dir / "scaler.pkl" |
|
|
| if not model_path.exists(): |
| return False |
|
|
| self.loaded_models[model_name] = { |
| "model": joblib.load(model_path), |
| "scaler": joblib.load(scaler_path) if scaler_path.exists() else None |
| } |
| return True |
|
|
| def predict(self, model_name: str, features: Dict) -> Dict: |
| """ |
| Make a prediction. |
| |
| Args: |
| model_name: Name of the model to use |
| features: Feature dictionary |
| |
| Returns: |
| Response matching mlService.js contract |
| """ |
| if not self.load_model(model_name): |
| return {"error": f"Model not found: {model_name}"} |
|
|
| model_data = self.loaded_models[model_name] |
| model = model_data["model"] |
| scaler = model_data["scaler"] |
|
|
| |
| X = np.array([list(features.values())]) |
|
|
| |
| if scaler: |
| X = scaler.transform(X) |
|
|
| |
| start_time = time.time() |
| prediction = int(model.predict(X)[0]) |
| inference_time = (time.time() - start_time) * 1000 |
|
|
| |
| confidence = 0.5 |
| if hasattr(model, "predict_proba"): |
| proba = model.predict_proba(X)[0] |
| confidence = float(max(proba)) |
|
|
| |
| risk_level = ( |
| "critical" if confidence >= 0.9 else |
| "high" if confidence >= 0.7 else |
| "medium" if confidence >= 0.5 else |
| "low" if confidence >= 0.3 else "info" |
| ) |
|
|
| return { |
| "prediction": prediction, |
| "confidence": confidence, |
| "risk_level": risk_level, |
| "model_name": model_name, |
| "model_version": "1.0.0", |
| "inference_time_ms": inference_time |
| } |
|
|
| def batch_predict(self, model_name: str, features_list: List[Dict]) -> List[Dict]: |
| """Batch predictions""" |
| return [self.predict(model_name, f) for f in features_list] |
|
|
| def list_models(self) -> List[str]: |
| """List available models""" |
| return list(self.manifest.get("models", {}).keys()) |
|
|
| def get_model_info(self, model_name: str) -> Dict: |
| """Get model information""" |
| return self.manifest.get("models", {}).get(model_name, {}) |
|
|
|
|
| |
| def create_api(models_dir: str): |
| """Create FastAPI app for model serving""" |
| try: |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| except ImportError: |
| return None |
|
|
| app = FastAPI(title="CyberForge ML API", version="1.0.0") |
| inference = CyberForgeInference(models_dir) |
|
|
| class PredictRequest(BaseModel): |
| model_name: str |
| features: Dict |
|
|
| @app.post("/predict") |
| async def predict(request: PredictRequest): |
| result = inference.predict(request.model_name, request.features) |
| if "error" in result: |
| raise HTTPException(status_code=404, detail=result["error"]) |
| return result |
|
|
| @app.get("/models") |
| async def list_models(): |
| return {"models": inference.list_models()} |
|
|
| @app.get("/models/{model_name}") |
| async def get_model_info(model_name: str): |
| info = inference.get_model_info(model_name) |
| if not info: |
| raise HTTPException(status_code=404, detail="Model not found") |
| return info |
|
|
| return app |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| models_dir = sys.argv[1] if len(sys.argv) > 1 else "." |
|
|
| inference = CyberForgeInference(models_dir) |
| print(f"Available models: {inference.list_models()}") |
|
|