| import os |
| import logging |
| from typing import List, Dict |
| from contextlib import asynccontextmanager |
| from datetime import datetime, timezone |
|
|
| from bs4 import BeautifulSoup |
| from fastapi import FastAPI, Request, BackgroundTasks |
| from fastapi.templating import Jinja2Templates |
| from fastapi.responses import JSONResponse |
| import asyncio |
| import aiohttp |
| import uvicorn |
| import pandas as pd |
| from datasets import Dataset, load_dataset |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| IN_SPACE = os.getenv("SPACE_REPO_NAME") is not None |
|
|
| if not IN_SPACE: |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| |
| DATASET_REPO_NAME = os.getenv("DATASET_REPO_NAME", "nbroad/hf-inference-providers-data") |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| |
| DATA_COLLECTION_INTERVAL = 1800 |
|
|
| |
| data_collection_task = None |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Manage application lifecycle""" |
| |
| global data_collection_task |
| data_collection_task = asyncio.create_task(timed_data_collection()) |
| logger.info("Started hourly data collection task") |
| yield |
| |
| if data_collection_task: |
| data_collection_task.cancel() |
| logger.info("Stopped hourly data collection task") |
|
|
| app = FastAPI(title="Inference Provider Dashboard", lifespan=lifespan) |
|
|
| |
| PROVIDERS = [ |
| "togethercomputer", |
| "fireworks-ai", |
| "nebius", |
| "fal", |
| "groq", |
| "cerebras", |
| "sambanovasystems", |
| "replicate", |
| "novita", |
| "Hyperbolic", |
| "featherless-ai", |
| "CohereLabs", |
| "nscale", |
| ] |
|
|
| |
| PROVIDER_TO_INFERENCE_NAME = { |
| "togethercomputer": "together", |
| "fal": "fal-ai", |
| "sambanovasystems": "sambanova", |
| "Hyperbolic": "hyperbolic", |
| "CohereLabs": "cohere", |
| |
| "fireworks-ai": "fireworks-ai", |
| "nebius": "nebius", |
| "groq": "groq", |
| "cerebras": "cerebras", |
| "replicate": "replicate", |
| "novita": "novita", |
| "featherless-ai": "featherless-ai", |
| "nscale": "nscale", |
| } |
|
|
| templates = Jinja2Templates(directory="templates") |
|
|
| async def get_monthly_requests(session: aiohttp.ClientSession, provider: str) -> Dict[str, str]: |
| """Get monthly requests for a provider from HuggingFace""" |
| url = f"https://huggingface.co/{provider}" |
| try: |
| async with session.get(url) as response: |
| html = await response.text() |
| soup = BeautifulSoup(html, 'html.parser') |
| request_div = soup.find('div', text=lambda t: t and 'monthly requests' in t.lower()) |
| if request_div: |
| requests_text = request_div.text.split()[0].replace(',', '') |
| return { |
| "provider": provider, |
| "monthly_requests": requests_text, |
| "monthly_requests_int": int(requests_text) if requests_text.isdigit() else 0 |
| } |
| return { |
| "provider": provider, |
| "monthly_requests": "N/A", |
| "monthly_requests_int": 0 |
| } |
| except Exception as e: |
| logger.error(f"Error fetching {provider}: {e}") |
| return { |
| "provider": provider, |
| "monthly_requests": "N/A", |
| "monthly_requests_int": 0 |
| } |
|
|
| async def get_provider_models(session: aiohttp.ClientSession, provider: str) -> List[str]: |
| """Get supported models for a provider from HuggingFace API""" |
| if not HF_TOKEN: |
| return [] |
| |
| |
| inference_provider = PROVIDER_TO_INFERENCE_NAME.get(provider) |
| if not inference_provider: |
| logger.warning(f"No inference provider mapping found for {provider}") |
| return [] |
| |
| url = f"https://huggingface.co/api/models?inference_provider={inference_provider}&limit=50&sort=downloads&direction=-1" |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
| |
| try: |
| async with session.get(url, headers=headers) as response: |
| if response.status == 200: |
| models_data = await response.json() |
| model_ids = [model.get('id', '') for model in models_data if model.get('id')] |
| return model_ids |
| else: |
| logger.warning(f"Failed to fetch models for {provider} (inference_provider={inference_provider}): {response.status}") |
| return [] |
| except Exception as e: |
| logger.error(f"Error fetching models for {provider} (inference_provider={inference_provider}): {e}") |
| return [] |
|
|
| async def collect_and_store_data(): |
| """Collect current data and store it in the dataset""" |
| if not HF_TOKEN: |
| logger.warning("No HF_TOKEN found, skipping data storage") |
| return |
| |
| try: |
| logger.info("Collecting data for storage...") |
| |
| |
| async with aiohttp.ClientSession() as session: |
| tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] |
| results = await asyncio.gather(*tasks) |
| |
| |
| timestamp = datetime.now(timezone.utc).isoformat() |
| data_rows = [] |
| |
| for result in results: |
| data_rows.append({ |
| "timestamp": timestamp, |
| "provider": result["provider"], |
| "monthly_requests": result["monthly_requests"], |
| "monthly_requests_int": result["monthly_requests_int"] |
| }) |
| |
| new_df = pd.DataFrame(data_rows) |
| |
| |
| try: |
| existing_dataset = load_dataset(DATASET_REPO_NAME, split="train") |
| existing_df = existing_dataset.to_pandas() |
| combined_df = pd.concat([existing_df, new_df], ignore_index=True) |
| except Exception as e: |
| logger.info(f"Creating new dataset (existing not found): {e}") |
| combined_df = new_df |
| |
| |
| combined_df['timestamp'] = pd.to_datetime(combined_df['timestamp']) |
| combined_df = combined_df.sort_values('timestamp') |
| |
| |
| deduplicated_df = combined_df.groupby(['provider', 'monthly_requests_int']).first().reset_index() |
| |
| |
| deduplicated_df['timestamp'] = deduplicated_df['timestamp'].dt.strftime('%Y-%m-%dT%H:%M:%S.%f%z') |
| |
| logger.info(f"De-duplicated dataset: {len(combined_df)} -> {len(deduplicated_df)} records") |
| |
| |
| new_dataset = Dataset.from_pandas(deduplicated_df) |
| new_dataset.push_to_hub(DATASET_REPO_NAME, token=HF_TOKEN, private=False) |
| |
| logger.info(f"Successfully stored data for {len(results)} providers") |
| |
| except Exception as e: |
| logger.error(f"Error collecting and storing data: {e}") |
|
|
| async def timed_data_collection(): |
| """Background task that runs every DATA_COLLECTION_INTERVAL seconds to collect data""" |
| while True: |
| try: |
| await collect_and_store_data() |
| await asyncio.sleep(DATA_COLLECTION_INTERVAL) |
| except asyncio.CancelledError: |
| logger.info("Data collection task cancelled") |
| break |
| except Exception as e: |
| logger.error(f"Error in hourly data collection: {e}") |
| |
| await asyncio.sleep(300) |
|
|
| @app.get("/") |
| async def dashboard(request: Request): |
| """Serve the main dashboard page""" |
| return templates.TemplateResponse("dashboard.html", {"request": request}) |
|
|
| @app.get("/api/providers") |
| async def get_providers_data(): |
| """API endpoint to get provider data""" |
| async with aiohttp.ClientSession() as session: |
| tasks = [get_monthly_requests(session, provider) for provider in PROVIDERS] |
| results = await asyncio.gather(*tasks) |
| |
| |
| results.sort(key=lambda x: x["monthly_requests_int"], reverse=True) |
| |
| response = JSONResponse({ |
| "providers": results, |
| "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
| "total_providers": len(results) |
| }) |
| |
| response.headers["Cache-Control"] = "public, max-age=30" |
| return response |
|
|
| @app.get("/api/providers/{provider}") |
| async def get_provider_data(provider: str): |
| """API endpoint to get data for a specific provider""" |
| if provider not in PROVIDERS: |
| return {"error": "Provider not found"} |
| |
| async with aiohttp.ClientSession() as session: |
| result = await get_monthly_requests(session, provider) |
| |
| return { |
| "provider_data": result, |
| "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| } |
|
|
| @app.get("/api/historical") |
| async def get_historical_data(): |
| """API endpoint to get historical data for line chart""" |
| if not HF_TOKEN: |
| logger.warning("No HF_TOKEN available for historical data") |
| response = JSONResponse({ |
| "error": "Historical data not available - no HF token", |
| "historical_data": {}, |
| "message": "Historical data collection requires HuggingFace token" |
| }) |
| response.headers["Cache-Control"] = "public, max-age=60" |
| return response |
| |
| try: |
| |
| dataset = load_dataset(DATASET_REPO_NAME, split="train") |
| df = dataset.to_pandas() |
| |
| logger.info(f"Loaded dataset with {len(df)} total records") |
| |
| if df.empty: |
| logger.info("Dataset is empty - no historical data available yet") |
| response = JSONResponse({ |
| "historical_data": {}, |
| "message": "No historical data available yet. Data collection is running - check back in 30 minutes.", |
| "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| }) |
| response.headers["Cache-Control"] = "public, max-age=60" |
| return response |
| |
| |
| df['timestamp'] = pd.to_datetime(df['timestamp']) |
| df = df.sort_values('timestamp') |
| |
| |
| df_filtered = df.copy() |
| |
| logger.info(f"Using all {len(df_filtered)} records for full historical view") |
| |
| |
| max_points_per_provider = 500 |
| if len(df_filtered) > max_points_per_provider * len(PROVIDERS): |
| |
| df_filtered = df_filtered.groupby('provider').apply( |
| lambda x: x.iloc[::max(1, len(x) // max_points_per_provider)] |
| ).reset_index(drop=True) |
| logger.info(f"Sampled down to {len(df_filtered)} records for performance") |
| |
| |
| historical_data = {} |
| total_data_points = 0 |
| |
| for provider in PROVIDERS: |
| provider_data = df_filtered[df_filtered['provider'] == provider].copy() |
| if not provider_data.empty: |
| |
| historical_data[provider] = [ |
| { |
| "x": row['timestamp'].isoformat(), |
| "y": row['monthly_requests_int'] |
| } |
| for _, row in provider_data.iterrows() |
| ] |
| total_data_points += len(historical_data[provider]) |
| else: |
| historical_data[provider] = [] |
| |
| logger.info(f"Returning {total_data_points} total data points across {len([p for p in historical_data.values() if p])} providers") |
| |
| |
| if not df_filtered.empty: |
| earliest_date = df_filtered['timestamp'].min().strftime('%Y-%m-%d %H:%M') |
| latest_date = df_filtered['timestamp'].max().strftime('%Y-%m-%d %H:%M') |
| date_range = f"From {earliest_date} to {latest_date}" |
| else: |
| date_range = "No data" |
|
|
| from fastapi.responses import JSONResponse |
| response = JSONResponse({ |
| "historical_data": historical_data, |
| "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), |
| "total_data_points": total_data_points, |
| "data_range": date_range, |
| "earliest_date": df_filtered['timestamp'].min().isoformat() if not df_filtered.empty else None, |
| "latest_date": df_filtered['timestamp'].max().isoformat() if not df_filtered.empty else None |
| }) |
| |
| response.headers["Cache-Control"] = "public, max-age=120" |
| return response |
| |
| except Exception as e: |
| logger.error(f"Error fetching historical data: {e}") |
| |
| if "does not exist" in str(e).lower() or "not found" in str(e).lower(): |
| logger.info("Dataset doesn't exist yet, triggering initial data collection") |
| try: |
| await collect_and_store_data() |
| response = JSONResponse({ |
| "historical_data": {}, |
| "message": "Dataset created! Historical data will appear after a few data collection cycles.", |
| "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| }) |
| response.headers["Cache-Control"] = "public, max-age=60" |
| return response |
| except Exception as create_error: |
| logger.error(f"Failed to create initial dataset: {create_error}") |
| |
| response = JSONResponse({ |
| "error": f"Failed to fetch historical data: {str(e)}", |
| "historical_data": {}, |
| "message": "Historical data temporarily unavailable" |
| }) |
| response.headers["Cache-Control"] = "public, max-age=30" |
| return response |
|
|
| @app.get("/api/models") |
| async def get_provider_models_data(): |
| """API endpoint to get supported models matrix for all providers""" |
| if not HF_TOKEN: |
| return {"error": "HF_TOKEN required for models data", "matrix": [], "providers": PROVIDERS} |
| |
| async with aiohttp.ClientSession() as session: |
| tasks = [get_provider_models(session, provider) for provider in PROVIDERS] |
| results = await asyncio.gather(*tasks) |
| |
| |
| provider_models = {} |
| all_models = set() |
| |
| for provider, models in zip(PROVIDERS, results): |
| provider_models[provider] = set(models) |
| all_models.update(models) |
| |
| |
| model_popularity = [] |
| for model in all_models: |
| provider_count = sum(1 for provider in PROVIDERS if model in provider_models.get(provider, set())) |
| model_popularity.append((model, provider_count)) |
| |
| |
| model_popularity.sort(key=lambda x: (-x[1], x[0])) |
| |
| |
| matrix = [] |
| for model_id, popularity in model_popularity: |
| row = { |
| "model_id": model_id, |
| "total_providers": popularity, |
| "providers": {} |
| } |
| |
| for provider in PROVIDERS: |
| row["providers"][provider] = model_id in provider_models.get(provider, set()) |
| |
| matrix.append(row) |
| |
| |
| provider_totals = {} |
| for provider in PROVIDERS: |
| provider_totals[provider] = len(provider_models.get(provider, set())) |
| |
| response = JSONResponse({ |
| "matrix": matrix, |
| "providers": PROVIDERS, |
| "provider_totals": provider_totals, |
| "provider_mapping": PROVIDER_TO_INFERENCE_NAME, |
| "total_models": len(all_models), |
| "last_updated": datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| }) |
| |
| response.headers["Cache-Control"] = "public, max-age=300" |
| return response |
|
|
| @app.post("/api/collect-now") |
| async def trigger_data_collection(background_tasks: BackgroundTasks): |
| """Manual trigger for data collection""" |
| background_tasks.add_task(collect_and_store_data) |
| return {"message": "Data collection triggered", "timestamp": datetime.now().isoformat()} |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |