Spaces:
Paused
Paused
| import time | |
| # Phase 1: Minimal imports for arg parsing and TUI | |
| import asyncio | |
| import os | |
| from pathlib import Path | |
| import sys | |
| import argparse | |
| import logging | |
| # --- Argument Parsing (BEFORE heavy imports) --- | |
| parser = argparse.ArgumentParser(description="API Key Proxy Server") | |
| parser.add_argument( | |
| "--host", type=str, default="0.0.0.0", help="Host to bind the server to." | |
| ) | |
| parser.add_argument("--port", type=int, default=8000, help="Port to run the server on.") | |
| parser.add_argument( | |
| "--enable-request-logging", action="store_true", help="Enable request logging." | |
| ) | |
| parser.add_argument( | |
| "--add-credential", | |
| action="store_true", | |
| help="Launch the interactive tool to add a new OAuth credential.", | |
| ) | |
| args, _ = parser.parse_known_args() | |
| # Add the 'src' directory to the Python path | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| # Check if we should launch TUI (no arguments = TUI mode) | |
| if len(sys.argv) == 1: | |
| # TUI MODE - Load ONLY what's needed for the launcher (fast path!) | |
| from proxy_app.launcher_tui import run_launcher_tui | |
| run_launcher_tui() | |
| # Launcher modifies sys.argv and returns, or exits if user chose Exit | |
| # If we get here, user chose "Run Proxy" and sys.argv is modified | |
| # Re-parse arguments with modified sys.argv | |
| args = parser.parse_args() | |
| # Check if credential tool mode (also doesn't need heavy proxy imports) | |
| if args.add_credential: | |
| from rotator_library.credential_tool import run_credential_tool | |
| run_credential_tool() | |
| sys.exit(0) | |
| # If we get here, we're ACTUALLY running the proxy - NOW show startup messages and start timer | |
| _start_time = time.time() | |
| # Load all .env files from root folder (main .env first, then any additional *.env files) | |
| from dotenv import load_dotenv | |
| from glob import glob | |
| # Get the application root directory (EXE dir if frozen, else CWD) | |
| # Inlined here to avoid triggering heavy rotator_library imports before loading screen | |
| if getattr(sys, "frozen", False): | |
| _root_dir = Path(sys.executable).parent | |
| else: | |
| _root_dir = Path.cwd() | |
| # Load main .env first | |
| load_dotenv(_root_dir / ".env") | |
| # Load any additional .env files (e.g., antigravity_all_combined.env, gemini_cli_all_combined.env) | |
| _env_files_found = list(_root_dir.glob("*.env")) | |
| for _env_file in sorted(_root_dir.glob("*.env")): | |
| if _env_file.name != ".env": # Skip main .env (already loaded) | |
| load_dotenv(_env_file, override=False) # Don't override existing values | |
| # Log discovered .env files for deployment verification | |
| if _env_files_found: | |
| _env_names = [_ef.name for _ef in _env_files_found] | |
| print(f"📁 Loaded {len(_env_files_found)} .env file(s): {', '.join(_env_names)}") | |
| # Get proxy API key for display | |
| proxy_api_key = os.getenv("PROXY_API_KEY") | |
| if proxy_api_key: | |
| key_display = f"✓ {proxy_api_key}" | |
| else: | |
| key_display = "✗ Not Set (INSECURE - anyone can access!)" | |
| print("━" * 70) | |
| print(f"Starting proxy on {args.host}:{args.port}") | |
| print(f"Proxy API Key: {key_display}") | |
| print(f"GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy") | |
| print("━" * 70) | |
| print("Loading server components...") | |
| # Phase 2: Load Rich for loading spinner (lightweight) | |
| from rich.console import Console | |
| _console = Console() | |
| # Phase 3: Heavy dependencies with granular loading messages | |
| print(" → Loading FastAPI framework...") | |
| with _console.status("[dim]Loading FastAPI framework...", spinner="dots"): | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, Request, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.security import APIKeyHeader | |
| print(" → Loading core dependencies...") | |
| with _console.status("[dim]Loading core dependencies...", spinner="dots"): | |
| from dotenv import load_dotenv | |
| import colorlog | |
| import json | |
| from typing import AsyncGenerator, Any, List, Optional, Union | |
| from pydantic import BaseModel, Field | |
| # --- Early Log Level Configuration --- | |
| logging.getLogger("LiteLLM").setLevel(logging.WARNING) | |
| print(" → Loading LiteLLM library...") | |
| with _console.status("[dim]Loading LiteLLM library...", spinner="dots"): | |
| import litellm | |
| # Phase 4: Application imports with granular loading messages | |
| print(" → Initializing proxy core...") | |
| with _console.status("[dim]Initializing proxy core...", spinner="dots"): | |
| from rotator_library import RotatingClient | |
| from rotator_library.credential_manager import CredentialManager | |
| from rotator_library.background_refresher import BackgroundRefresher | |
| from rotator_library.model_info_service import init_model_info_service | |
| from proxy_app.request_logger import log_request_to_console | |
| from proxy_app.batch_manager import EmbeddingBatcher | |
| from proxy_app.detailed_logger import DetailedLogger | |
| print(" → Discovering provider plugins...") | |
| # Provider lazy loading happens during import, so time it here | |
| _provider_start = time.time() | |
| with _console.status("[dim]Discovering provider plugins...", spinner="dots"): | |
| from rotator_library import ( | |
| PROVIDER_PLUGINS, | |
| ) # This triggers lazy load via __getattr__ | |
| _provider_time = time.time() - _provider_start | |
| # Get count after import (without timing to avoid double-counting) | |
| _plugin_count = len(PROVIDER_PLUGINS) | |
| # --- Pydantic Models --- | |
| class EmbeddingRequest(BaseModel): | |
| model: str | |
| input: Union[str, List[str]] | |
| input_type: Optional[str] = None | |
| dimensions: Optional[int] = None | |
| user: Optional[str] = None | |
| class ModelCard(BaseModel): | |
| """Basic model card for minimal response.""" | |
| id: str | |
| object: str = "model" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| owned_by: str = "Mirro-Proxy" | |
| class ModelCapabilities(BaseModel): | |
| """Model capability flags.""" | |
| tool_choice: bool = False | |
| function_calling: bool = False | |
| reasoning: bool = False | |
| vision: bool = False | |
| system_messages: bool = True | |
| prompt_caching: bool = False | |
| assistant_prefill: bool = False | |
| class EnrichedModelCard(BaseModel): | |
| """Extended model card with pricing and capabilities.""" | |
| id: str | |
| object: str = "model" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| owned_by: str = "unknown" | |
| # Pricing (optional - may not be available for all models) | |
| input_cost_per_token: Optional[float] = None | |
| output_cost_per_token: Optional[float] = None | |
| cache_read_input_token_cost: Optional[float] = None | |
| cache_creation_input_token_cost: Optional[float] = None | |
| # Limits (optional) | |
| max_input_tokens: Optional[int] = None | |
| max_output_tokens: Optional[int] = None | |
| context_window: Optional[int] = None | |
| # Capabilities | |
| mode: str = "chat" | |
| supported_modalities: List[str] = Field(default_factory=lambda: ["text"]) | |
| supported_output_modalities: List[str] = Field(default_factory=lambda: ["text"]) | |
| capabilities: Optional[ModelCapabilities] = None | |
| # Debug info (optional) | |
| _sources: Optional[List[str]] = None | |
| _match_type: Optional[str] = None | |
| class Config: | |
| extra = "allow" # Allow extra fields from the service | |
| class ModelList(BaseModel): | |
| """List of models response.""" | |
| object: str = "list" | |
| data: List[ModelCard] | |
| class EnrichedModelList(BaseModel): | |
| """List of enriched models with pricing and capabilities.""" | |
| object: str = "list" | |
| data: List[EnrichedModelCard] | |
| # Calculate total loading time | |
| _elapsed = time.time() - _start_time | |
| print( | |
| f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)" | |
| ) | |
| # Clear screen and reprint header for clean startup view | |
| # This pushes loading messages up (still in scroll history) but shows a clean final screen | |
| import os as _os_module | |
| _os_module.system("cls" if _os_module.name == "nt" else "clear") | |
| # Reprint header | |
| print("━" * 70) | |
| print(f"Starting proxy on {args.host}:{args.port}") | |
| print(f"Proxy API Key: {key_display}") | |
| print(f"GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy") | |
| print("━" * 70) | |
| print( | |
| f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)" | |
| ) | |
| # Note: Debug logging will be added after logging configuration below | |
| # --- Logging Configuration --- | |
| # Import path utilities here (after loading screen) to avoid triggering heavy imports early | |
| from rotator_library.utils.paths import get_logs_dir, get_data_file | |
| LOG_DIR = get_logs_dir(_root_dir) | |
| # Configure a console handler with color (INFO and above only, no DEBUG) | |
| console_handler = colorlog.StreamHandler(sys.stdout) | |
| console_handler.setLevel(logging.INFO) | |
| formatter = colorlog.ColoredFormatter( | |
| "%(log_color)s%(message)s", | |
| log_colors={ | |
| "DEBUG": "cyan", | |
| "INFO": "green", | |
| "WARNING": "yellow", | |
| "ERROR": "red", | |
| "CRITICAL": "red,bg_white", | |
| }, | |
| ) | |
| console_handler.setFormatter(formatter) | |
| # Configure a file handler for INFO-level logs and higher | |
| info_file_handler = logging.FileHandler(LOG_DIR / "proxy.log", encoding="utf-8") | |
| info_file_handler.setLevel(logging.INFO) | |
| info_file_handler.setFormatter( | |
| logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| ) | |
| # Configure a dedicated file handler for all DEBUG-level logs | |
| debug_file_handler = logging.FileHandler(LOG_DIR / "proxy_debug.log", encoding="utf-8") | |
| debug_file_handler.setLevel(logging.DEBUG) | |
| debug_file_handler.setFormatter( | |
| logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| ) | |
| # Create a filter to ensure the debug handler ONLY gets DEBUG messages from the rotator_library | |
| class RotatorDebugFilter(logging.Filter): | |
| def filter(self, record): | |
| return record.levelno == logging.DEBUG and record.name.startswith( | |
| "rotator_library" | |
| ) | |
| debug_file_handler.addFilter(RotatorDebugFilter()) | |
| # Configure a console handler with color | |
| console_handler = colorlog.StreamHandler(sys.stdout) | |
| console_handler.setLevel(logging.INFO) | |
| formatter = colorlog.ColoredFormatter( | |
| "%(log_color)s%(message)s", | |
| log_colors={ | |
| "DEBUG": "cyan", | |
| "INFO": "green", | |
| "WARNING": "yellow", | |
| "ERROR": "red", | |
| "CRITICAL": "red,bg_white", | |
| }, | |
| ) | |
| console_handler.setFormatter(formatter) | |
| # Add a filter to prevent any LiteLLM logs from cluttering the console | |
| class NoLiteLLMLogFilter(logging.Filter): | |
| def filter(self, record): | |
| return not record.name.startswith("LiteLLM") | |
| console_handler.addFilter(NoLiteLLMLogFilter()) | |
| # Get the root logger and set it to DEBUG to capture all messages | |
| root_logger = logging.getLogger() | |
| root_logger.setLevel(logging.DEBUG) | |
| # Add all handlers to the root logger | |
| root_logger.addHandler(info_file_handler) | |
| root_logger.addHandler(console_handler) | |
| root_logger.addHandler(debug_file_handler) | |
| # Silence other noisy loggers by setting their level higher than root | |
| logging.getLogger("uvicorn").setLevel(logging.WARNING) | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| # Isolate LiteLLM's logger to prevent it from reaching the console. | |
| # We will capture its logs via the logger_fn callback in the client instead. | |
| litellm_logger = logging.getLogger("LiteLLM") | |
| litellm_logger.handlers = [] | |
| litellm_logger.propagate = False | |
| # Now that logging is configured, log the module load time to debug file only | |
| logging.debug(f"Modules loaded in {_elapsed:.2f}s") | |
| # Load environment variables from .env file | |
| load_dotenv(_root_dir / ".env") | |
| # --- Configuration --- | |
| USE_EMBEDDING_BATCHER = False | |
| ENABLE_REQUEST_LOGGING = args.enable_request_logging | |
| if ENABLE_REQUEST_LOGGING: | |
| logging.info("Request logging is enabled.") | |
| PROXY_API_KEY = os.getenv("PROXY_API_KEY") | |
| # Note: PROXY_API_KEY validation moved to server startup to allow credential tool to run first | |
| # Discover API keys from environment variables | |
| api_keys = {} | |
| for key, value in os.environ.items(): | |
| if "_API_KEY" in key and key != "PROXY_API_KEY": | |
| provider = key.split("_API_KEY")[0].lower() | |
| if provider not in api_keys: | |
| api_keys[provider] = [] | |
| api_keys[provider].append(value) | |
| # Load model ignore lists from environment variables | |
| ignore_models = {} | |
| for key, value in os.environ.items(): | |
| if key.startswith("IGNORE_MODELS_"): | |
| provider = key.replace("IGNORE_MODELS_", "").lower() | |
| models_to_ignore = [ | |
| model.strip() for model in value.split(",") if model.strip() | |
| ] | |
| ignore_models[provider] = models_to_ignore | |
| logging.debug( | |
| f"Loaded ignore list for provider '{provider}': {models_to_ignore}" | |
| ) | |
| # Load model whitelist from environment variables | |
| whitelist_models = {} | |
| for key, value in os.environ.items(): | |
| if key.startswith("WHITELIST_MODELS_"): | |
| provider = key.replace("WHITELIST_MODELS_", "").lower() | |
| models_to_whitelist = [ | |
| model.strip() for model in value.split(",") if model.strip() | |
| ] | |
| whitelist_models[provider] = models_to_whitelist | |
| logging.debug( | |
| f"Loaded whitelist for provider '{provider}': {models_to_whitelist}" | |
| ) | |
| # Load max concurrent requests per key from environment variables | |
| max_concurrent_requests_per_key = {} | |
| for key, value in os.environ.items(): | |
| if key.startswith("MAX_CONCURRENT_REQUESTS_PER_KEY_"): | |
| provider = key.replace("MAX_CONCURRENT_REQUESTS_PER_KEY_", "").lower() | |
| try: | |
| max_concurrent = int(value) | |
| if max_concurrent < 1: | |
| logging.warning( | |
| f"Invalid max_concurrent value for provider '{provider}': {value}. Must be >= 1. Using default (1)." | |
| ) | |
| max_concurrent = 1 | |
| max_concurrent_requests_per_key[provider] = max_concurrent | |
| logging.debug( | |
| f"Loaded max concurrent requests for provider '{provider}': {max_concurrent}" | |
| ) | |
| except ValueError: | |
| logging.warning( | |
| f"Invalid max_concurrent value for provider '{provider}': {value}. Using default (1)." | |
| ) | |
| # --- Lifespan Management --- | |
| async def lifespan(app: FastAPI): | |
| """Manage the RotatingClient's lifecycle with the app's lifespan.""" | |
| # [MODIFIED] Perform skippable OAuth initialization at startup | |
| skip_oauth_init = os.getenv("SKIP_OAUTH_INIT_CHECK", "false").lower() == "true" | |
| # The CredentialManager now handles all discovery, including .env overrides. | |
| # We pass all environment variables to it for this purpose. | |
| cred_manager = CredentialManager(os.environ) | |
| oauth_credentials = cred_manager.discover_and_prepare() | |
| if not skip_oauth_init and oauth_credentials: | |
| logging.info("Starting OAuth credential validation and deduplication...") | |
| processed_emails = {} # email -> {provider: path} | |
| credentials_to_initialize = {} # provider -> [paths] | |
| final_oauth_credentials = {} | |
| # --- Pass 1: Pre-initialization Scan & Deduplication --- | |
| # logging.info("Pass 1: Scanning for existing metadata to find duplicates...") | |
| for provider, paths in oauth_credentials.items(): | |
| if provider not in credentials_to_initialize: | |
| credentials_to_initialize[provider] = [] | |
| for path in paths: | |
| # Skip env-based credentials (virtual paths) - they don't have metadata files | |
| if path.startswith("env://"): | |
| credentials_to_initialize[provider].append(path) | |
| continue | |
| try: | |
| with open(path, "r") as f: | |
| data = json.load(f) | |
| metadata = data.get("_proxy_metadata", {}) | |
| email = metadata.get("email") | |
| if email: | |
| if email not in processed_emails: | |
| processed_emails[email] = {} | |
| if provider in processed_emails[email]: | |
| original_path = processed_emails[email][provider] | |
| logging.warning( | |
| f"Duplicate for '{email}' on '{provider}' found in pre-scan: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping." | |
| ) | |
| continue | |
| else: | |
| processed_emails[email][provider] = path | |
| credentials_to_initialize[provider].append(path) | |
| except (FileNotFoundError, json.JSONDecodeError) as e: | |
| logging.warning( | |
| f"Could not pre-read metadata from '{path}': {e}. Will process during initialization." | |
| ) | |
| credentials_to_initialize[provider].append(path) | |
| # --- Pass 2: Parallel Initialization of Filtered Credentials --- | |
| # logging.info("Pass 2: Initializing unique credentials and performing final check...") | |
| async def process_credential(provider: str, path: str, provider_instance): | |
| """Process a single credential: initialize and fetch user info.""" | |
| try: | |
| await provider_instance.initialize_token(path) | |
| if not hasattr(provider_instance, "get_user_info"): | |
| return (provider, path, None, None) | |
| user_info = await provider_instance.get_user_info(path) | |
| email = user_info.get("email") | |
| return (provider, path, email, None) | |
| except Exception as e: | |
| logging.error( | |
| f"Failed to process OAuth token for {provider} at '{path}': {e}" | |
| ) | |
| return (provider, path, None, e) | |
| # Collect all tasks for parallel execution | |
| tasks = [] | |
| for provider, paths in credentials_to_initialize.items(): | |
| if not paths: | |
| continue | |
| provider_plugin_class = PROVIDER_PLUGINS.get(provider) | |
| if not provider_plugin_class: | |
| continue | |
| provider_instance = provider_plugin_class() | |
| for path in paths: | |
| tasks.append(process_credential(provider, path, provider_instance)) | |
| # Execute all credential processing tasks in parallel | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # --- Pass 3: Sequential Deduplication and Final Assembly --- | |
| for result in results: | |
| # Handle exceptions from gather | |
| if isinstance(result, Exception): | |
| logging.error(f"Credential processing raised exception: {result}") | |
| continue | |
| provider, path, email, error = result | |
| # Skip if there was an error | |
| if error: | |
| continue | |
| # If provider doesn't support get_user_info, add directly | |
| if email is None: | |
| if provider not in final_oauth_credentials: | |
| final_oauth_credentials[provider] = [] | |
| final_oauth_credentials[provider].append(path) | |
| continue | |
| # Handle empty email | |
| if not email: | |
| logging.warning( | |
| f"Could not retrieve email for '{path}'. Treating as unique." | |
| ) | |
| if provider not in final_oauth_credentials: | |
| final_oauth_credentials[provider] = [] | |
| final_oauth_credentials[provider].append(path) | |
| continue | |
| # Deduplication check | |
| if email not in processed_emails: | |
| processed_emails[email] = {} | |
| if ( | |
| provider in processed_emails[email] | |
| and processed_emails[email][provider] != path | |
| ): | |
| original_path = processed_emails[email][provider] | |
| logging.warning( | |
| f"Duplicate for '{email}' on '{provider}' found post-init: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping." | |
| ) | |
| continue | |
| else: | |
| processed_emails[email][provider] = path | |
| if provider not in final_oauth_credentials: | |
| final_oauth_credentials[provider] = [] | |
| final_oauth_credentials[provider].append(path) | |
| # Update metadata (skip for env-based credentials - they don't have files) | |
| if not path.startswith("env://"): | |
| try: | |
| with open(path, "r+") as f: | |
| data = json.load(f) | |
| metadata = data.get("_proxy_metadata", {}) | |
| metadata["email"] = email | |
| metadata["last_check_timestamp"] = time.time() | |
| data["_proxy_metadata"] = metadata | |
| f.seek(0) | |
| json.dump(data, f, indent=2) | |
| f.truncate() | |
| except Exception as e: | |
| logging.error(f"Failed to update metadata for '{path}': {e}") | |
| logging.info("OAuth credential processing complete.") | |
| oauth_credentials = final_oauth_credentials | |
| # [NEW] Load provider-specific params | |
| litellm_provider_params = { | |
| "gemini_cli": {"project_id": os.getenv("GEMINI_CLI_PROJECT_ID")} | |
| } | |
| # The client now uses the root logger configuration | |
| client = RotatingClient( | |
| api_keys=api_keys, | |
| oauth_credentials=oauth_credentials, # Pass OAuth config | |
| configure_logging=True, | |
| litellm_provider_params=litellm_provider_params, | |
| ignore_models=ignore_models, | |
| whitelist_models=whitelist_models, | |
| enable_request_logging=ENABLE_REQUEST_LOGGING, | |
| max_concurrent_requests_per_key=max_concurrent_requests_per_key, | |
| ) | |
| # Log loaded credentials summary (compact, always visible for deployment verification) | |
| # _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none" | |
| # _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none" | |
| # _total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()]) | |
| # print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})") | |
| client.background_refresher.start() # Start the background task | |
| app.state.rotating_client = client | |
| # Warn if no provider credentials are configured | |
| if not client.all_credentials: | |
| logging.warning("=" * 70) | |
| logging.warning("⚠️ NO PROVIDER CREDENTIALS CONFIGURED") | |
| logging.warning("The proxy is running but cannot serve any LLM requests.") | |
| logging.warning( | |
| "Launch the credential tool to add API keys or OAuth credentials." | |
| ) | |
| logging.warning(" • Executable: Run with --add-credential flag") | |
| logging.warning(" • Source: python src/proxy_app/main.py --add-credential") | |
| logging.warning("=" * 70) | |
| os.environ["LITELLM_LOG"] = "ERROR" | |
| litellm.set_verbose = False | |
| litellm.drop_params = True | |
| if USE_EMBEDDING_BATCHER: | |
| batcher = EmbeddingBatcher(client=client) | |
| app.state.embedding_batcher = batcher | |
| logging.info("RotatingClient and EmbeddingBatcher initialized.") | |
| else: | |
| app.state.embedding_batcher = None | |
| logging.info("RotatingClient initialized (EmbeddingBatcher disabled).") | |
| # Start model info service in background (fetches pricing/capabilities data) | |
| # This runs asynchronously and doesn't block proxy startup | |
| model_info_service = await init_model_info_service() | |
| app.state.model_info_service = model_info_service | |
| logging.info("Model info service started (fetching pricing data in background).") | |
| yield | |
| await client.background_refresher.stop() # Stop the background task on shutdown | |
| if app.state.embedding_batcher: | |
| await app.state.embedding_batcher.stop() | |
| await client.close() | |
| # Stop model info service | |
| if hasattr(app.state, "model_info_service") and app.state.model_info_service: | |
| await app.state.model_info_service.stop() | |
| if app.state.embedding_batcher: | |
| logging.info("RotatingClient and EmbeddingBatcher closed.") | |
| else: | |
| logging.info("RotatingClient closed.") | |
| # --- FastAPI App Setup --- | |
| app = FastAPI(lifespan=lifespan) | |
| # Add CORS middleware to allow all origins, methods, and headers | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| api_key_header = APIKeyHeader(name="Authorization", auto_error=False) | |
| def get_rotating_client(request: Request) -> RotatingClient: | |
| """Dependency to get the rotating client instance from the app state.""" | |
| return request.app.state.rotating_client | |
| def get_embedding_batcher(request: Request) -> EmbeddingBatcher: | |
| """Dependency to get the embedding batcher instance from the app state.""" | |
| return request.app.state.embedding_batcher | |
| async def verify_api_key(auth: str = Depends(api_key_header)): | |
| """Dependency to verify the proxy API key.""" | |
| # If PROXY_API_KEY is not set or empty, skip verification (open access) | |
| if not PROXY_API_KEY: | |
| return auth | |
| if not auth or auth != f"Bearer {PROXY_API_KEY}": | |
| raise HTTPException(status_code=401, detail="Invalid or missing API Key") | |
| return auth | |
| async def streaming_response_wrapper( | |
| request: Request, | |
| request_data: dict, | |
| response_stream: AsyncGenerator[str, None], | |
| logger: Optional[DetailedLogger] = None, | |
| ) -> AsyncGenerator[str, None]: | |
| """ | |
| Wraps a streaming response to log the full response after completion | |
| and ensures any errors during the stream are sent to the client. | |
| """ | |
| response_chunks = [] | |
| full_response = {} | |
| try: | |
| async for chunk_str in response_stream: | |
| if await request.is_disconnected(): | |
| logging.warning("Client disconnected, stopping stream.") | |
| break | |
| yield chunk_str | |
| if chunk_str.strip() and chunk_str.startswith("data:"): | |
| content = chunk_str[len("data:") :].strip() | |
| if content != "[DONE]": | |
| try: | |
| chunk_data = json.loads(content) | |
| response_chunks.append(chunk_data) | |
| if logger: | |
| logger.log_stream_chunk(chunk_data) | |
| except json.JSONDecodeError: | |
| pass | |
| except Exception as e: | |
| logging.error(f"An error occurred during the response stream: {e}") | |
| # Yield a final error message to the client to ensure they are not left hanging. | |
| error_payload = { | |
| "error": { | |
| "message": f"An unexpected error occurred during the stream: {str(e)}", | |
| "type": "proxy_internal_error", | |
| "code": 500, | |
| } | |
| } | |
| yield f"data: {json.dumps(error_payload)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # Also log this as a failed request | |
| if logger: | |
| logger.log_final_response( | |
| status_code=500, headers=None, body={"error": str(e)} | |
| ) | |
| return # Stop further processing | |
| finally: | |
| if response_chunks: | |
| # --- Aggregation Logic --- | |
| final_message = {"role": "assistant"} | |
| aggregated_tool_calls = {} | |
| usage_data = None | |
| finish_reason = None | |
| for chunk in response_chunks: | |
| if "choices" in chunk and chunk["choices"]: | |
| choice = chunk["choices"][0] | |
| delta = choice.get("delta", {}) | |
| # Dynamically aggregate all fields from the delta | |
| for key, value in delta.items(): | |
| if value is None: | |
| continue | |
| if key == "content": | |
| if "content" not in final_message: | |
| final_message["content"] = "" | |
| if value: | |
| final_message["content"] += value | |
| elif key == "tool_calls": | |
| for tc_chunk in value: | |
| index = tc_chunk["index"] | |
| if index not in aggregated_tool_calls: | |
| aggregated_tool_calls[index] = { | |
| "type": "function", | |
| "function": {"name": "", "arguments": ""}, | |
| } | |
| # Ensure 'function' key exists for this index before accessing its sub-keys | |
| if "function" not in aggregated_tool_calls[index]: | |
| aggregated_tool_calls[index]["function"] = { | |
| "name": "", | |
| "arguments": "", | |
| } | |
| if tc_chunk.get("id"): | |
| aggregated_tool_calls[index]["id"] = tc_chunk["id"] | |
| if "function" in tc_chunk: | |
| if "name" in tc_chunk["function"]: | |
| if tc_chunk["function"]["name"] is not None: | |
| aggregated_tool_calls[index]["function"][ | |
| "name" | |
| ] += tc_chunk["function"]["name"] | |
| if "arguments" in tc_chunk["function"]: | |
| if ( | |
| tc_chunk["function"]["arguments"] | |
| is not None | |
| ): | |
| aggregated_tool_calls[index]["function"][ | |
| "arguments" | |
| ] += tc_chunk["function"]["arguments"] | |
| elif key == "function_call": | |
| if "function_call" not in final_message: | |
| final_message["function_call"] = { | |
| "name": "", | |
| "arguments": "", | |
| } | |
| if "name" in value: | |
| if value["name"] is not None: | |
| final_message["function_call"]["name"] += value[ | |
| "name" | |
| ] | |
| if "arguments" in value: | |
| if value["arguments"] is not None: | |
| final_message["function_call"]["arguments"] += ( | |
| value["arguments"] | |
| ) | |
| else: # Generic key handling for other data like 'reasoning' | |
| # FIX: Role should always replace, never concatenate | |
| if key == "role": | |
| final_message[key] = value | |
| elif key not in final_message: | |
| final_message[key] = value | |
| elif isinstance(final_message.get(key), str): | |
| final_message[key] += value | |
| else: | |
| final_message[key] = value | |
| if "finish_reason" in choice and choice["finish_reason"]: | |
| finish_reason = choice["finish_reason"] | |
| if "usage" in chunk and chunk["usage"]: | |
| usage_data = chunk["usage"] | |
| # --- Final Response Construction --- | |
| if aggregated_tool_calls: | |
| final_message["tool_calls"] = list(aggregated_tool_calls.values()) | |
| # CRITICAL FIX: Override finish_reason when tool_calls exist | |
| # This ensures OpenCode and other agentic systems continue the conversation loop | |
| finish_reason = "tool_calls" | |
| # Ensure standard fields are present for consistent logging | |
| for field in ["content", "tool_calls", "function_call"]: | |
| if field not in final_message: | |
| final_message[field] = None | |
| first_chunk = response_chunks[0] | |
| final_choice = { | |
| "index": 0, | |
| "message": final_message, | |
| "finish_reason": finish_reason, | |
| } | |
| full_response = { | |
| "id": first_chunk.get("id"), | |
| "object": "chat.completion", | |
| "created": first_chunk.get("created"), | |
| "model": first_chunk.get("model"), | |
| "choices": [final_choice], | |
| "usage": usage_data, | |
| } | |
| if logger: | |
| logger.log_final_response( | |
| status_code=200, | |
| headers=None, # Headers are not available at this stage | |
| body=full_response, | |
| ) | |
| async def chat_completions( | |
| request: Request, | |
| client: RotatingClient = Depends(get_rotating_client), | |
| _=Depends(verify_api_key), | |
| ): | |
| """ | |
| OpenAI-compatible endpoint powered by the RotatingClient. | |
| Handles both streaming and non-streaming responses and logs them. | |
| """ | |
| logger = DetailedLogger() if ENABLE_REQUEST_LOGGING else None | |
| try: | |
| # Read and parse the request body only once at the beginning. | |
| try: | |
| request_data = await request.json() | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid JSON in request body.") | |
| # Global temperature=0 override (controlled by .env variable, default: OFF) | |
| # Low temperature makes models deterministic and prone to following training data | |
| # instead of actual schemas, which can cause tool hallucination | |
| # Modes: "remove" = delete temperature key, "set" = change to 1.0, "false" = disabled | |
| override_temp_zero = os.getenv("OVERRIDE_TEMPERATURE_ZERO", "false").lower() | |
| if ( | |
| override_temp_zero in ("remove", "set", "true", "1", "yes") | |
| and "temperature" in request_data | |
| and request_data["temperature"] == 0 | |
| ): | |
| if override_temp_zero == "remove": | |
| # Remove temperature key entirely | |
| del request_data["temperature"] | |
| logging.debug( | |
| "OVERRIDE_TEMPERATURE_ZERO=remove: Removed temperature=0 from request" | |
| ) | |
| else: | |
| # Set to 1.0 (for "set", "true", "1", "yes") | |
| request_data["temperature"] = 1.0 | |
| logging.debug( | |
| "OVERRIDE_TEMPERATURE_ZERO=set: Converting temperature=0 to temperature=1.0" | |
| ) | |
| # If logging is enabled, perform all logging operations using the parsed data. | |
| if logger: | |
| logger.log_request(headers=request.headers, body=request_data) | |
| # Extract and log specific reasoning parameters for monitoring. | |
| model = request_data.get("model") | |
| generation_cfg = ( | |
| request_data.get("generationConfig", {}) | |
| or request_data.get("generation_config", {}) | |
| or {} | |
| ) | |
| reasoning_effort = request_data.get("reasoning_effort") or generation_cfg.get( | |
| "reasoning_effort" | |
| ) | |
| custom_reasoning_budget = request_data.get( | |
| "custom_reasoning_budget" | |
| ) or generation_cfg.get("custom_reasoning_budget", False) | |
| logging.getLogger("rotator_library").debug( | |
| f"Handling reasoning parameters: model={model}, reasoning_effort={reasoning_effort}, custom_reasoning_budget={custom_reasoning_budget}" | |
| ) | |
| # Log basic request info to console (this is a separate, simpler logger). | |
| log_request_to_console( | |
| url=str(request.url), | |
| headers=dict(request.headers), | |
| client_info=(request.client.host, request.client.port), | |
| request_data=request_data, | |
| ) | |
| is_streaming = request_data.get("stream", False) | |
| if is_streaming: | |
| response_generator = client.acompletion(request=request, **request_data) | |
| return StreamingResponse( | |
| streaming_response_wrapper( | |
| request, request_data, response_generator, logger | |
| ), | |
| media_type="text/event-stream", | |
| ) | |
| else: | |
| response = await client.acompletion(request=request, **request_data) | |
| if logger: | |
| # Assuming response has status_code and headers attributes | |
| # This might need adjustment based on the actual response object | |
| response_headers = ( | |
| response.headers if hasattr(response, "headers") else None | |
| ) | |
| status_code = ( | |
| response.status_code if hasattr(response, "status_code") else 200 | |
| ) | |
| logger.log_final_response( | |
| status_code=status_code, | |
| headers=response_headers, | |
| body=response.model_dump(), | |
| ) | |
| return response | |
| except ( | |
| litellm.InvalidRequestError, | |
| ValueError, | |
| litellm.ContextWindowExceededError, | |
| ) as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid Request: {str(e)}") | |
| except litellm.AuthenticationError as e: | |
| raise HTTPException(status_code=401, detail=f"Authentication Error: {str(e)}") | |
| except litellm.RateLimitError as e: | |
| raise HTTPException(status_code=429, detail=f"Rate Limit Exceeded: {str(e)}") | |
| except (litellm.ServiceUnavailableError, litellm.APIConnectionError) as e: | |
| raise HTTPException(status_code=503, detail=f"Service Unavailable: {str(e)}") | |
| except litellm.Timeout as e: | |
| raise HTTPException(status_code=504, detail=f"Gateway Timeout: {str(e)}") | |
| except (litellm.InternalServerError, litellm.OpenAIError) as e: | |
| raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}") | |
| except Exception as e: | |
| logging.error(f"Request failed after all retries: {e}") | |
| # Optionally log the failed request | |
| if ENABLE_REQUEST_LOGGING: | |
| try: | |
| request_data = await request.json() | |
| except json.JSONDecodeError: | |
| request_data = {"error": "Could not parse request body"} | |
| if logger: | |
| logger.log_final_response( | |
| status_code=500, headers=None, body={"error": str(e)} | |
| ) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def embeddings( | |
| request: Request, | |
| body: EmbeddingRequest, | |
| client: RotatingClient = Depends(get_rotating_client), | |
| batcher: Optional[EmbeddingBatcher] = Depends(get_embedding_batcher), | |
| _=Depends(verify_api_key), | |
| ): | |
| """ | |
| OpenAI-compatible endpoint for creating embeddings. | |
| Supports two modes based on the USE_EMBEDDING_BATCHER flag: | |
| - True: Uses a server-side batcher for high throughput. | |
| - False: Passes requests directly to the provider. | |
| """ | |
| try: | |
| request_data = body.model_dump(exclude_none=True) | |
| log_request_to_console( | |
| url=str(request.url), | |
| headers=dict(request.headers), | |
| client_info=(request.client.host, request.client.port), | |
| request_data=request_data, | |
| ) | |
| if USE_EMBEDDING_BATCHER and batcher: | |
| # --- Server-Side Batching Logic --- | |
| request_data = body.model_dump(exclude_none=True) | |
| inputs = request_data.get("input", []) | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| tasks = [] | |
| for single_input in inputs: | |
| individual_request = request_data.copy() | |
| individual_request["input"] = single_input | |
| tasks.append(batcher.add_request(individual_request)) | |
| results = await asyncio.gather(*tasks) | |
| all_data = [] | |
| total_prompt_tokens = 0 | |
| total_tokens = 0 | |
| for i, result in enumerate(results): | |
| result["data"][0]["index"] = i | |
| all_data.extend(result["data"]) | |
| total_prompt_tokens += result["usage"]["prompt_tokens"] | |
| total_tokens += result["usage"]["total_tokens"] | |
| final_response_data = { | |
| "object": "list", | |
| "model": results[0]["model"], | |
| "data": all_data, | |
| "usage": { | |
| "prompt_tokens": total_prompt_tokens, | |
| "total_tokens": total_tokens, | |
| }, | |
| } | |
| response = litellm.EmbeddingResponse(**final_response_data) | |
| else: | |
| # --- Direct Pass-Through Logic --- | |
| request_data = body.model_dump(exclude_none=True) | |
| if isinstance(request_data.get("input"), str): | |
| request_data["input"] = [request_data["input"]] | |
| response = await client.aembedding(request=request, **request_data) | |
| return response | |
| except HTTPException as e: | |
| # Re-raise HTTPException to ensure it's not caught by the generic Exception handler | |
| raise e | |
| except ( | |
| litellm.InvalidRequestError, | |
| ValueError, | |
| litellm.ContextWindowExceededError, | |
| ) as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid Request: {str(e)}") | |
| except litellm.AuthenticationError as e: | |
| raise HTTPException(status_code=401, detail=f"Authentication Error: {str(e)}") | |
| except litellm.RateLimitError as e: | |
| raise HTTPException(status_code=429, detail=f"Rate Limit Exceeded: {str(e)}") | |
| except (litellm.ServiceUnavailableError, litellm.APIConnectionError) as e: | |
| raise HTTPException(status_code=503, detail=f"Service Unavailable: {str(e)}") | |
| except litellm.Timeout as e: | |
| raise HTTPException(status_code=504, detail=f"Gateway Timeout: {str(e)}") | |
| except (litellm.InternalServerError, litellm.OpenAIError) as e: | |
| raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}") | |
| except Exception as e: | |
| logging.error(f"Embedding request failed: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def read_root(): | |
| return {"Status": "API Key Proxy is running"} | |
| async def list_models( | |
| request: Request, | |
| client: RotatingClient = Depends(get_rotating_client), | |
| _=Depends(verify_api_key), | |
| enriched: bool = True, | |
| ): | |
| """ | |
| Returns a list of available models in the OpenAI-compatible format. | |
| Query Parameters: | |
| enriched: If True (default), returns detailed model info with pricing and capabilities. | |
| If False, returns minimal OpenAI-compatible response. | |
| """ | |
| model_ids = await client.get_all_available_models(grouped=False) | |
| if enriched and hasattr(request.app.state, "model_info_service"): | |
| model_info_service = request.app.state.model_info_service | |
| if model_info_service.is_ready: | |
| # Return enriched model data | |
| enriched_data = model_info_service.enrich_model_list(model_ids) | |
| return {"object": "list", "data": enriched_data} | |
| # Fallback to basic model cards | |
| model_cards = [ | |
| { | |
| "id": model_id, | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": "Mirro-Proxy", | |
| } | |
| for model_id in model_ids | |
| ] | |
| return {"object": "list", "data": model_cards} | |
| async def get_model( | |
| model_id: str, | |
| request: Request, | |
| _=Depends(verify_api_key), | |
| ): | |
| """ | |
| Returns detailed information about a specific model. | |
| Path Parameters: | |
| model_id: The model ID (e.g., "anthropic/claude-3-opus", "openrouter/openai/gpt-4") | |
| """ | |
| if hasattr(request.app.state, "model_info_service"): | |
| model_info_service = request.app.state.model_info_service | |
| if model_info_service.is_ready: | |
| info = model_info_service.get_model_info(model_id) | |
| if info: | |
| return info.to_dict() | |
| # Return basic info if service not ready or model not found | |
| return { | |
| "id": model_id, | |
| "object": "model", | |
| "created": int(time.time()), | |
| "owned_by": model_id.split("/")[0] if "/" in model_id else "unknown", | |
| } | |
| async def model_info_stats( | |
| request: Request, | |
| _=Depends(verify_api_key), | |
| ): | |
| """ | |
| Returns statistics about the model info service (for monitoring/debugging). | |
| """ | |
| if hasattr(request.app.state, "model_info_service"): | |
| return request.app.state.model_info_service.get_stats() | |
| return {"error": "Model info service not initialized"} | |
| async def list_providers(_=Depends(verify_api_key)): | |
| """ | |
| Returns a list of all available providers. | |
| """ | |
| return list(PROVIDER_PLUGINS.keys()) | |
| async def get_quota_stats( | |
| request: Request, | |
| client: RotatingClient = Depends(get_rotating_client), | |
| _=Depends(verify_api_key), | |
| provider: str = None, | |
| ): | |
| """ | |
| Returns quota and usage statistics for all credentials. | |
| This returns cached data from the proxy without making external API calls. | |
| Use POST to reload from disk or force refresh from external APIs. | |
| Query Parameters: | |
| provider: Optional filter to return stats for a specific provider only | |
| Returns: | |
| { | |
| "providers": { | |
| "provider_name": { | |
| "credential_count": int, | |
| "active_count": int, | |
| "on_cooldown_count": int, | |
| "exhausted_count": int, | |
| "total_requests": int, | |
| "tokens": {...}, | |
| "approx_cost": float | null, | |
| "quota_groups": {...}, // For Antigravity | |
| "credentials": [...] | |
| } | |
| }, | |
| "summary": {...}, | |
| "data_source": "cache", | |
| "timestamp": float | |
| } | |
| """ | |
| try: | |
| stats = await client.get_quota_stats(provider_filter=provider) | |
| return stats | |
| except Exception as e: | |
| logging.error(f"Failed to get quota stats: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def refresh_quota_stats( | |
| request: Request, | |
| client: RotatingClient = Depends(get_rotating_client), | |
| _=Depends(verify_api_key), | |
| ): | |
| """ | |
| Refresh quota and usage statistics. | |
| Request body: | |
| { | |
| "action": "reload" | "force_refresh", | |
| "scope": "all" | "provider" | "credential", | |
| "provider": "antigravity", // required if scope != "all" | |
| "credential": "antigravity_oauth_1.json" // required if scope == "credential" | |
| } | |
| Actions: | |
| - reload: Re-read data from disk (no external API calls) | |
| - force_refresh: For Antigravity, fetch live quota from API. | |
| For other providers, same as reload. | |
| Returns: | |
| Same as GET, plus a "refresh_result" field with operation details. | |
| """ | |
| try: | |
| data = await request.json() | |
| action = data.get("action", "reload") | |
| scope = data.get("scope", "all") | |
| provider = data.get("provider") | |
| credential = data.get("credential") | |
| # Validate parameters | |
| if action not in ("reload", "force_refresh"): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="action must be 'reload' or 'force_refresh'", | |
| ) | |
| if scope not in ("all", "provider", "credential"): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="scope must be 'all', 'provider', or 'credential'", | |
| ) | |
| if scope in ("provider", "credential") and not provider: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="'provider' is required when scope is 'provider' or 'credential'", | |
| ) | |
| if scope == "credential" and not credential: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="'credential' is required when scope is 'credential'", | |
| ) | |
| refresh_result = { | |
| "action": action, | |
| "scope": scope, | |
| "provider": provider, | |
| "credential": credential, | |
| } | |
| if action == "reload": | |
| # Just reload from disk | |
| start_time = time.time() | |
| await client.reload_usage_from_disk() | |
| refresh_result["duration_ms"] = int((time.time() - start_time) * 1000) | |
| refresh_result["success"] = True | |
| refresh_result["message"] = "Reloaded usage data from disk" | |
| elif action == "force_refresh": | |
| # Force refresh from external API (for supported providers like Antigravity) | |
| result = await client.force_refresh_quota( | |
| provider=provider if scope in ("provider", "credential") else None, | |
| credential=credential if scope == "credential" else None, | |
| ) | |
| refresh_result.update(result) | |
| refresh_result["success"] = result["failed_count"] == 0 | |
| # Get updated stats | |
| stats = await client.get_quota_stats(provider_filter=provider) | |
| stats["refresh_result"] = refresh_result | |
| stats["data_source"] = "refreshed" | |
| return stats | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.error(f"Failed to refresh quota stats: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def token_count( | |
| request: Request, | |
| client: RotatingClient = Depends(get_rotating_client), | |
| _=Depends(verify_api_key), | |
| ): | |
| """ | |
| Calculates the token count for a given list of messages and a model. | |
| """ | |
| try: | |
| data = await request.json() | |
| model = data.get("model") | |
| messages = data.get("messages") | |
| if not model or not messages: | |
| raise HTTPException( | |
| status_code=400, detail="'model' and 'messages' are required." | |
| ) | |
| count = client.token_count(**data) | |
| return {"token_count": count} | |
| except Exception as e: | |
| logging.error(f"Token count failed: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def cost_estimate(request: Request, _=Depends(verify_api_key)): | |
| """ | |
| Estimates the cost for a request based on token counts and model pricing. | |
| Request body: | |
| { | |
| "model": "anthropic/claude-3-opus", | |
| "prompt_tokens": 1000, | |
| "completion_tokens": 500, | |
| "cache_read_tokens": 0, # optional | |
| "cache_creation_tokens": 0 # optional | |
| } | |
| Returns: | |
| { | |
| "model": "anthropic/claude-3-opus", | |
| "cost": 0.0375, | |
| "currency": "USD", | |
| "pricing": { | |
| "input_cost_per_token": 0.000015, | |
| "output_cost_per_token": 0.000075 | |
| }, | |
| "source": "model_info_service" # or "litellm_fallback" | |
| } | |
| """ | |
| try: | |
| data = await request.json() | |
| model = data.get("model") | |
| prompt_tokens = data.get("prompt_tokens", 0) | |
| completion_tokens = data.get("completion_tokens", 0) | |
| cache_read_tokens = data.get("cache_read_tokens", 0) | |
| cache_creation_tokens = data.get("cache_creation_tokens", 0) | |
| if not model: | |
| raise HTTPException(status_code=400, detail="'model' is required.") | |
| result = { | |
| "model": model, | |
| "cost": None, | |
| "currency": "USD", | |
| "pricing": {}, | |
| "source": None, | |
| } | |
| # Try model info service first | |
| if hasattr(request.app.state, "model_info_service"): | |
| model_info_service = request.app.state.model_info_service | |
| if model_info_service.is_ready: | |
| cost = model_info_service.calculate_cost( | |
| model, | |
| prompt_tokens, | |
| completion_tokens, | |
| cache_read_tokens, | |
| cache_creation_tokens, | |
| ) | |
| if cost is not None: | |
| cost_info = model_info_service.get_cost_info(model) | |
| result["cost"] = cost | |
| result["pricing"] = cost_info or {} | |
| result["source"] = "model_info_service" | |
| return result | |
| # Fallback to litellm | |
| try: | |
| import litellm | |
| # Create a mock response for cost calculation | |
| model_info = litellm.get_model_info(model) | |
| input_cost = model_info.get("input_cost_per_token", 0) | |
| output_cost = model_info.get("output_cost_per_token", 0) | |
| if input_cost or output_cost: | |
| cost = (prompt_tokens * input_cost) + (completion_tokens * output_cost) | |
| result["cost"] = cost | |
| result["pricing"] = { | |
| "input_cost_per_token": input_cost, | |
| "output_cost_per_token": output_cost, | |
| } | |
| result["source"] = "litellm_fallback" | |
| return result | |
| except Exception: | |
| pass | |
| result["source"] = "unknown" | |
| result["error"] = "Pricing data not available for this model" | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.error(f"Cost estimate failed: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| # Define ENV_FILE for onboarding checks using centralized path | |
| ENV_FILE = get_data_file(".env") | |
| # Check if launcher TUI should be shown (no arguments provided) | |
| if len(sys.argv) == 1: | |
| # No arguments - show launcher TUI (lazy import) | |
| from proxy_app.launcher_tui import run_launcher_tui | |
| run_launcher_tui() | |
| # Launcher modifies sys.argv and returns, or exits if user chose Exit | |
| # If we get here, user chose "Run Proxy" and sys.argv is modified | |
| # Re-parse arguments with modified sys.argv | |
| args = parser.parse_args() | |
| def needs_onboarding() -> bool: | |
| """ | |
| Check if the proxy needs onboarding (first-time setup). | |
| Returns True if onboarding is needed, False otherwise. | |
| """ | |
| # Only check if .env file exists | |
| # PROXY_API_KEY is optional (will show warning if not set) | |
| if not ENV_FILE.is_file(): | |
| return True | |
| return False | |
| def show_onboarding_message(): | |
| """Display clear explanatory message for why onboarding is needed.""" | |
| os.system( | |
| "cls" if os.name == "nt" else "clear" | |
| ) # Clear terminal for clean presentation | |
| console.print( | |
| Panel.fit( | |
| "[bold cyan]🚀 LLM API Key Proxy - First Time Setup[/bold cyan]", | |
| border_style="cyan", | |
| ) | |
| ) | |
| console.print("[bold yellow]⚠️ Configuration Required[/bold yellow]\n") | |
| console.print("The proxy needs initial configuration:") | |
| console.print(" [red]❌ No .env file found[/red]") | |
| console.print("\n[bold]Why this matters:[/bold]") | |
| console.print(" • The .env file stores your credentials and settings") | |
| console.print(" • PROXY_API_KEY protects your proxy from unauthorized access") | |
| console.print(" • Provider API keys enable LLM access") | |
| console.print("\n[bold]What happens next:[/bold]") | |
| console.print(" 1. We'll create a .env file with PROXY_API_KEY") | |
| console.print(" 2. You can add LLM provider credentials (API keys or OAuth)") | |
| console.print(" 3. The proxy will then start normally") | |
| console.print( | |
| "\n[bold yellow]⚠️ Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default." | |
| ) | |
| console.print(" You can remove it later if you want an unsecured proxy.\n") | |
| console.input( | |
| "[bold green]Press Enter to launch the credential setup tool...[/bold green]" | |
| ) | |
| # Check if user explicitly wants to add credentials | |
| if args.add_credential: | |
| # Import and call ensure_env_defaults to create .env and PROXY_API_KEY if needed | |
| from rotator_library.credential_tool import ensure_env_defaults | |
| ensure_env_defaults() | |
| # Reload environment variables after ensure_env_defaults creates/updates .env | |
| load_dotenv(ENV_FILE, override=True) | |
| run_credential_tool() | |
| else: | |
| # Check if onboarding is needed | |
| if needs_onboarding(): | |
| # Import console from rich for better messaging | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| console = Console() | |
| # Show clear explanatory message | |
| show_onboarding_message() | |
| # Launch credential tool automatically | |
| from rotator_library.credential_tool import ensure_env_defaults | |
| ensure_env_defaults() | |
| load_dotenv(ENV_FILE, override=True) | |
| run_credential_tool() | |
| # After credential tool exits, reload and re-check | |
| load_dotenv(ENV_FILE, override=True) | |
| # Re-read PROXY_API_KEY from environment | |
| PROXY_API_KEY = os.getenv("PROXY_API_KEY") | |
| # Verify onboarding is complete | |
| if needs_onboarding(): | |
| console.print("\n[bold red]❌ Configuration incomplete.[/bold red]") | |
| console.print( | |
| "The proxy still cannot start. Please ensure PROXY_API_KEY is set in .env\n" | |
| ) | |
| sys.exit(1) | |
| else: | |
| console.print("\n[bold green]✅ Configuration complete![/bold green]") | |
| console.print("\nStarting proxy server...\n") | |
| import uvicorn | |
| uvicorn.run(app, host=args.host, port=args.port) | |