ai-capabilities / app.py
Evan Shipman
Added provider_id
3869e21
import sqlite3
from contextlib import contextmanager
from fastapi import FastAPI, Query, HTTPException
from typing import List, Optional
from datetime import datetime
app = FastAPI()
DATABASE_NAME = "data.db"
LAST_UPDATED = '2025-01-21'
AUTHOR = 'Evan Shipman'
LICENSE = 'MIT'
# Database initialization and context manager
def init_db():
with sqlite3.connect(DATABASE_NAME) as conn:
conn.executescript('''
CREATE TABLE IF NOT EXISTS providers (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
location TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS models (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
type TEXT NOT NULL DEFAULT 'standard'
);
CREATE TABLE IF NOT EXISTS model_providers (
model_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
FOREIGN KEY (model_id) REFERENCES models (id),
FOREIGN KEY (provider_id) REFERENCES providers (id),
PRIMARY KEY (model_id, provider_id)
);
CREATE TABLE IF NOT EXISTS endpoints (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
context_size INTEGER NOT NULL,
output_size INTEGER NOT NULL,
max_batch_size INTEGER,
max_batch_tokens INTEGER,
features TEXT,
FOREIGN KEY (model_id, provider_id) REFERENCES model_providers (model_id, provider_id)
);
CREATE TABLE IF NOT EXISTS pricing_tiers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
endpoint_id INTEGER NOT NULL,
price_type TEXT NOT NULL,
min_tokens INTEGER NOT NULL DEFAULT 0,
max_tokens INTEGER, -- NULL means no upper limit
price_per_million REAL NOT NULL,
FOREIGN KEY (endpoint_id) REFERENCES endpoints (id),
UNIQUE (endpoint_id, price_type, min_tokens)
);
''')
@contextmanager
def get_db():
conn = sqlite3.connect(DATABASE_NAME)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
# FastAPI routes
@app.get("/")
def get_description():
"""Provides a description of the API."""
return {
"description": "This API provides information about LLM capabilities and their providers.",
"endpoints": [
{
"path": "/models",
"description": "Retrieve a list of models with optional filtering by IDs."
},
{
"path": "/providers",
"description": "Retrieve a list of providers with optional filtering by IDs."
},
{
"path": "/features",
"description": "Retrieve a list of all unique model features/capabilities."
}
],
"author": AUTHOR,
"license": LICENSE,
"last_updated": LAST_UPDATED
}
@app.get("/features")
def get_features():
"""Retrieve a list of all unique features available across all models."""
with get_db() as conn:
cursor = conn.execute("""
SELECT features
FROM endpoints
WHERE features IS NOT NULL
""")
# Collect all unique features
all_features: Set[str] = set()
for row in cursor:
try:
features = json.loads(row['features'])
if isinstance(features, list):
all_features.update(features)
except (json.JSONDecodeError, TypeError):
continue
return sorted(list(all_features))
@app.get("/models")
def get_models(
ids: Optional[List[str]] = Query(None),
features: Optional[List[str]] = Query(None),
provider_id: Optional[str] = Query(None, description="Filter models by provider ID")
):
"""Retrieve a list of models with pricing and batch information. Optionally filter by provider."""
with get_db() as conn:
query = """
SELECT
m.id,
m.name,
m.type,
mp.provider_id,
e.context_size,
e.output_size,
e.max_batch_size,
e.max_batch_tokens,
e.features,
pt.price_type,
pt.min_tokens,
pt.max_tokens,
pt.price_per_million,
pt.batch_size
FROM models m
JOIN model_providers mp ON m.id = mp.model_id
JOIN endpoints e ON m.id = e.model_id AND mp.provider_id = e.provider_id
JOIN pricing_tiers pt ON e.id = pt.endpoint_id
"""
conditions = []
params = []
# Add model ID filter if specified
if ids:
conditions.append("m.id IN (" + ",".join("?" * len(ids)) + ")")
params.extend(ids)
# Add provider filter if specified
if provider_id:
conditions.append("mp.provider_id = ?")
params.append(provider_id)
# Add features filter if specified
if features:
conditions.append("""e.features IS NOT NULL AND (
SELECT COUNT(*)
FROM json_each(e.features)
WHERE value IN (""" + ",".join("?" * len(features)) + """)
) = ?""")
params.extend(features)
params.append(len(features))
if conditions:
query += " WHERE " + " AND ".join(conditions)
cursor = conn.execute(query, params)
models_dict = {}
for row in cursor:
model_id = row['id']
if model_id not in models_dict:
models_dict[model_id] = {
'id': model_id,
'name': row['name'],
'type': row['type'],
'endpoints': []
}
# Find or create endpoint
endpoint = None
for ep in models_dict[model_id]['endpoints']:
if ep['provider'] == row['provider_id']:
endpoint = ep
break
if endpoint is None:
endpoint = {
'provider': row['provider_id'],
'context_size': row['context_size'],
'output_size': row['output_size'],
'pricing': {}
}
# Add features if available
if row['features']:
try:
endpoint['features'] = json.loads(row['features'])
except json.JSONDecodeError:
pass
# Add batch support info if available
if row['max_batch_size'] or row['max_batch_tokens']:
endpoint['batch_support'] = {
'max_batch_size': row['max_batch_size'],
'max_batch_tokens': row['max_batch_tokens']
}
models_dict[model_id]['endpoints'].append(endpoint)
# Add pricing information
price_type = row['price_type']
if price_type not in endpoint['pricing']:
endpoint['pricing'][price_type] = []
price_entry = {
'min_tokens': row['min_tokens'],
'max_tokens': row['max_tokens'],
'price': row['price_per_million']
}
# Add batch size if it exists
if row['batch_size'] is not None:
price_entry['batch_size'] = row['batch_size']
endpoint['pricing'][price_type].append(price_entry)
return list(models_dict.values())
@app.get("/providers")
def get_providers(ids: Optional[List[str]] = Query(None, description="Filter by a list of provider IDs")):
"""Retrieve a list of providers. Supports filtering by a list of IDs."""
with get_db() as conn:
query = """
SELECT
p.id,
p.name,
p.location,
GROUP_CONCAT(mp.model_id) as models
FROM providers p
LEFT JOIN model_providers mp ON p.id = mp.provider_id
"""
params = []
if ids:
query += " WHERE p.id IN (" + ",".join("?" * len(ids)) + ")"
params.extend(ids)
query += " GROUP BY p.id"
cursor = conn.execute(query, params)
providers = []
for row in cursor:
providers.append({
'id': row['id'],
'name': row['name'],
'location': row['location'],
'models': row['models'].split(',') if row['models'] else []
})
return providers
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8000)