""" Session management for API keys and user configuration. This module provides in-memory session management with automatic timeout to securely handle API keys and user-specific configuration. """ import time from datetime import datetime, timedelta from typing import Any, Dict, Optional class SessionManager: """Manages user sessions with API keys and configuration.""" def __init__(self, timeout_minutes: int = 60): """ Initialize session manager. Args: timeout_minutes: Session timeout in minutes (default: 60) """ self.timeout_minutes = timeout_minutes self.sessions: Dict[str, Dict[str, Any]] = {} def create_session( self, session_id: str, api_keys: Optional[Dict[str, str]] = None ) -> Dict[str, Any]: """ Create a new session. Args: session_id: Unique session identifier api_keys: Optional API keys to store (default: use environment variables) Returns: Session data dictionary """ session_data = { "created_at": datetime.now(), "last_accessed": datetime.now(), "api_keys": api_keys or {}, "config": {}, } self.sessions[session_id] = session_data return session_data def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ Get session data if it exists and hasn't expired. Args: session_id: Session identifier Returns: Session data or None if expired/not found """ if session_id not in self.sessions: return None session = self.sessions[session_id] # Check if session has expired if self._is_expired(session): self.delete_session(session_id) return None # Update last accessed time session["last_accessed"] = datetime.now() return session def update_session(self, session_id: str, updates: Dict[str, Any]) -> bool: """ Update session data. Args: session_id: Session identifier updates: Dictionary of updates to apply Returns: True if successful, False if session not found/expired """ session = self.get_session(session_id) if session is None: return False # Merge updates into session for key, value in updates.items(): if key == "api_keys": # Merge API keys session["api_keys"].update(value) elif key == "config": # Merge configuration session["config"].update(value) else: session[key] = value session["last_accessed"] = datetime.now() return True def delete_session(self, session_id: str) -> bool: """ Delete session and clear all stored data. Args: session_id: Session identifier Returns: True if deleted, False if not found """ if session_id in self.sessions: # Clear sensitive data if "api_keys" in self.sessions[session_id]: self.sessions[session_id]["api_keys"].clear() del self.sessions[session_id] return True return False def cleanup_expired_sessions(self) -> int: """ Remove all expired sessions. Returns: Number of sessions cleaned up """ expired_sessions = [ session_id for session_id, session in self.sessions.items() if self._is_expired(session) ] for session_id in expired_sessions: self.delete_session(session_id) return len(expired_sessions) def _is_expired(self, session: Dict[str, Any]) -> bool: """ Check if session has expired. Args: session: Session data dictionary Returns: True if expired, False otherwise """ last_accessed = session.get("last_accessed") if last_accessed is None: return True expiry_time = last_accessed + timedelta(minutes=self.timeout_minutes) return datetime.now() > expiry_time def get_active_session_count(self) -> int: """ Get count of active (non-expired) sessions. Returns: Number of active sessions """ return len([s for s in self.sessions.values() if not self._is_expired(s)]) # Global session manager instance _session_manager: Optional[SessionManager] = None def get_session_manager(timeout_minutes: int = 60) -> SessionManager: """ Get global session manager instance. Args: timeout_minutes: Session timeout in minutes Returns: SessionManager instance """ global _session_manager if _session_manager is None: _session_manager = SessionManager(timeout_minutes) return _session_manager # Convenience functions for common operations def create_user_session( user_id: str, api_keys: Optional[Dict[str, str]] = None ) -> Dict[str, Any]: """Create a new user session.""" manager = get_session_manager() return manager.create_session(user_id, api_keys) def get_user_session(user_id: str) -> Optional[Dict[str, Any]]: """Get user session data.""" manager = get_session_manager() return manager.get_session(user_id) def update_user_session(user_id: str, updates: Dict[str, Any]) -> bool: """Update user session data.""" manager = get_session_manager() return manager.update_session(user_id, updates) def delete_user_session(user_id: str) -> bool: """Delete user session.""" manager = get_session_manager() return manager.delete_session(user_id) def cleanup_sessions() -> int: """Clean up expired sessions.""" manager = get_session_manager() return manager.cleanup_expired_sessions() # Configuration validation functions def validate_indicator_parameters(params: Dict[str, Any]) -> tuple[bool, Optional[str]]: """ Validate indicator parameters. Args: params: Indicator parameters dictionary Returns: Tuple of (is_valid, error_message) """ # RSI period validation if "rsi_period" in params: rsi_period = params["rsi_period"] if not isinstance(rsi_period, int) or rsi_period < 2 or rsi_period > 100: return False, "RSI period must be an integer between 2 and 100" # MACD parameters validation if "macd_fast" in params: macd_fast = params["macd_fast"] if not isinstance(macd_fast, int) or macd_fast < 2 or macd_fast > 50: return False, "MACD fast period must be an integer between 2 and 50" if "macd_slow" in params: macd_slow = params["macd_slow"] if not isinstance(macd_slow, int) or macd_slow < 2 or macd_slow > 100: return False, "MACD slow period must be an integer between 2 and 100" # Ensure slow > fast if "macd_fast" in params and macd_slow <= params["macd_fast"]: return False, "MACD slow period must be greater than fast period" if "macd_signal" in params: macd_signal = params["macd_signal"] if not isinstance(macd_signal, int) or macd_signal < 2 or macd_signal > 50: return False, "MACD signal period must be an integer between 2 and 50" # Stochastic parameters validation if "stoch_k_period" in params: stoch_k = params["stoch_k_period"] if not isinstance(stoch_k, int) or stoch_k < 2 or stoch_k > 50: return False, "Stochastic K period must be an integer between 2 and 50" if "stoch_d_period" in params: stoch_d = params["stoch_d_period"] if not isinstance(stoch_d, int) or stoch_d < 2 or stoch_d > 20: return False, "Stochastic D period must be an integer between 2 and 20" return True, None def validate_model_name(provider: str, model: str) -> tuple[bool, Optional[str]]: """ Validate LLM model name for a given provider. Args: provider: LLM provider name (openai, anthropic, qwen) model: Model name Returns: Tuple of (is_valid, error_message) """ valid_models = { "openai": [ "gpt-4", "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", ], "anthropic": [ "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-2.1", "claude-2.0", ], "qwen": ["qwen-turbo", "qwen-plus", "qwen-max"], } if provider not in valid_models: return ( False, f"Unknown provider: {provider}. Valid providers: {', '.join(valid_models.keys())}", ) if model not in valid_models[provider]: return ( False, f"Invalid model '{model}' for provider '{provider}'. Valid models: {', '.join(valid_models[provider])}", ) return True, None def validate_data_provider(provider: str) -> tuple[bool, Optional[str]]: """ Validate data provider name. Args: provider: Data provider name Returns: Tuple of (is_valid, error_message) """ valid_providers = ["yfinance", "alpha_vantage"] if provider not in valid_providers: return ( False, f"Invalid data provider: {provider}. Valid providers: {', '.join(valid_providers)}", ) return True, None def validate_hf_token(token: str) -> tuple[bool, Optional[str]]: """ Validate HuggingFace API token format. Args: token: HuggingFace API token Returns: Tuple of (is_valid, error_message) """ if not token: return False, "HuggingFace token cannot be empty" # HF tokens typically start with "hf_" and are alphanumeric if not token.startswith("hf_"): return ( False, "HuggingFace token should start with 'hf_'. Get your token from https://huggingface.co/settings/tokens", ) # Check minimum length (HF tokens are typically 30-40 characters) if len(token) < 20: return False, "HuggingFace token appears too short. Please check your token." # Check for valid characters (alphanumeric and underscores) if not all(c.isalnum() or c == "_" for c in token): return False, "HuggingFace token contains invalid characters" return True, None def validate_api_keys(api_keys: Dict[str, str]) -> tuple[bool, Optional[str]]: """ Validate API keys for various providers. Args: api_keys: Dictionary of API keys by provider Returns: Tuple of (is_valid, error_message) """ # Validate HuggingFace token if provided if "huggingface" in api_keys or "hf_token" in api_keys: token = api_keys.get("huggingface") or api_keys.get("hf_token") is_valid, error = validate_hf_token(token) if not is_valid: return False, f"Invalid HuggingFace token: {error}" # Add validation for other providers as needed # OpenAI keys typically start with "sk-" if "openai" in api_keys: openai_key = api_keys["openai"] if not openai_key.startswith("sk-"): return False, "OpenAI API key should start with 'sk-'" # Anthropic keys typically start with "sk-ant-" if "anthropic" in api_keys: anthropic_key = api_keys["anthropic"] if not anthropic_key.startswith("sk-ant-"): return False, "Anthropic API key should start with 'sk-ant-'" return True, None def validate_configuration(config: Dict[str, Any]) -> tuple[bool, Optional[str]]: """ Validate complete configuration object. Args: config: Configuration dictionary Returns: Tuple of (is_valid, error_message) """ # Validate indicator parameters if "indicator_parameters" in config: is_valid, error = validate_indicator_parameters(config["indicator_parameters"]) if not is_valid: return False, f"Invalid indicator parameters: {error}" # Validate LLM provider if "llm_provider" in config: provider = config["llm_provider"] if provider not in ["openai", "anthropic", "huggingface", "qwen"]: return False, f"Invalid LLM provider: {provider}" # Validate API keys if present if "api_keys" in config: is_valid, error = validate_api_keys(config["api_keys"]) if not is_valid: return False, f"Invalid API keys: {error}" # Validate data providers if "data_providers" in config: providers = config["data_providers"] if "ohlc_primary" in providers: is_valid, error = validate_data_provider(providers["ohlc_primary"]) if not is_valid: return False, f"Invalid OHLC primary provider: {error}" if "fundamentals_primary" in providers: is_valid, error = validate_data_provider(providers["fundamentals_primary"]) if not is_valid: return False, f"Invalid fundamentals provider: {error}" return True, None