import time # Phase 1: Minimal imports for arg parsing and TUI import asyncio import os from pathlib import Path import sys import argparse import logging # --- Argument Parsing (BEFORE heavy imports) --- parser = argparse.ArgumentParser(description="API Key Proxy Server") parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to bind the server to." ) parser.add_argument("--port", type=int, default=8000, help="Port to run the server on.") parser.add_argument( "--enable-request-logging", action="store_true", help="Enable request logging." ) parser.add_argument( "--add-credential", action="store_true", help="Launch the interactive tool to add a new OAuth credential.", ) args, _ = parser.parse_known_args() # Add the 'src' directory to the Python path sys.path.append(str(Path(__file__).resolve().parent.parent)) # Check if we should launch TUI (no arguments = TUI mode) if len(sys.argv) == 1: # TUI MODE - Load ONLY what's needed for the launcher (fast path!) from proxy_app.launcher_tui import run_launcher_tui run_launcher_tui() # Launcher modifies sys.argv and returns, or exits if user chose Exit # If we get here, user chose "Run Proxy" and sys.argv is modified # Re-parse arguments with modified sys.argv args = parser.parse_args() # Check if credential tool mode (also doesn't need heavy proxy imports) if args.add_credential: from rotator_library.credential_tool import run_credential_tool run_credential_tool() sys.exit(0) # If we get here, we're ACTUALLY running the proxy - NOW show startup messages and start timer _start_time = time.time() # Load all .env files from root folder (main .env first, then any additional *.env files) from dotenv import load_dotenv from glob import glob # Get the application root directory (EXE dir if frozen, else CWD) # Inlined here to avoid triggering heavy rotator_library imports before loading screen if getattr(sys, "frozen", False): _root_dir = Path(sys.executable).parent else: _root_dir = Path.cwd() # Load main .env first load_dotenv(_root_dir / ".env") # Load any additional .env files (e.g., antigravity_all_combined.env, gemini_cli_all_combined.env) _env_files_found = list(_root_dir.glob("*.env")) for _env_file in sorted(_root_dir.glob("*.env")): if _env_file.name != ".env": # Skip main .env (already loaded) load_dotenv(_env_file, override=False) # Don't override existing values # Log discovered .env files for deployment verification if _env_files_found: _env_names = [_ef.name for _ef in _env_files_found] print(f"📁 Loaded {len(_env_files_found)} .env file(s): {', '.join(_env_names)}") # Get proxy API key for display proxy_api_key = os.getenv("PROXY_API_KEY") if proxy_api_key: key_display = f"✓ {proxy_api_key}" else: key_display = "✗ Not Set (INSECURE - anyone can access!)" print("━" * 70) print(f"Starting proxy on {args.host}:{args.port}") print(f"Proxy API Key: {key_display}") print(f"GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy") print("━" * 70) print("Loading server components...") # Phase 2: Load Rich for loading spinner (lightweight) from rich.console import Console _console = Console() # Phase 3: Heavy dependencies with granular loading messages print(" → Loading FastAPI framework...") with _console.status("[dim]Loading FastAPI framework...", spinner="dots"): from contextlib import asynccontextmanager from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi.security import APIKeyHeader print(" → Loading core dependencies...") with _console.status("[dim]Loading core dependencies...", spinner="dots"): from dotenv import load_dotenv import colorlog import json from typing import AsyncGenerator, Any, List, Optional, Union from pydantic import BaseModel, Field # --- Early Log Level Configuration --- logging.getLogger("LiteLLM").setLevel(logging.WARNING) print(" → Loading LiteLLM library...") with _console.status("[dim]Loading LiteLLM library...", spinner="dots"): import litellm # Phase 4: Application imports with granular loading messages print(" → Initializing proxy core...") with _console.status("[dim]Initializing proxy core...", spinner="dots"): from rotator_library import RotatingClient from rotator_library.credential_manager import CredentialManager from rotator_library.background_refresher import BackgroundRefresher from rotator_library.model_info_service import init_model_info_service from proxy_app.request_logger import log_request_to_console from proxy_app.batch_manager import EmbeddingBatcher from proxy_app.detailed_logger import DetailedLogger print(" → Discovering provider plugins...") # Provider lazy loading happens during import, so time it here _provider_start = time.time() with _console.status("[dim]Discovering provider plugins...", spinner="dots"): from rotator_library import ( PROVIDER_PLUGINS, ) # This triggers lazy load via __getattr__ _provider_time = time.time() - _provider_start # Get count after import (without timing to avoid double-counting) _plugin_count = len(PROVIDER_PLUGINS) # --- Pydantic Models --- class EmbeddingRequest(BaseModel): model: str input: Union[str, List[str]] input_type: Optional[str] = None dimensions: Optional[int] = None user: Optional[str] = None class ModelCard(BaseModel): """Basic model card for minimal response.""" id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "Mirro-Proxy" class ModelCapabilities(BaseModel): """Model capability flags.""" tool_choice: bool = False function_calling: bool = False reasoning: bool = False vision: bool = False system_messages: bool = True prompt_caching: bool = False assistant_prefill: bool = False class EnrichedModelCard(BaseModel): """Extended model card with pricing and capabilities.""" id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "unknown" # Pricing (optional - may not be available for all models) input_cost_per_token: Optional[float] = None output_cost_per_token: Optional[float] = None cache_read_input_token_cost: Optional[float] = None cache_creation_input_token_cost: Optional[float] = None # Limits (optional) max_input_tokens: Optional[int] = None max_output_tokens: Optional[int] = None context_window: Optional[int] = None # Capabilities mode: str = "chat" supported_modalities: List[str] = Field(default_factory=lambda: ["text"]) supported_output_modalities: List[str] = Field(default_factory=lambda: ["text"]) capabilities: Optional[ModelCapabilities] = None # Debug info (optional) _sources: Optional[List[str]] = None _match_type: Optional[str] = None class Config: extra = "allow" # Allow extra fields from the service class ModelList(BaseModel): """List of models response.""" object: str = "list" data: List[ModelCard] class EnrichedModelList(BaseModel): """List of enriched models with pricing and capabilities.""" object: str = "list" data: List[EnrichedModelCard] # Calculate total loading time _elapsed = time.time() - _start_time print( f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)" ) # Clear screen and reprint header for clean startup view # This pushes loading messages up (still in scroll history) but shows a clean final screen import os as _os_module _os_module.system("cls" if _os_module.name == "nt" else "clear") # Reprint header print("━" * 70) print(f"Starting proxy on {args.host}:{args.port}") print(f"Proxy API Key: {key_display}") print(f"GitHub: https://github.com/Mirrowel/LLM-API-Key-Proxy") print("━" * 70) print( f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)" ) # Note: Debug logging will be added after logging configuration below # --- Logging Configuration --- # Import path utilities here (after loading screen) to avoid triggering heavy imports early from rotator_library.utils.paths import get_logs_dir, get_data_file LOG_DIR = get_logs_dir(_root_dir) # Configure a console handler with color (INFO and above only, no DEBUG) console_handler = colorlog.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) formatter = colorlog.ColoredFormatter( "%(log_color)s%(message)s", log_colors={ "DEBUG": "cyan", "INFO": "green", "WARNING": "yellow", "ERROR": "red", "CRITICAL": "red,bg_white", }, ) console_handler.setFormatter(formatter) # Configure a file handler for INFO-level logs and higher info_file_handler = logging.FileHandler(LOG_DIR / "proxy.log", encoding="utf-8") info_file_handler.setLevel(logging.INFO) info_file_handler.setFormatter( logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ) # Configure a dedicated file handler for all DEBUG-level logs debug_file_handler = logging.FileHandler(LOG_DIR / "proxy_debug.log", encoding="utf-8") debug_file_handler.setLevel(logging.DEBUG) debug_file_handler.setFormatter( logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ) # Create a filter to ensure the debug handler ONLY gets DEBUG messages from the rotator_library class RotatorDebugFilter(logging.Filter): def filter(self, record): return record.levelno == logging.DEBUG and record.name.startswith( "rotator_library" ) debug_file_handler.addFilter(RotatorDebugFilter()) # Configure a console handler with color console_handler = colorlog.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) formatter = colorlog.ColoredFormatter( "%(log_color)s%(message)s", log_colors={ "DEBUG": "cyan", "INFO": "green", "WARNING": "yellow", "ERROR": "red", "CRITICAL": "red,bg_white", }, ) console_handler.setFormatter(formatter) # Add a filter to prevent any LiteLLM logs from cluttering the console class NoLiteLLMLogFilter(logging.Filter): def filter(self, record): return not record.name.startswith("LiteLLM") console_handler.addFilter(NoLiteLLMLogFilter()) # Get the root logger and set it to DEBUG to capture all messages root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) # Add all handlers to the root logger root_logger.addHandler(info_file_handler) root_logger.addHandler(console_handler) root_logger.addHandler(debug_file_handler) # Silence other noisy loggers by setting their level higher than root logging.getLogger("uvicorn").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING) # Isolate LiteLLM's logger to prevent it from reaching the console. # We will capture its logs via the logger_fn callback in the client instead. litellm_logger = logging.getLogger("LiteLLM") litellm_logger.handlers = [] litellm_logger.propagate = False # Now that logging is configured, log the module load time to debug file only logging.debug(f"Modules loaded in {_elapsed:.2f}s") # Load environment variables from .env file load_dotenv(_root_dir / ".env") # --- Configuration --- USE_EMBEDDING_BATCHER = False ENABLE_REQUEST_LOGGING = args.enable_request_logging if ENABLE_REQUEST_LOGGING: logging.info("Request logging is enabled.") PROXY_API_KEY = os.getenv("PROXY_API_KEY") # Note: PROXY_API_KEY validation moved to server startup to allow credential tool to run first # Discover API keys from environment variables api_keys = {} for key, value in os.environ.items(): if "_API_KEY" in key and key != "PROXY_API_KEY": provider = key.split("_API_KEY")[0].lower() if provider not in api_keys: api_keys[provider] = [] api_keys[provider].append(value) # Load model ignore lists from environment variables ignore_models = {} for key, value in os.environ.items(): if key.startswith("IGNORE_MODELS_"): provider = key.replace("IGNORE_MODELS_", "").lower() models_to_ignore = [ model.strip() for model in value.split(",") if model.strip() ] ignore_models[provider] = models_to_ignore logging.debug( f"Loaded ignore list for provider '{provider}': {models_to_ignore}" ) # Load model whitelist from environment variables whitelist_models = {} for key, value in os.environ.items(): if key.startswith("WHITELIST_MODELS_"): provider = key.replace("WHITELIST_MODELS_", "").lower() models_to_whitelist = [ model.strip() for model in value.split(",") if model.strip() ] whitelist_models[provider] = models_to_whitelist logging.debug( f"Loaded whitelist for provider '{provider}': {models_to_whitelist}" ) # Load max concurrent requests per key from environment variables max_concurrent_requests_per_key = {} for key, value in os.environ.items(): if key.startswith("MAX_CONCURRENT_REQUESTS_PER_KEY_"): provider = key.replace("MAX_CONCURRENT_REQUESTS_PER_KEY_", "").lower() try: max_concurrent = int(value) if max_concurrent < 1: logging.warning( f"Invalid max_concurrent value for provider '{provider}': {value}. Must be >= 1. Using default (1)." ) max_concurrent = 1 max_concurrent_requests_per_key[provider] = max_concurrent logging.debug( f"Loaded max concurrent requests for provider '{provider}': {max_concurrent}" ) except ValueError: logging.warning( f"Invalid max_concurrent value for provider '{provider}': {value}. Using default (1)." ) # --- Lifespan Management --- @asynccontextmanager async def lifespan(app: FastAPI): """Manage the RotatingClient's lifecycle with the app's lifespan.""" # [MODIFIED] Perform skippable OAuth initialization at startup skip_oauth_init = os.getenv("SKIP_OAUTH_INIT_CHECK", "false").lower() == "true" # The CredentialManager now handles all discovery, including .env overrides. # We pass all environment variables to it for this purpose. cred_manager = CredentialManager(os.environ) oauth_credentials = cred_manager.discover_and_prepare() if not skip_oauth_init and oauth_credentials: logging.info("Starting OAuth credential validation and deduplication...") processed_emails = {} # email -> {provider: path} credentials_to_initialize = {} # provider -> [paths] final_oauth_credentials = {} # --- Pass 1: Pre-initialization Scan & Deduplication --- # logging.info("Pass 1: Scanning for existing metadata to find duplicates...") for provider, paths in oauth_credentials.items(): if provider not in credentials_to_initialize: credentials_to_initialize[provider] = [] for path in paths: # Skip env-based credentials (virtual paths) - they don't have metadata files if path.startswith("env://"): credentials_to_initialize[provider].append(path) continue try: with open(path, "r") as f: data = json.load(f) metadata = data.get("_proxy_metadata", {}) email = metadata.get("email") if email: if email not in processed_emails: processed_emails[email] = {} if provider in processed_emails[email]: original_path = processed_emails[email][provider] logging.warning( f"Duplicate for '{email}' on '{provider}' found in pre-scan: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping." ) continue else: processed_emails[email][provider] = path credentials_to_initialize[provider].append(path) except (FileNotFoundError, json.JSONDecodeError) as e: logging.warning( f"Could not pre-read metadata from '{path}': {e}. Will process during initialization." ) credentials_to_initialize[provider].append(path) # --- Pass 2: Parallel Initialization of Filtered Credentials --- # logging.info("Pass 2: Initializing unique credentials and performing final check...") async def process_credential(provider: str, path: str, provider_instance): """Process a single credential: initialize and fetch user info.""" try: await provider_instance.initialize_token(path) if not hasattr(provider_instance, "get_user_info"): return (provider, path, None, None) user_info = await provider_instance.get_user_info(path) email = user_info.get("email") return (provider, path, email, None) except Exception as e: logging.error( f"Failed to process OAuth token for {provider} at '{path}': {e}" ) return (provider, path, None, e) # Collect all tasks for parallel execution tasks = [] for provider, paths in credentials_to_initialize.items(): if not paths: continue provider_plugin_class = PROVIDER_PLUGINS.get(provider) if not provider_plugin_class: continue provider_instance = provider_plugin_class() for path in paths: tasks.append(process_credential(provider, path, provider_instance)) # Execute all credential processing tasks in parallel results = await asyncio.gather(*tasks, return_exceptions=True) # --- Pass 3: Sequential Deduplication and Final Assembly --- for result in results: # Handle exceptions from gather if isinstance(result, Exception): logging.error(f"Credential processing raised exception: {result}") continue provider, path, email, error = result # Skip if there was an error if error: continue # If provider doesn't support get_user_info, add directly if email is None: if provider not in final_oauth_credentials: final_oauth_credentials[provider] = [] final_oauth_credentials[provider].append(path) continue # Handle empty email if not email: logging.warning( f"Could not retrieve email for '{path}'. Treating as unique." ) if provider not in final_oauth_credentials: final_oauth_credentials[provider] = [] final_oauth_credentials[provider].append(path) continue # Deduplication check if email not in processed_emails: processed_emails[email] = {} if ( provider in processed_emails[email] and processed_emails[email][provider] != path ): original_path = processed_emails[email][provider] logging.warning( f"Duplicate for '{email}' on '{provider}' found post-init: '{Path(path).name}'. Original: '{Path(original_path).name}'. Skipping." ) continue else: processed_emails[email][provider] = path if provider not in final_oauth_credentials: final_oauth_credentials[provider] = [] final_oauth_credentials[provider].append(path) # Update metadata (skip for env-based credentials - they don't have files) if not path.startswith("env://"): try: with open(path, "r+") as f: data = json.load(f) metadata = data.get("_proxy_metadata", {}) metadata["email"] = email metadata["last_check_timestamp"] = time.time() data["_proxy_metadata"] = metadata f.seek(0) json.dump(data, f, indent=2) f.truncate() except Exception as e: logging.error(f"Failed to update metadata for '{path}': {e}") logging.info("OAuth credential processing complete.") oauth_credentials = final_oauth_credentials # [NEW] Load provider-specific params litellm_provider_params = { "gemini_cli": {"project_id": os.getenv("GEMINI_CLI_PROJECT_ID")} } # The client now uses the root logger configuration client = RotatingClient( api_keys=api_keys, oauth_credentials=oauth_credentials, # Pass OAuth config configure_logging=True, litellm_provider_params=litellm_provider_params, ignore_models=ignore_models, whitelist_models=whitelist_models, enable_request_logging=ENABLE_REQUEST_LOGGING, max_concurrent_requests_per_key=max_concurrent_requests_per_key, ) # Log loaded credentials summary (compact, always visible for deployment verification) # _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none" # _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none" # _total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()]) # print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})") client.background_refresher.start() # Start the background task app.state.rotating_client = client # Warn if no provider credentials are configured if not client.all_credentials: logging.warning("=" * 70) logging.warning("⚠️ NO PROVIDER CREDENTIALS CONFIGURED") logging.warning("The proxy is running but cannot serve any LLM requests.") logging.warning( "Launch the credential tool to add API keys or OAuth credentials." ) logging.warning(" • Executable: Run with --add-credential flag") logging.warning(" • Source: python src/proxy_app/main.py --add-credential") logging.warning("=" * 70) os.environ["LITELLM_LOG"] = "ERROR" litellm.set_verbose = False litellm.drop_params = True if USE_EMBEDDING_BATCHER: batcher = EmbeddingBatcher(client=client) app.state.embedding_batcher = batcher logging.info("RotatingClient and EmbeddingBatcher initialized.") else: app.state.embedding_batcher = None logging.info("RotatingClient initialized (EmbeddingBatcher disabled).") # Start model info service in background (fetches pricing/capabilities data) # This runs asynchronously and doesn't block proxy startup model_info_service = await init_model_info_service() app.state.model_info_service = model_info_service logging.info("Model info service started (fetching pricing data in background).") yield await client.background_refresher.stop() # Stop the background task on shutdown if app.state.embedding_batcher: await app.state.embedding_batcher.stop() await client.close() # Stop model info service if hasattr(app.state, "model_info_service") and app.state.model_info_service: await app.state.model_info_service.stop() if app.state.embedding_batcher: logging.info("RotatingClient and EmbeddingBatcher closed.") else: logging.info("RotatingClient closed.") # --- FastAPI App Setup --- app = FastAPI(lifespan=lifespan) # Add CORS middleware to allow all origins, methods, and headers app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) api_key_header = APIKeyHeader(name="Authorization", auto_error=False) def get_rotating_client(request: Request) -> RotatingClient: """Dependency to get the rotating client instance from the app state.""" return request.app.state.rotating_client def get_embedding_batcher(request: Request) -> EmbeddingBatcher: """Dependency to get the embedding batcher instance from the app state.""" return request.app.state.embedding_batcher async def verify_api_key(auth: str = Depends(api_key_header)): """Dependency to verify the proxy API key.""" # If PROXY_API_KEY is not set or empty, skip verification (open access) if not PROXY_API_KEY: return auth if not auth or auth != f"Bearer {PROXY_API_KEY}": raise HTTPException(status_code=401, detail="Invalid or missing API Key") return auth async def streaming_response_wrapper( request: Request, request_data: dict, response_stream: AsyncGenerator[str, None], logger: Optional[DetailedLogger] = None, ) -> AsyncGenerator[str, None]: """ Wraps a streaming response to log the full response after completion and ensures any errors during the stream are sent to the client. """ response_chunks = [] full_response = {} try: async for chunk_str in response_stream: if await request.is_disconnected(): logging.warning("Client disconnected, stopping stream.") break yield chunk_str if chunk_str.strip() and chunk_str.startswith("data:"): content = chunk_str[len("data:") :].strip() if content != "[DONE]": try: chunk_data = json.loads(content) response_chunks.append(chunk_data) if logger: logger.log_stream_chunk(chunk_data) except json.JSONDecodeError: pass except Exception as e: logging.error(f"An error occurred during the response stream: {e}") # Yield a final error message to the client to ensure they are not left hanging. error_payload = { "error": { "message": f"An unexpected error occurred during the stream: {str(e)}", "type": "proxy_internal_error", "code": 500, } } yield f"data: {json.dumps(error_payload)}\n\n" yield "data: [DONE]\n\n" # Also log this as a failed request if logger: logger.log_final_response( status_code=500, headers=None, body={"error": str(e)} ) return # Stop further processing finally: if response_chunks: # --- Aggregation Logic --- final_message = {"role": "assistant"} aggregated_tool_calls = {} usage_data = None finish_reason = None for chunk in response_chunks: if "choices" in chunk and chunk["choices"]: choice = chunk["choices"][0] delta = choice.get("delta", {}) # Dynamically aggregate all fields from the delta for key, value in delta.items(): if value is None: continue if key == "content": if "content" not in final_message: final_message["content"] = "" if value: final_message["content"] += value elif key == "tool_calls": for tc_chunk in value: index = tc_chunk["index"] if index not in aggregated_tool_calls: aggregated_tool_calls[index] = { "type": "function", "function": {"name": "", "arguments": ""}, } # Ensure 'function' key exists for this index before accessing its sub-keys if "function" not in aggregated_tool_calls[index]: aggregated_tool_calls[index]["function"] = { "name": "", "arguments": "", } if tc_chunk.get("id"): aggregated_tool_calls[index]["id"] = tc_chunk["id"] if "function" in tc_chunk: if "name" in tc_chunk["function"]: if tc_chunk["function"]["name"] is not None: aggregated_tool_calls[index]["function"][ "name" ] += tc_chunk["function"]["name"] if "arguments" in tc_chunk["function"]: if ( tc_chunk["function"]["arguments"] is not None ): aggregated_tool_calls[index]["function"][ "arguments" ] += tc_chunk["function"]["arguments"] elif key == "function_call": if "function_call" not in final_message: final_message["function_call"] = { "name": "", "arguments": "", } if "name" in value: if value["name"] is not None: final_message["function_call"]["name"] += value[ "name" ] if "arguments" in value: if value["arguments"] is not None: final_message["function_call"]["arguments"] += ( value["arguments"] ) else: # Generic key handling for other data like 'reasoning' # FIX: Role should always replace, never concatenate if key == "role": final_message[key] = value elif key not in final_message: final_message[key] = value elif isinstance(final_message.get(key), str): final_message[key] += value else: final_message[key] = value if "finish_reason" in choice and choice["finish_reason"]: finish_reason = choice["finish_reason"] if "usage" in chunk and chunk["usage"]: usage_data = chunk["usage"] # --- Final Response Construction --- if aggregated_tool_calls: final_message["tool_calls"] = list(aggregated_tool_calls.values()) # CRITICAL FIX: Override finish_reason when tool_calls exist # This ensures OpenCode and other agentic systems continue the conversation loop finish_reason = "tool_calls" # Ensure standard fields are present for consistent logging for field in ["content", "tool_calls", "function_call"]: if field not in final_message: final_message[field] = None first_chunk = response_chunks[0] final_choice = { "index": 0, "message": final_message, "finish_reason": finish_reason, } full_response = { "id": first_chunk.get("id"), "object": "chat.completion", "created": first_chunk.get("created"), "model": first_chunk.get("model"), "choices": [final_choice], "usage": usage_data, } if logger: logger.log_final_response( status_code=200, headers=None, # Headers are not available at this stage body=full_response, ) @app.post("/v1/chat/completions") async def chat_completions( request: Request, client: RotatingClient = Depends(get_rotating_client), _=Depends(verify_api_key), ): """ OpenAI-compatible endpoint powered by the RotatingClient. Handles both streaming and non-streaming responses and logs them. """ logger = DetailedLogger() if ENABLE_REQUEST_LOGGING else None try: # Read and parse the request body only once at the beginning. try: request_data = await request.json() except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid JSON in request body.") # Global temperature=0 override (controlled by .env variable, default: OFF) # Low temperature makes models deterministic and prone to following training data # instead of actual schemas, which can cause tool hallucination # Modes: "remove" = delete temperature key, "set" = change to 1.0, "false" = disabled override_temp_zero = os.getenv("OVERRIDE_TEMPERATURE_ZERO", "false").lower() if ( override_temp_zero in ("remove", "set", "true", "1", "yes") and "temperature" in request_data and request_data["temperature"] == 0 ): if override_temp_zero == "remove": # Remove temperature key entirely del request_data["temperature"] logging.debug( "OVERRIDE_TEMPERATURE_ZERO=remove: Removed temperature=0 from request" ) else: # Set to 1.0 (for "set", "true", "1", "yes") request_data["temperature"] = 1.0 logging.debug( "OVERRIDE_TEMPERATURE_ZERO=set: Converting temperature=0 to temperature=1.0" ) # If logging is enabled, perform all logging operations using the parsed data. if logger: logger.log_request(headers=request.headers, body=request_data) # Extract and log specific reasoning parameters for monitoring. model = request_data.get("model") generation_cfg = ( request_data.get("generationConfig", {}) or request_data.get("generation_config", {}) or {} ) reasoning_effort = request_data.get("reasoning_effort") or generation_cfg.get( "reasoning_effort" ) custom_reasoning_budget = request_data.get( "custom_reasoning_budget" ) or generation_cfg.get("custom_reasoning_budget", False) logging.getLogger("rotator_library").debug( f"Handling reasoning parameters: model={model}, reasoning_effort={reasoning_effort}, custom_reasoning_budget={custom_reasoning_budget}" ) # Log basic request info to console (this is a separate, simpler logger). log_request_to_console( url=str(request.url), headers=dict(request.headers), client_info=(request.client.host, request.client.port), request_data=request_data, ) is_streaming = request_data.get("stream", False) if is_streaming: response_generator = client.acompletion(request=request, **request_data) return StreamingResponse( streaming_response_wrapper( request, request_data, response_generator, logger ), media_type="text/event-stream", ) else: response = await client.acompletion(request=request, **request_data) if logger: # Assuming response has status_code and headers attributes # This might need adjustment based on the actual response object response_headers = ( response.headers if hasattr(response, "headers") else None ) status_code = ( response.status_code if hasattr(response, "status_code") else 200 ) logger.log_final_response( status_code=status_code, headers=response_headers, body=response.model_dump(), ) return response except ( litellm.InvalidRequestError, ValueError, litellm.ContextWindowExceededError, ) as e: raise HTTPException(status_code=400, detail=f"Invalid Request: {str(e)}") except litellm.AuthenticationError as e: raise HTTPException(status_code=401, detail=f"Authentication Error: {str(e)}") except litellm.RateLimitError as e: raise HTTPException(status_code=429, detail=f"Rate Limit Exceeded: {str(e)}") except (litellm.ServiceUnavailableError, litellm.APIConnectionError) as e: raise HTTPException(status_code=503, detail=f"Service Unavailable: {str(e)}") except litellm.Timeout as e: raise HTTPException(status_code=504, detail=f"Gateway Timeout: {str(e)}") except (litellm.InternalServerError, litellm.OpenAIError) as e: raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}") except Exception as e: logging.error(f"Request failed after all retries: {e}") # Optionally log the failed request if ENABLE_REQUEST_LOGGING: try: request_data = await request.json() except json.JSONDecodeError: request_data = {"error": "Could not parse request body"} if logger: logger.log_final_response( status_code=500, headers=None, body={"error": str(e)} ) raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/embeddings") async def embeddings( request: Request, body: EmbeddingRequest, client: RotatingClient = Depends(get_rotating_client), batcher: Optional[EmbeddingBatcher] = Depends(get_embedding_batcher), _=Depends(verify_api_key), ): """ OpenAI-compatible endpoint for creating embeddings. Supports two modes based on the USE_EMBEDDING_BATCHER flag: - True: Uses a server-side batcher for high throughput. - False: Passes requests directly to the provider. """ try: request_data = body.model_dump(exclude_none=True) log_request_to_console( url=str(request.url), headers=dict(request.headers), client_info=(request.client.host, request.client.port), request_data=request_data, ) if USE_EMBEDDING_BATCHER and batcher: # --- Server-Side Batching Logic --- request_data = body.model_dump(exclude_none=True) inputs = request_data.get("input", []) if isinstance(inputs, str): inputs = [inputs] tasks = [] for single_input in inputs: individual_request = request_data.copy() individual_request["input"] = single_input tasks.append(batcher.add_request(individual_request)) results = await asyncio.gather(*tasks) all_data = [] total_prompt_tokens = 0 total_tokens = 0 for i, result in enumerate(results): result["data"][0]["index"] = i all_data.extend(result["data"]) total_prompt_tokens += result["usage"]["prompt_tokens"] total_tokens += result["usage"]["total_tokens"] final_response_data = { "object": "list", "model": results[0]["model"], "data": all_data, "usage": { "prompt_tokens": total_prompt_tokens, "total_tokens": total_tokens, }, } response = litellm.EmbeddingResponse(**final_response_data) else: # --- Direct Pass-Through Logic --- request_data = body.model_dump(exclude_none=True) if isinstance(request_data.get("input"), str): request_data["input"] = [request_data["input"]] response = await client.aembedding(request=request, **request_data) return response except HTTPException as e: # Re-raise HTTPException to ensure it's not caught by the generic Exception handler raise e except ( litellm.InvalidRequestError, ValueError, litellm.ContextWindowExceededError, ) as e: raise HTTPException(status_code=400, detail=f"Invalid Request: {str(e)}") except litellm.AuthenticationError as e: raise HTTPException(status_code=401, detail=f"Authentication Error: {str(e)}") except litellm.RateLimitError as e: raise HTTPException(status_code=429, detail=f"Rate Limit Exceeded: {str(e)}") except (litellm.ServiceUnavailableError, litellm.APIConnectionError) as e: raise HTTPException(status_code=503, detail=f"Service Unavailable: {str(e)}") except litellm.Timeout as e: raise HTTPException(status_code=504, detail=f"Gateway Timeout: {str(e)}") except (litellm.InternalServerError, litellm.OpenAIError) as e: raise HTTPException(status_code=502, detail=f"Bad Gateway: {str(e)}") except Exception as e: logging.error(f"Embedding request failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/") def read_root(): return {"Status": "API Key Proxy is running"} @app.get("/v1/models") async def list_models( request: Request, client: RotatingClient = Depends(get_rotating_client), _=Depends(verify_api_key), enriched: bool = True, ): """ Returns a list of available models in the OpenAI-compatible format. Query Parameters: enriched: If True (default), returns detailed model info with pricing and capabilities. If False, returns minimal OpenAI-compatible response. """ model_ids = await client.get_all_available_models(grouped=False) if enriched and hasattr(request.app.state, "model_info_service"): model_info_service = request.app.state.model_info_service if model_info_service.is_ready: # Return enriched model data enriched_data = model_info_service.enrich_model_list(model_ids) return {"object": "list", "data": enriched_data} # Fallback to basic model cards model_cards = [ { "id": model_id, "object": "model", "created": int(time.time()), "owned_by": "Mirro-Proxy", } for model_id in model_ids ] return {"object": "list", "data": model_cards} @app.get("/v1/models/{model_id:path}") async def get_model( model_id: str, request: Request, _=Depends(verify_api_key), ): """ Returns detailed information about a specific model. Path Parameters: model_id: The model ID (e.g., "anthropic/claude-3-opus", "openrouter/openai/gpt-4") """ if hasattr(request.app.state, "model_info_service"): model_info_service = request.app.state.model_info_service if model_info_service.is_ready: info = model_info_service.get_model_info(model_id) if info: return info.to_dict() # Return basic info if service not ready or model not found return { "id": model_id, "object": "model", "created": int(time.time()), "owned_by": model_id.split("/")[0] if "/" in model_id else "unknown", } @app.get("/v1/model-info/stats") async def model_info_stats( request: Request, _=Depends(verify_api_key), ): """ Returns statistics about the model info service (for monitoring/debugging). """ if hasattr(request.app.state, "model_info_service"): return request.app.state.model_info_service.get_stats() return {"error": "Model info service not initialized"} @app.get("/v1/providers") async def list_providers(_=Depends(verify_api_key)): """ Returns a list of all available providers. """ return list(PROVIDER_PLUGINS.keys()) @app.get("/v1/quota-stats") async def get_quota_stats( request: Request, client: RotatingClient = Depends(get_rotating_client), _=Depends(verify_api_key), provider: str = None, ): """ Returns quota and usage statistics for all credentials. This returns cached data from the proxy without making external API calls. Use POST to reload from disk or force refresh from external APIs. Query Parameters: provider: Optional filter to return stats for a specific provider only Returns: { "providers": { "provider_name": { "credential_count": int, "active_count": int, "on_cooldown_count": int, "exhausted_count": int, "total_requests": int, "tokens": {...}, "approx_cost": float | null, "quota_groups": {...}, // For Antigravity "credentials": [...] } }, "summary": {...}, "data_source": "cache", "timestamp": float } """ try: stats = await client.get_quota_stats(provider_filter=provider) return stats except Exception as e: logging.error(f"Failed to get quota stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/quota-stats") async def refresh_quota_stats( request: Request, client: RotatingClient = Depends(get_rotating_client), _=Depends(verify_api_key), ): """ Refresh quota and usage statistics. Request body: { "action": "reload" | "force_refresh", "scope": "all" | "provider" | "credential", "provider": "antigravity", // required if scope != "all" "credential": "antigravity_oauth_1.json" // required if scope == "credential" } Actions: - reload: Re-read data from disk (no external API calls) - force_refresh: For Antigravity, fetch live quota from API. For other providers, same as reload. Returns: Same as GET, plus a "refresh_result" field with operation details. """ try: data = await request.json() action = data.get("action", "reload") scope = data.get("scope", "all") provider = data.get("provider") credential = data.get("credential") # Validate parameters if action not in ("reload", "force_refresh"): raise HTTPException( status_code=400, detail="action must be 'reload' or 'force_refresh'", ) if scope not in ("all", "provider", "credential"): raise HTTPException( status_code=400, detail="scope must be 'all', 'provider', or 'credential'", ) if scope in ("provider", "credential") and not provider: raise HTTPException( status_code=400, detail="'provider' is required when scope is 'provider' or 'credential'", ) if scope == "credential" and not credential: raise HTTPException( status_code=400, detail="'credential' is required when scope is 'credential'", ) refresh_result = { "action": action, "scope": scope, "provider": provider, "credential": credential, } if action == "reload": # Just reload from disk start_time = time.time() await client.reload_usage_from_disk() refresh_result["duration_ms"] = int((time.time() - start_time) * 1000) refresh_result["success"] = True refresh_result["message"] = "Reloaded usage data from disk" elif action == "force_refresh": # Force refresh from external API (for supported providers like Antigravity) result = await client.force_refresh_quota( provider=provider if scope in ("provider", "credential") else None, credential=credential if scope == "credential" else None, ) refresh_result.update(result) refresh_result["success"] = result["failed_count"] == 0 # Get updated stats stats = await client.get_quota_stats(provider_filter=provider) stats["refresh_result"] = refresh_result stats["data_source"] = "refreshed" return stats except HTTPException: raise except Exception as e: logging.error(f"Failed to refresh quota stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/token-count") async def token_count( request: Request, client: RotatingClient = Depends(get_rotating_client), _=Depends(verify_api_key), ): """ Calculates the token count for a given list of messages and a model. """ try: data = await request.json() model = data.get("model") messages = data.get("messages") if not model or not messages: raise HTTPException( status_code=400, detail="'model' and 'messages' are required." ) count = client.token_count(**data) return {"token_count": count} except Exception as e: logging.error(f"Token count failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/cost-estimate") async def cost_estimate(request: Request, _=Depends(verify_api_key)): """ Estimates the cost for a request based on token counts and model pricing. Request body: { "model": "anthropic/claude-3-opus", "prompt_tokens": 1000, "completion_tokens": 500, "cache_read_tokens": 0, # optional "cache_creation_tokens": 0 # optional } Returns: { "model": "anthropic/claude-3-opus", "cost": 0.0375, "currency": "USD", "pricing": { "input_cost_per_token": 0.000015, "output_cost_per_token": 0.000075 }, "source": "model_info_service" # or "litellm_fallback" } """ try: data = await request.json() model = data.get("model") prompt_tokens = data.get("prompt_tokens", 0) completion_tokens = data.get("completion_tokens", 0) cache_read_tokens = data.get("cache_read_tokens", 0) cache_creation_tokens = data.get("cache_creation_tokens", 0) if not model: raise HTTPException(status_code=400, detail="'model' is required.") result = { "model": model, "cost": None, "currency": "USD", "pricing": {}, "source": None, } # Try model info service first if hasattr(request.app.state, "model_info_service"): model_info_service = request.app.state.model_info_service if model_info_service.is_ready: cost = model_info_service.calculate_cost( model, prompt_tokens, completion_tokens, cache_read_tokens, cache_creation_tokens, ) if cost is not None: cost_info = model_info_service.get_cost_info(model) result["cost"] = cost result["pricing"] = cost_info or {} result["source"] = "model_info_service" return result # Fallback to litellm try: import litellm # Create a mock response for cost calculation model_info = litellm.get_model_info(model) input_cost = model_info.get("input_cost_per_token", 0) output_cost = model_info.get("output_cost_per_token", 0) if input_cost or output_cost: cost = (prompt_tokens * input_cost) + (completion_tokens * output_cost) result["cost"] = cost result["pricing"] = { "input_cost_per_token": input_cost, "output_cost_per_token": output_cost, } result["source"] = "litellm_fallback" return result except Exception: pass result["source"] = "unknown" result["error"] = "Pricing data not available for this model" return result except HTTPException: raise except Exception as e: logging.error(f"Cost estimate failed: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": # Define ENV_FILE for onboarding checks using centralized path ENV_FILE = get_data_file(".env") # Check if launcher TUI should be shown (no arguments provided) if len(sys.argv) == 1: # No arguments - show launcher TUI (lazy import) from proxy_app.launcher_tui import run_launcher_tui run_launcher_tui() # Launcher modifies sys.argv and returns, or exits if user chose Exit # If we get here, user chose "Run Proxy" and sys.argv is modified # Re-parse arguments with modified sys.argv args = parser.parse_args() def needs_onboarding() -> bool: """ Check if the proxy needs onboarding (first-time setup). Returns True if onboarding is needed, False otherwise. """ # Only check if .env file exists # PROXY_API_KEY is optional (will show warning if not set) if not ENV_FILE.is_file(): return True return False def show_onboarding_message(): """Display clear explanatory message for why onboarding is needed.""" os.system( "cls" if os.name == "nt" else "clear" ) # Clear terminal for clean presentation console.print( Panel.fit( "[bold cyan]🚀 LLM API Key Proxy - First Time Setup[/bold cyan]", border_style="cyan", ) ) console.print("[bold yellow]⚠️ Configuration Required[/bold yellow]\n") console.print("The proxy needs initial configuration:") console.print(" [red]❌ No .env file found[/red]") console.print("\n[bold]Why this matters:[/bold]") console.print(" • The .env file stores your credentials and settings") console.print(" • PROXY_API_KEY protects your proxy from unauthorized access") console.print(" • Provider API keys enable LLM access") console.print("\n[bold]What happens next:[/bold]") console.print(" 1. We'll create a .env file with PROXY_API_KEY") console.print(" 2. You can add LLM provider credentials (API keys or OAuth)") console.print(" 3. The proxy will then start normally") console.print( "\n[bold yellow]⚠️ Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default." ) console.print(" You can remove it later if you want an unsecured proxy.\n") console.input( "[bold green]Press Enter to launch the credential setup tool...[/bold green]" ) # Check if user explicitly wants to add credentials if args.add_credential: # Import and call ensure_env_defaults to create .env and PROXY_API_KEY if needed from rotator_library.credential_tool import ensure_env_defaults ensure_env_defaults() # Reload environment variables after ensure_env_defaults creates/updates .env load_dotenv(ENV_FILE, override=True) run_credential_tool() else: # Check if onboarding is needed if needs_onboarding(): # Import console from rich for better messaging from rich.console import Console from rich.panel import Panel console = Console() # Show clear explanatory message show_onboarding_message() # Launch credential tool automatically from rotator_library.credential_tool import ensure_env_defaults ensure_env_defaults() load_dotenv(ENV_FILE, override=True) run_credential_tool() # After credential tool exits, reload and re-check load_dotenv(ENV_FILE, override=True) # Re-read PROXY_API_KEY from environment PROXY_API_KEY = os.getenv("PROXY_API_KEY") # Verify onboarding is complete if needs_onboarding(): console.print("\n[bold red]❌ Configuration incomplete.[/bold red]") console.print( "The proxy still cannot start. Please ensure PROXY_API_KEY is set in .env\n" ) sys.exit(1) else: console.print("\n[bold green]✅ Configuration complete![/bold green]") console.print("\nStarting proxy server...\n") import uvicorn uvicorn.run(app, host=args.host, port=args.port)