Spaces:
Sleeping
Sleeping
| """ | |
| Session management for API keys and user configuration. | |
| This module provides in-memory session management with automatic timeout | |
| to securely handle API keys and user-specific configuration. | |
| """ | |
| import time | |
| from datetime import datetime, timedelta | |
| from typing import Any, Dict, Optional | |
| class SessionManager: | |
| """Manages user sessions with API keys and configuration.""" | |
| def __init__(self, timeout_minutes: int = 60): | |
| """ | |
| Initialize session manager. | |
| Args: | |
| timeout_minutes: Session timeout in minutes (default: 60) | |
| """ | |
| self.timeout_minutes = timeout_minutes | |
| self.sessions: Dict[str, Dict[str, Any]] = {} | |
| def create_session( | |
| self, session_id: str, api_keys: Optional[Dict[str, str]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Create a new session. | |
| Args: | |
| session_id: Unique session identifier | |
| api_keys: Optional API keys to store (default: use environment variables) | |
| Returns: | |
| Session data dictionary | |
| """ | |
| session_data = { | |
| "created_at": datetime.now(), | |
| "last_accessed": datetime.now(), | |
| "api_keys": api_keys or {}, | |
| "config": {}, | |
| } | |
| self.sessions[session_id] = session_data | |
| return session_data | |
| def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get session data if it exists and hasn't expired. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| Session data or None if expired/not found | |
| """ | |
| if session_id not in self.sessions: | |
| return None | |
| session = self.sessions[session_id] | |
| # Check if session has expired | |
| if self._is_expired(session): | |
| self.delete_session(session_id) | |
| return None | |
| # Update last accessed time | |
| session["last_accessed"] = datetime.now() | |
| return session | |
| def update_session(self, session_id: str, updates: Dict[str, Any]) -> bool: | |
| """ | |
| Update session data. | |
| Args: | |
| session_id: Session identifier | |
| updates: Dictionary of updates to apply | |
| Returns: | |
| True if successful, False if session not found/expired | |
| """ | |
| session = self.get_session(session_id) | |
| if session is None: | |
| return False | |
| # Merge updates into session | |
| for key, value in updates.items(): | |
| if key == "api_keys": | |
| # Merge API keys | |
| session["api_keys"].update(value) | |
| elif key == "config": | |
| # Merge configuration | |
| session["config"].update(value) | |
| else: | |
| session[key] = value | |
| session["last_accessed"] = datetime.now() | |
| return True | |
| def delete_session(self, session_id: str) -> bool: | |
| """ | |
| Delete session and clear all stored data. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| True if deleted, False if not found | |
| """ | |
| if session_id in self.sessions: | |
| # Clear sensitive data | |
| if "api_keys" in self.sessions[session_id]: | |
| self.sessions[session_id]["api_keys"].clear() | |
| del self.sessions[session_id] | |
| return True | |
| return False | |
| def cleanup_expired_sessions(self) -> int: | |
| """ | |
| Remove all expired sessions. | |
| Returns: | |
| Number of sessions cleaned up | |
| """ | |
| expired_sessions = [ | |
| session_id | |
| for session_id, session in self.sessions.items() | |
| if self._is_expired(session) | |
| ] | |
| for session_id in expired_sessions: | |
| self.delete_session(session_id) | |
| return len(expired_sessions) | |
| def _is_expired(self, session: Dict[str, Any]) -> bool: | |
| """ | |
| Check if session has expired. | |
| Args: | |
| session: Session data dictionary | |
| Returns: | |
| True if expired, False otherwise | |
| """ | |
| last_accessed = session.get("last_accessed") | |
| if last_accessed is None: | |
| return True | |
| expiry_time = last_accessed + timedelta(minutes=self.timeout_minutes) | |
| return datetime.now() > expiry_time | |
| def get_active_session_count(self) -> int: | |
| """ | |
| Get count of active (non-expired) sessions. | |
| Returns: | |
| Number of active sessions | |
| """ | |
| return len([s for s in self.sessions.values() if not self._is_expired(s)]) | |
| # Global session manager instance | |
| _session_manager: Optional[SessionManager] = None | |
| def get_session_manager(timeout_minutes: int = 60) -> SessionManager: | |
| """ | |
| Get global session manager instance. | |
| Args: | |
| timeout_minutes: Session timeout in minutes | |
| Returns: | |
| SessionManager instance | |
| """ | |
| global _session_manager | |
| if _session_manager is None: | |
| _session_manager = SessionManager(timeout_minutes) | |
| return _session_manager | |
| # Convenience functions for common operations | |
| def create_user_session( | |
| user_id: str, api_keys: Optional[Dict[str, str]] = None | |
| ) -> Dict[str, Any]: | |
| """Create a new user session.""" | |
| manager = get_session_manager() | |
| return manager.create_session(user_id, api_keys) | |
| def get_user_session(user_id: str) -> Optional[Dict[str, Any]]: | |
| """Get user session data.""" | |
| manager = get_session_manager() | |
| return manager.get_session(user_id) | |
| def update_user_session(user_id: str, updates: Dict[str, Any]) -> bool: | |
| """Update user session data.""" | |
| manager = get_session_manager() | |
| return manager.update_session(user_id, updates) | |
| def delete_user_session(user_id: str) -> bool: | |
| """Delete user session.""" | |
| manager = get_session_manager() | |
| return manager.delete_session(user_id) | |
| def cleanup_sessions() -> int: | |
| """Clean up expired sessions.""" | |
| manager = get_session_manager() | |
| return manager.cleanup_expired_sessions() | |
| # Configuration validation functions | |
| def validate_indicator_parameters(params: Dict[str, Any]) -> tuple[bool, Optional[str]]: | |
| """ | |
| Validate indicator parameters. | |
| Args: | |
| params: Indicator parameters dictionary | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| # RSI period validation | |
| if "rsi_period" in params: | |
| rsi_period = params["rsi_period"] | |
| if not isinstance(rsi_period, int) or rsi_period < 2 or rsi_period > 100: | |
| return False, "RSI period must be an integer between 2 and 100" | |
| # MACD parameters validation | |
| if "macd_fast" in params: | |
| macd_fast = params["macd_fast"] | |
| if not isinstance(macd_fast, int) or macd_fast < 2 or macd_fast > 50: | |
| return False, "MACD fast period must be an integer between 2 and 50" | |
| if "macd_slow" in params: | |
| macd_slow = params["macd_slow"] | |
| if not isinstance(macd_slow, int) or macd_slow < 2 or macd_slow > 100: | |
| return False, "MACD slow period must be an integer between 2 and 100" | |
| # Ensure slow > fast | |
| if "macd_fast" in params and macd_slow <= params["macd_fast"]: | |
| return False, "MACD slow period must be greater than fast period" | |
| if "macd_signal" in params: | |
| macd_signal = params["macd_signal"] | |
| if not isinstance(macd_signal, int) or macd_signal < 2 or macd_signal > 50: | |
| return False, "MACD signal period must be an integer between 2 and 50" | |
| # Stochastic parameters validation | |
| if "stoch_k_period" in params: | |
| stoch_k = params["stoch_k_period"] | |
| if not isinstance(stoch_k, int) or stoch_k < 2 or stoch_k > 50: | |
| return False, "Stochastic K period must be an integer between 2 and 50" | |
| if "stoch_d_period" in params: | |
| stoch_d = params["stoch_d_period"] | |
| if not isinstance(stoch_d, int) or stoch_d < 2 or stoch_d > 20: | |
| return False, "Stochastic D period must be an integer between 2 and 20" | |
| return True, None | |
| def validate_model_name(provider: str, model: str) -> tuple[bool, Optional[str]]: | |
| """ | |
| Validate LLM model name for a given provider. | |
| Args: | |
| provider: LLM provider name (openai, anthropic, qwen) | |
| model: Model name | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| valid_models = { | |
| "openai": [ | |
| "gpt-4", | |
| "gpt-4-turbo", | |
| "gpt-4-turbo-preview", | |
| "gpt-3.5-turbo", | |
| "gpt-3.5-turbo-16k", | |
| ], | |
| "anthropic": [ | |
| "claude-3-opus-20240229", | |
| "claude-3-sonnet-20240229", | |
| "claude-3-haiku-20240307", | |
| "claude-2.1", | |
| "claude-2.0", | |
| ], | |
| "qwen": ["qwen-turbo", "qwen-plus", "qwen-max"], | |
| } | |
| if provider not in valid_models: | |
| return ( | |
| False, | |
| f"Unknown provider: {provider}. Valid providers: {', '.join(valid_models.keys())}", | |
| ) | |
| if model not in valid_models[provider]: | |
| return ( | |
| False, | |
| f"Invalid model '{model}' for provider '{provider}'. Valid models: {', '.join(valid_models[provider])}", | |
| ) | |
| return True, None | |
| def validate_data_provider(provider: str) -> tuple[bool, Optional[str]]: | |
| """ | |
| Validate data provider name. | |
| Args: | |
| provider: Data provider name | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| valid_providers = ["yfinance", "alpha_vantage"] | |
| if provider not in valid_providers: | |
| return ( | |
| False, | |
| f"Invalid data provider: {provider}. Valid providers: {', '.join(valid_providers)}", | |
| ) | |
| return True, None | |
| def validate_hf_token(token: str) -> tuple[bool, Optional[str]]: | |
| """ | |
| Validate HuggingFace API token format. | |
| Args: | |
| token: HuggingFace API token | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if not token: | |
| return False, "HuggingFace token cannot be empty" | |
| # HF tokens typically start with "hf_" and are alphanumeric | |
| if not token.startswith("hf_"): | |
| return ( | |
| False, | |
| "HuggingFace token should start with 'hf_'. Get your token from https://huggingface.co/settings/tokens", | |
| ) | |
| # Check minimum length (HF tokens are typically 30-40 characters) | |
| if len(token) < 20: | |
| return False, "HuggingFace token appears too short. Please check your token." | |
| # Check for valid characters (alphanumeric and underscores) | |
| if not all(c.isalnum() or c == "_" for c in token): | |
| return False, "HuggingFace token contains invalid characters" | |
| return True, None | |
| def validate_api_keys(api_keys: Dict[str, str]) -> tuple[bool, Optional[str]]: | |
| """ | |
| Validate API keys for various providers. | |
| Args: | |
| api_keys: Dictionary of API keys by provider | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| # Validate HuggingFace token if provided | |
| if "huggingface" in api_keys or "hf_token" in api_keys: | |
| token = api_keys.get("huggingface") or api_keys.get("hf_token") | |
| is_valid, error = validate_hf_token(token) | |
| if not is_valid: | |
| return False, f"Invalid HuggingFace token: {error}" | |
| # Add validation for other providers as needed | |
| # OpenAI keys typically start with "sk-" | |
| if "openai" in api_keys: | |
| openai_key = api_keys["openai"] | |
| if not openai_key.startswith("sk-"): | |
| return False, "OpenAI API key should start with 'sk-'" | |
| # Anthropic keys typically start with "sk-ant-" | |
| if "anthropic" in api_keys: | |
| anthropic_key = api_keys["anthropic"] | |
| if not anthropic_key.startswith("sk-ant-"): | |
| return False, "Anthropic API key should start with 'sk-ant-'" | |
| return True, None | |
| def validate_configuration(config: Dict[str, Any]) -> tuple[bool, Optional[str]]: | |
| """ | |
| Validate complete configuration object. | |
| Args: | |
| config: Configuration dictionary | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| # Validate indicator parameters | |
| if "indicator_parameters" in config: | |
| is_valid, error = validate_indicator_parameters(config["indicator_parameters"]) | |
| if not is_valid: | |
| return False, f"Invalid indicator parameters: {error}" | |
| # Validate LLM provider | |
| if "llm_provider" in config: | |
| provider = config["llm_provider"] | |
| if provider not in ["openai", "anthropic", "huggingface", "qwen"]: | |
| return False, f"Invalid LLM provider: {provider}" | |
| # Validate API keys if present | |
| if "api_keys" in config: | |
| is_valid, error = validate_api_keys(config["api_keys"]) | |
| if not is_valid: | |
| return False, f"Invalid API keys: {error}" | |
| # Validate data providers | |
| if "data_providers" in config: | |
| providers = config["data_providers"] | |
| if "ohlc_primary" in providers: | |
| is_valid, error = validate_data_provider(providers["ohlc_primary"]) | |
| if not is_valid: | |
| return False, f"Invalid OHLC primary provider: {error}" | |
| if "fundamentals_primary" in providers: | |
| is_valid, error = validate_data_provider(providers["fundamentals_primary"]) | |
| if not is_valid: | |
| return False, f"Invalid fundamentals provider: {error}" | |
| return True, None | |