| | import os |
| | import sys |
| | import json |
| | import yaml |
| | import numpy as np |
| | from typing import List, Dict, Optional, Any |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import onnxruntime as ort |
| | from sentence_transformers import SentenceTransformer |
| | import logging |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | class RouteRequest(BaseModel): |
| | query: str |
| | available_models: List[str] |
| | performance_weight: float = 1.0 |
| | cost_weight: float = 0.5 |
| | return_scores: bool = False |
| |
|
| | class RouteResponse(BaseModel): |
| | selected_model: str |
| | scores: Optional[Dict[str, float]] = None |
| |
|
| | |
| | app = FastAPI( |
| | title="Multi-Router-Bandit API", |
| | description="API for routing queries to the most appropriate LLM", |
| | version="1.0.0" |
| | ) |
| |
|
| | |
| | embedding_model = None |
| | routing_policy_session = None |
| | irt_model_session = None |
| | model_embeddings = None |
| | models_config = None |
| |
|
| | def get_models_dir(): |
| | """Get the models directory from environment variable or default.""" |
| | return os.environ.get("MODELS_DIR", "models") |
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Initialize models on startup.""" |
| | global embedding_model, routing_policy_session, irt_model_session, model_embeddings, models_config |
| | |
| | models_dir = get_models_dir() |
| | |
| | |
| | if not os.path.exists(os.path.join(models_dir, "routing_policy.onnx")): |
| | logger.error(f"Routing policy not found at {os.path.join(models_dir, 'routing_policy.onnx')}") |
| | sys.exit(1) |
| | |
| | if not os.path.exists(os.path.join(models_dir, "irt_model.onnx")): |
| | logger.error(f"IRT model not found at {os.path.join(models_dir, 'irt_model.onnx')}") |
| | sys.exit(1) |
| | |
| | if not os.path.exists(os.path.join(models_dir, "model_embeddings.json")): |
| | logger.error(f"Model embeddings not found at {os.path.join(models_dir, 'model_embeddings.json')}") |
| | sys.exit(1) |
| | |
| | if not os.path.exists(os.path.join(models_dir, "models_config.yaml")): |
| | logger.error(f"Models configuration not found at {os.path.join(models_dir, 'models_config.yaml')}") |
| | sys.exit(1) |
| | |
| | |
| | logger.info("Loading embedding model...") |
| | embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") |
| | |
| | |
| | logger.info("Loading ONNX models...") |
| | routing_policy_session = ort.InferenceSession(os.path.join(models_dir, "routing_policy.onnx")) |
| | irt_model_session = ort.InferenceSession(os.path.join(models_dir, "irt_model.onnx")) |
| | |
| | |
| | logger.info("Loading model embeddings...") |
| | with open(os.path.join(models_dir, "model_embeddings.json"), "r") as f: |
| | model_embeddings = json.load(f) |
| | |
| | |
| | logger.info("Loading models configuration...") |
| | with open(os.path.join(models_dir, "models_config.yaml"), "r") as f: |
| | models_config = yaml.safe_load(f) |
| | |
| | logger.info("Startup complete") |
| |
|
| | def get_model_cost(model_name: str) -> float: |
| | """Get the cost of a model.""" |
| | |
| | for provider in models_config.get('api_endpoints', {}).values(): |
| | for model_config in provider.get('models', []): |
| | if model_config['name'] == model_name: |
| | return (model_config.get('cost_per_1k_input_tokens', 0) + |
| | model_config.get('cost_per_1k_output_tokens', 0)) / 2 |
| | |
| | |
| | for model_config in models_config.get('open_source_models', []): |
| | if model_config['name'] == model_name: |
| | return model_config.get('cost_per_1k_tokens', 0) |
| | |
| | |
| | return 0.01 |
| |
|
| | def get_prompt_embedding(query: str) -> np.ndarray: |
| | """Get the embedding for a query.""" |
| | return embedding_model.encode([query], convert_to_numpy=True)[0] |
| |
|
| | def predict_score(prompt_embedding: np.ndarray, model_embedding: np.ndarray) -> float: |
| | """Predict the score for a prompt-model pair using the IRT model.""" |
| | |
| | ort_inputs = { |
| | "prompt_embedding": prompt_embedding.astype(np.float32).reshape(1, -1), |
| | "model_embedding": model_embedding.astype(np.float32).reshape(1, -1) |
| | } |
| | |
| | |
| | ort_outputs = irt_model_session.run(None, ort_inputs) |
| | |
| | |
| | score = ort_outputs[0][0][0] |
| | |
| | return float(score) |
| |
|
| | def route_query( |
| | query: str, |
| | available_models: List[str], |
| | performance_weight: float = 1.0, |
| | cost_weight: float = 0.5, |
| | return_scores: bool = False |
| | ) -> Dict[str, Any]: |
| | """Route a query to the most appropriate LLM.""" |
| | |
| | prompt_embedding = get_prompt_embedding(query) |
| | |
| | |
| | preference = np.array([[performance_weight, cost_weight]], dtype=np.float32) |
| | |
| | |
| | model_contexts = [] |
| | for model_name in available_models: |
| | |
| | if model_name not in model_embeddings: |
| | raise HTTPException(status_code=400, detail=f"Model {model_name} not found") |
| | |
| | model_embedding = np.array(model_embeddings[model_name], dtype=np.float32) |
| | |
| | |
| | model_cost = get_model_cost(model_name) |
| | |
| | |
| | predicted_score = predict_score(prompt_embedding, model_embedding) |
| | |
| | model_contexts.append((model_embedding, model_cost, predicted_score)) |
| | |
| | |
| | model_embeddings_list = np.array([context[0] for context in model_contexts], dtype=np.float32) |
| | model_costs = np.array([context[1] for context in model_contexts], dtype=np.float32) |
| | model_scores = np.array([context[2] for context in model_contexts], dtype=np.float32) |
| | |
| | |
| | model_contexts_concat = np.concatenate([ |
| | model_embeddings_list, |
| | model_costs.reshape(-1, 1), |
| | model_scores.reshape(-1, 1) |
| | ], axis=1) |
| | |
| | |
| | ort_inputs = { |
| | "prompt_embedding": prompt_embedding.astype(np.float32).reshape(1, -1), |
| | "model_contexts": model_contexts_concat.astype(np.float32), |
| | "preference": preference |
| | } |
| | |
| | |
| | ort_outputs = routing_policy_session.run(None, ort_inputs) |
| | |
| | |
| | action_logits = ort_outputs[0] |
| | |
| | |
| | action = np.argmax(action_logits, axis=1)[0] |
| | |
| | |
| | selected_model = available_models[action] |
| | |
| | |
| | response = { |
| | "selected_model": selected_model |
| | } |
| | |
| | |
| | if return_scores: |
| | scores = {} |
| | for i, model_name in enumerate(available_models): |
| | scores[model_name] = float(model_contexts[i][2]) |
| | |
| | response["scores"] = scores |
| | |
| | return response |
| |
|
| | @app.post("/route", response_model=RouteResponse) |
| | async def route_endpoint(request: RouteRequest): |
| | """Route a query to the most appropriate LLM.""" |
| | try: |
| | result = route_query( |
| | query=request.query, |
| | available_models=request.available_models, |
| | performance_weight=request.performance_weight, |
| | cost_weight=request.cost_weight, |
| | return_scores=request.return_scores |
| | ) |
| | |
| | return result |
| | except Exception as e: |
| | logger.error(f"Error routing query: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.get("/models") |
| | async def list_models(): |
| | """List all available models.""" |
| | models = [] |
| | |
| | |
| | for provider, provider_config in models_config.get('api_endpoints', {}).items(): |
| | for model_config in provider_config.get('models', []): |
| | models.append({ |
| | "name": model_config['name'], |
| | "provider": provider, |
| | "cost_per_1k_input_tokens": model_config.get('cost_per_1k_input_tokens', 0), |
| | "cost_per_1k_output_tokens": model_config.get('cost_per_1k_output_tokens', 0), |
| | "description": model_config.get('description', "") |
| | }) |
| | |
| | |
| | for model_config in models_config.get('open_source_models', []): |
| | models.append({ |
| | "name": model_config['name'], |
| | "provider": "open_source", |
| | "cost_per_1k_tokens": model_config.get('cost_per_1k_tokens', 0), |
| | "description": model_config.get('description', "") |
| | }) |
| | |
| | return {"models": models} |
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | """Health check endpoint.""" |
| | return {"status": "ok"} |
| |
|
| | |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.responses import HTMLResponse |
| |
|
| | @app.get("/", response_class=HTMLResponse) |
| | async def read_root(): |
| | return """ |
| | <!DOCTYPE html> |
| | <html> |
| | <head> |
| | <title>Multi-Router-Bandit API</title> |
| | <style> |
| | body { |
| | font-family: Arial, sans-serif; |
| | max-width: 800px; |
| | margin: 0 auto; |
| | padding: 20px; |
| | } |
| | h1 { |
| | color: #333; |
| | } |
| | .form-group { |
| | margin-bottom: 15px; |
| | } |
| | label { |
| | display: block; |
| | margin-bottom: 5px; |
| | } |
| | input[type="text"], textarea, select { |
| | width: 100%; |
| | padding: 8px; |
| | box-sizing: border-box; |
| | } |
| | button { |
| | background-color: #4CAF50; |
| | color: white; |
| | padding: 10px 15px; |
| | border: none; |
| | cursor: pointer; |
| | } |
| | button:hover { |
| | background-color: #45a049; |
| | } |
| | #result { |
| | margin-top: 20px; |
| | padding: 10px; |
| | border: 1px solid #ddd; |
| | background-color: #f9f9f9; |
| | white-space: pre-wrap; |
| | } |
| | </style> |
| | </head> |
| | <body> |
| | <h1>Multi-Router-Bandit API</h1> |
| | <div class="form-group"> |
| | <label for="query">Query:</label> |
| | <textarea id="query" rows="4" placeholder="Enter your query here"></textarea> |
| | </div> |
| | <div class="form-group"> |
| | <label for="models">Available Models (comma-separated):</label> |
| | <input type="text" id="models" placeholder="gpt-4,llama-3-70b,mistral-7b-instruct,phi-3-mini"> |
| | </div> |
| | <div class="form-group"> |
| | <label for="performance-weight">Performance Weight:</label> |
| | <input type="number" id="performance-weight" value="1.0" step="0.1" min="0"> |
| | </div> |
| | <div class="form-group"> |
| | <label for="cost-weight">Cost Weight:</label> |
| | <input type="number" id="cost-weight" value="0.5" step="0.1" min="0"> |
| | </div> |
| | <div class="form-group"> |
| | <label for="return-scores">Return Scores:</label> |
| | <select id="return-scores"> |
| | <option value="true">Yes</option> |
| | <option value="false">No</option> |
| | </select> |
| | </div> |
| | <button onclick="routeQuery()">Route Query</button> |
| | <div id="result"></div> |
| | |
| | <script> |
| | async function routeQuery() { |
| | const query = document.getElementById('query').value; |
| | const modelsStr = document.getElementById('models').value; |
| | const performanceWeight = parseFloat(document.getElementById('performance-weight').value); |
| | const costWeight = parseFloat(document.getElementById('cost-weight').value); |
| | const returnScores = document.getElementById('return-scores').value === 'true'; |
| | |
| | const models = modelsStr.split(',').map(m => m.trim()).filter(m => m); |
| | |
| | if (!query) { |
| | alert('Please enter a query'); |
| | return; |
| | } |
| | |
| | if (models.length === 0) { |
| | alert('Please enter at least one model'); |
| | return; |
| | } |
| | |
| | const resultDiv = document.getElementById('result'); |
| | resultDiv.textContent = 'Loading...'; |
| | |
| | try { |
| | const response = await fetch('/route', { |
| | method: 'POST', |
| | headers: { |
| | 'Content-Type': 'application/json' |
| | }, |
| | body: JSON.stringify({ |
| | query, |
| | available_models: models, |
| | performance_weight: performanceWeight, |
| | cost_weight: costWeight, |
| | return_scores: returnScores |
| | }) |
| | }); |
| | |
| | const data = await response.json(); |
| | |
| | if (response.ok) { |
| | resultDiv.textContent = JSON.stringify(data, null, 2); |
| | } else { |
| | resultDiv.textContent = `Error: ${data.detail}`; |
| | } |
| | } catch (error) { |
| | resultDiv.textContent = `Error: ${error.message}`; |
| | } |
| | } |
| | </script> |
| | </body> |
| | </html> |
| | """ |
| |
|