phi4_fullrouter / app.py
gchaves-99's picture
Add all files
077d109
import os
import sys
import json
import yaml
import numpy as np
from typing import List, Dict, Optional, Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import onnxruntime as ort
from sentence_transformers import SentenceTransformer
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Define API models
class RouteRequest(BaseModel):
query: str
available_models: List[str]
performance_weight: float = 1.0
cost_weight: float = 0.5
return_scores: bool = False
class RouteResponse(BaseModel):
selected_model: str
scores: Optional[Dict[str, float]] = None
# Initialize FastAPI app
app = FastAPI(
title="Multi-Router-Bandit API",
description="API for routing queries to the most appropriate LLM",
version="1.0.0"
)
# Global variables
embedding_model = None
routing_policy_session = None
irt_model_session = None
model_embeddings = None
models_config = None
def get_models_dir():
"""Get the models directory from environment variable or default."""
return os.environ.get("MODELS_DIR", "models")
@app.on_event("startup")
async def startup_event():
"""Initialize models on startup."""
global embedding_model, routing_policy_session, irt_model_session, model_embeddings, models_config
models_dir = get_models_dir()
# Check if models exist
if not os.path.exists(os.path.join(models_dir, "routing_policy.onnx")):
logger.error(f"Routing policy not found at {os.path.join(models_dir, 'routing_policy.onnx')}")
sys.exit(1)
if not os.path.exists(os.path.join(models_dir, "irt_model.onnx")):
logger.error(f"IRT model not found at {os.path.join(models_dir, 'irt_model.onnx')}")
sys.exit(1)
if not os.path.exists(os.path.join(models_dir, "model_embeddings.json")):
logger.error(f"Model embeddings not found at {os.path.join(models_dir, 'model_embeddings.json')}")
sys.exit(1)
if not os.path.exists(os.path.join(models_dir, "models_config.yaml")):
logger.error(f"Models configuration not found at {os.path.join(models_dir, 'models_config.yaml')}")
sys.exit(1)
# Load embedding model
logger.info("Loading embedding model...")
embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# Load ONNX models
logger.info("Loading ONNX models...")
routing_policy_session = ort.InferenceSession(os.path.join(models_dir, "routing_policy.onnx"))
irt_model_session = ort.InferenceSession(os.path.join(models_dir, "irt_model.onnx"))
# Load model embeddings
logger.info("Loading model embeddings...")
with open(os.path.join(models_dir, "model_embeddings.json"), "r") as f:
model_embeddings = json.load(f)
# Load models configuration
logger.info("Loading models configuration...")
with open(os.path.join(models_dir, "models_config.yaml"), "r") as f:
models_config = yaml.safe_load(f)
logger.info("Startup complete")
def get_model_cost(model_name: str) -> float:
"""Get the cost of a model."""
# Check API endpoints
for provider in models_config.get('api_endpoints', {}).values():
for model_config in provider.get('models', []):
if model_config['name'] == model_name:
return (model_config.get('cost_per_1k_input_tokens', 0) +
model_config.get('cost_per_1k_output_tokens', 0)) / 2
# Check open source models
for model_config in models_config.get('open_source_models', []):
if model_config['name'] == model_name:
return model_config.get('cost_per_1k_tokens', 0)
# Default cost
return 0.01
def get_prompt_embedding(query: str) -> np.ndarray:
"""Get the embedding for a query."""
return embedding_model.encode([query], convert_to_numpy=True)[0]
def predict_score(prompt_embedding: np.ndarray, model_embedding: np.ndarray) -> float:
"""Predict the score for a prompt-model pair using the IRT model."""
# Prepare inputs
ort_inputs = {
"prompt_embedding": prompt_embedding.astype(np.float32).reshape(1, -1),
"model_embedding": model_embedding.astype(np.float32).reshape(1, -1)
}
# Run inference
ort_outputs = irt_model_session.run(None, ort_inputs)
# Get score
score = ort_outputs[0][0][0]
return float(score)
def route_query(
query: str,
available_models: List[str],
performance_weight: float = 1.0,
cost_weight: float = 0.5,
return_scores: bool = False
) -> Dict[str, Any]:
"""Route a query to the most appropriate LLM."""
# Get prompt embedding
prompt_embedding = get_prompt_embedding(query)
# Create preference vector
preference = np.array([[performance_weight, cost_weight]], dtype=np.float32)
# Create model contexts
model_contexts = []
for model_name in available_models:
# Get model embedding
if model_name not in model_embeddings:
raise HTTPException(status_code=400, detail=f"Model {model_name} not found")
model_embedding = np.array(model_embeddings[model_name], dtype=np.float32)
# Get model cost
model_cost = get_model_cost(model_name)
# Predict score
predicted_score = predict_score(prompt_embedding, model_embedding)
model_contexts.append((model_embedding, model_cost, predicted_score))
# Prepare inputs for routing policy
model_embeddings_list = np.array([context[0] for context in model_contexts], dtype=np.float32)
model_costs = np.array([context[1] for context in model_contexts], dtype=np.float32)
model_scores = np.array([context[2] for context in model_contexts], dtype=np.float32)
# Concatenate model contexts
model_contexts_concat = np.concatenate([
model_embeddings_list,
model_costs.reshape(-1, 1),
model_scores.reshape(-1, 1)
], axis=1)
# Prepare inputs for ONNX model
ort_inputs = {
"prompt_embedding": prompt_embedding.astype(np.float32).reshape(1, -1),
"model_contexts": model_contexts_concat.astype(np.float32),
"preference": preference
}
# Run inference
ort_outputs = routing_policy_session.run(None, ort_inputs)
# Get action logits
action_logits = ort_outputs[0]
# Get action with highest probability
action = np.argmax(action_logits, axis=1)[0]
# Get selected model
selected_model = available_models[action]
# Prepare response
response = {
"selected_model": selected_model
}
# Add scores if requested
if return_scores:
scores = {}
for i, model_name in enumerate(available_models):
scores[model_name] = float(model_contexts[i][2])
response["scores"] = scores
return response
@app.post("/route", response_model=RouteResponse)
async def route_endpoint(request: RouteRequest):
"""Route a query to the most appropriate LLM."""
try:
result = route_query(
query=request.query,
available_models=request.available_models,
performance_weight=request.performance_weight,
cost_weight=request.cost_weight,
return_scores=request.return_scores
)
return result
except Exception as e:
logger.error(f"Error routing query: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models():
"""List all available models."""
models = []
# Get models from API endpoints
for provider, provider_config in models_config.get('api_endpoints', {}).items():
for model_config in provider_config.get('models', []):
models.append({
"name": model_config['name'],
"provider": provider,
"cost_per_1k_input_tokens": model_config.get('cost_per_1k_input_tokens', 0),
"cost_per_1k_output_tokens": model_config.get('cost_per_1k_output_tokens', 0),
"description": model_config.get('description', "")
})
# Get open source models
for model_config in models_config.get('open_source_models', []):
models.append({
"name": model_config['name'],
"provider": "open_source",
"cost_per_1k_tokens": model_config.get('cost_per_1k_tokens', 0),
"description": model_config.get('description', "")
})
return {"models": models}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "ok"}
# Add a simple UI for testing
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
@app.get("/", response_class=HTMLResponse)
async def read_root():
return """
<!DOCTYPE html>
<html>
<head>
<title>Multi-Router-Bandit API</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
h1 {
color: #333;
}
.form-group {
margin-bottom: 15px;
}
label {
display: block;
margin-bottom: 5px;
}
input[type="text"], textarea, select {
width: 100%;
padding: 8px;
box-sizing: border-box;
}
button {
background-color: #4CAF50;
color: white;
padding: 10px 15px;
border: none;
cursor: pointer;
}
button:hover {
background-color: #45a049;
}
#result {
margin-top: 20px;
padding: 10px;
border: 1px solid #ddd;
background-color: #f9f9f9;
white-space: pre-wrap;
}
</style>
</head>
<body>
<h1>Multi-Router-Bandit API</h1>
<div class="form-group">
<label for="query">Query:</label>
<textarea id="query" rows="4" placeholder="Enter your query here"></textarea>
</div>
<div class="form-group">
<label for="models">Available Models (comma-separated):</label>
<input type="text" id="models" placeholder="gpt-4,llama-3-70b,mistral-7b-instruct,phi-3-mini">
</div>
<div class="form-group">
<label for="performance-weight">Performance Weight:</label>
<input type="number" id="performance-weight" value="1.0" step="0.1" min="0">
</div>
<div class="form-group">
<label for="cost-weight">Cost Weight:</label>
<input type="number" id="cost-weight" value="0.5" step="0.1" min="0">
</div>
<div class="form-group">
<label for="return-scores">Return Scores:</label>
<select id="return-scores">
<option value="true">Yes</option>
<option value="false">No</option>
</select>
</div>
<button onclick="routeQuery()">Route Query</button>
<div id="result"></div>
<script>
async function routeQuery() {
const query = document.getElementById('query').value;
const modelsStr = document.getElementById('models').value;
const performanceWeight = parseFloat(document.getElementById('performance-weight').value);
const costWeight = parseFloat(document.getElementById('cost-weight').value);
const returnScores = document.getElementById('return-scores').value === 'true';
const models = modelsStr.split(',').map(m => m.trim()).filter(m => m);
if (!query) {
alert('Please enter a query');
return;
}
if (models.length === 0) {
alert('Please enter at least one model');
return;
}
const resultDiv = document.getElementById('result');
resultDiv.textContent = 'Loading...';
try {
const response = await fetch('/route', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
query,
available_models: models,
performance_weight: performanceWeight,
cost_weight: costWeight,
return_scores: returnScores
})
});
const data = await response.json();
if (response.ok) {
resultDiv.textContent = JSON.stringify(data, null, 2);
} else {
resultDiv.textContent = `Error: ${data.detail}`;
}
} catch (error) {
resultDiv.textContent = `Error: ${error.message}`;
}
}
</script>
</body>
</html>
"""