import os import sys import json import yaml import numpy as np from typing import List, Dict, Optional, Any from fastapi import FastAPI, HTTPException from pydantic import BaseModel import onnxruntime as ort from sentence_transformers import SentenceTransformer import logging # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Define API models class RouteRequest(BaseModel): query: str available_models: List[str] performance_weight: float = 1.0 cost_weight: float = 0.5 return_scores: bool = False class RouteResponse(BaseModel): selected_model: str scores: Optional[Dict[str, float]] = None # Initialize FastAPI app app = FastAPI( title="Multi-Router-Bandit API", description="API for routing queries to the most appropriate LLM", version="1.0.0" ) # Global variables embedding_model = None routing_policy_session = None irt_model_session = None model_embeddings = None models_config = None def get_models_dir(): """Get the models directory from environment variable or default.""" return os.environ.get("MODELS_DIR", "models") @app.on_event("startup") async def startup_event(): """Initialize models on startup.""" global embedding_model, routing_policy_session, irt_model_session, model_embeddings, models_config models_dir = get_models_dir() # Check if models exist if not os.path.exists(os.path.join(models_dir, "routing_policy.onnx")): logger.error(f"Routing policy not found at {os.path.join(models_dir, 'routing_policy.onnx')}") sys.exit(1) if not os.path.exists(os.path.join(models_dir, "irt_model.onnx")): logger.error(f"IRT model not found at {os.path.join(models_dir, 'irt_model.onnx')}") sys.exit(1) if not os.path.exists(os.path.join(models_dir, "model_embeddings.json")): logger.error(f"Model embeddings not found at {os.path.join(models_dir, 'model_embeddings.json')}") sys.exit(1) if not os.path.exists(os.path.join(models_dir, "models_config.yaml")): logger.error(f"Models configuration not found at {os.path.join(models_dir, 'models_config.yaml')}") sys.exit(1) # Load embedding model logger.info("Loading embedding model...") embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") # Load ONNX models logger.info("Loading ONNX models...") routing_policy_session = ort.InferenceSession(os.path.join(models_dir, "routing_policy.onnx")) irt_model_session = ort.InferenceSession(os.path.join(models_dir, "irt_model.onnx")) # Load model embeddings logger.info("Loading model embeddings...") with open(os.path.join(models_dir, "model_embeddings.json"), "r") as f: model_embeddings = json.load(f) # Load models configuration logger.info("Loading models configuration...") with open(os.path.join(models_dir, "models_config.yaml"), "r") as f: models_config = yaml.safe_load(f) logger.info("Startup complete") def get_model_cost(model_name: str) -> float: """Get the cost of a model.""" # Check API endpoints for provider in models_config.get('api_endpoints', {}).values(): for model_config in provider.get('models', []): if model_config['name'] == model_name: return (model_config.get('cost_per_1k_input_tokens', 0) + model_config.get('cost_per_1k_output_tokens', 0)) / 2 # Check open source models for model_config in models_config.get('open_source_models', []): if model_config['name'] == model_name: return model_config.get('cost_per_1k_tokens', 0) # Default cost return 0.01 def get_prompt_embedding(query: str) -> np.ndarray: """Get the embedding for a query.""" return embedding_model.encode([query], convert_to_numpy=True)[0] def predict_score(prompt_embedding: np.ndarray, model_embedding: np.ndarray) -> float: """Predict the score for a prompt-model pair using the IRT model.""" # Prepare inputs ort_inputs = { "prompt_embedding": prompt_embedding.astype(np.float32).reshape(1, -1), "model_embedding": model_embedding.astype(np.float32).reshape(1, -1) } # Run inference ort_outputs = irt_model_session.run(None, ort_inputs) # Get score score = ort_outputs[0][0][0] return float(score) def route_query( query: str, available_models: List[str], performance_weight: float = 1.0, cost_weight: float = 0.5, return_scores: bool = False ) -> Dict[str, Any]: """Route a query to the most appropriate LLM.""" # Get prompt embedding prompt_embedding = get_prompt_embedding(query) # Create preference vector preference = np.array([[performance_weight, cost_weight]], dtype=np.float32) # Create model contexts model_contexts = [] for model_name in available_models: # Get model embedding if model_name not in model_embeddings: raise HTTPException(status_code=400, detail=f"Model {model_name} not found") model_embedding = np.array(model_embeddings[model_name], dtype=np.float32) # Get model cost model_cost = get_model_cost(model_name) # Predict score predicted_score = predict_score(prompt_embedding, model_embedding) model_contexts.append((model_embedding, model_cost, predicted_score)) # Prepare inputs for routing policy model_embeddings_list = np.array([context[0] for context in model_contexts], dtype=np.float32) model_costs = np.array([context[1] for context in model_contexts], dtype=np.float32) model_scores = np.array([context[2] for context in model_contexts], dtype=np.float32) # Concatenate model contexts model_contexts_concat = np.concatenate([ model_embeddings_list, model_costs.reshape(-1, 1), model_scores.reshape(-1, 1) ], axis=1) # Prepare inputs for ONNX model ort_inputs = { "prompt_embedding": prompt_embedding.astype(np.float32).reshape(1, -1), "model_contexts": model_contexts_concat.astype(np.float32), "preference": preference } # Run inference ort_outputs = routing_policy_session.run(None, ort_inputs) # Get action logits action_logits = ort_outputs[0] # Get action with highest probability action = np.argmax(action_logits, axis=1)[0] # Get selected model selected_model = available_models[action] # Prepare response response = { "selected_model": selected_model } # Add scores if requested if return_scores: scores = {} for i, model_name in enumerate(available_models): scores[model_name] = float(model_contexts[i][2]) response["scores"] = scores return response @app.post("/route", response_model=RouteResponse) async def route_endpoint(request: RouteRequest): """Route a query to the most appropriate LLM.""" try: result = route_query( query=request.query, available_models=request.available_models, performance_weight=request.performance_weight, cost_weight=request.cost_weight, return_scores=request.return_scores ) return result except Exception as e: logger.error(f"Error routing query: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/models") async def list_models(): """List all available models.""" models = [] # Get models from API endpoints for provider, provider_config in models_config.get('api_endpoints', {}).items(): for model_config in provider_config.get('models', []): models.append({ "name": model_config['name'], "provider": provider, "cost_per_1k_input_tokens": model_config.get('cost_per_1k_input_tokens', 0), "cost_per_1k_output_tokens": model_config.get('cost_per_1k_output_tokens', 0), "description": model_config.get('description', "") }) # Get open source models for model_config in models_config.get('open_source_models', []): models.append({ "name": model_config['name'], "provider": "open_source", "cost_per_1k_tokens": model_config.get('cost_per_1k_tokens', 0), "description": model_config.get('description', "") }) return {"models": models} @app.get("/health") async def health_check(): """Health check endpoint.""" return {"status": "ok"} # Add a simple UI for testing from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse @app.get("/", response_class=HTMLResponse) async def read_root(): return """