trading-tools / web /config /api_keys.py
Deploy Bot
Deploy Trading Analysis Platform to HuggingFace Spaces
a1bf219
"""
Session management for API keys and user configuration.
This module provides in-memory session management with automatic timeout
to securely handle API keys and user-specific configuration.
"""
import time
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
class SessionManager:
"""Manages user sessions with API keys and configuration."""
def __init__(self, timeout_minutes: int = 60):
"""
Initialize session manager.
Args:
timeout_minutes: Session timeout in minutes (default: 60)
"""
self.timeout_minutes = timeout_minutes
self.sessions: Dict[str, Dict[str, Any]] = {}
def create_session(
self, session_id: str, api_keys: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
Create a new session.
Args:
session_id: Unique session identifier
api_keys: Optional API keys to store (default: use environment variables)
Returns:
Session data dictionary
"""
session_data = {
"created_at": datetime.now(),
"last_accessed": datetime.now(),
"api_keys": api_keys or {},
"config": {},
}
self.sessions[session_id] = session_data
return session_data
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
Get session data if it exists and hasn't expired.
Args:
session_id: Session identifier
Returns:
Session data or None if expired/not found
"""
if session_id not in self.sessions:
return None
session = self.sessions[session_id]
# Check if session has expired
if self._is_expired(session):
self.delete_session(session_id)
return None
# Update last accessed time
session["last_accessed"] = datetime.now()
return session
def update_session(self, session_id: str, updates: Dict[str, Any]) -> bool:
"""
Update session data.
Args:
session_id: Session identifier
updates: Dictionary of updates to apply
Returns:
True if successful, False if session not found/expired
"""
session = self.get_session(session_id)
if session is None:
return False
# Merge updates into session
for key, value in updates.items():
if key == "api_keys":
# Merge API keys
session["api_keys"].update(value)
elif key == "config":
# Merge configuration
session["config"].update(value)
else:
session[key] = value
session["last_accessed"] = datetime.now()
return True
def delete_session(self, session_id: str) -> bool:
"""
Delete session and clear all stored data.
Args:
session_id: Session identifier
Returns:
True if deleted, False if not found
"""
if session_id in self.sessions:
# Clear sensitive data
if "api_keys" in self.sessions[session_id]:
self.sessions[session_id]["api_keys"].clear()
del self.sessions[session_id]
return True
return False
def cleanup_expired_sessions(self) -> int:
"""
Remove all expired sessions.
Returns:
Number of sessions cleaned up
"""
expired_sessions = [
session_id
for session_id, session in self.sessions.items()
if self._is_expired(session)
]
for session_id in expired_sessions:
self.delete_session(session_id)
return len(expired_sessions)
def _is_expired(self, session: Dict[str, Any]) -> bool:
"""
Check if session has expired.
Args:
session: Session data dictionary
Returns:
True if expired, False otherwise
"""
last_accessed = session.get("last_accessed")
if last_accessed is None:
return True
expiry_time = last_accessed + timedelta(minutes=self.timeout_minutes)
return datetime.now() > expiry_time
def get_active_session_count(self) -> int:
"""
Get count of active (non-expired) sessions.
Returns:
Number of active sessions
"""
return len([s for s in self.sessions.values() if not self._is_expired(s)])
# Global session manager instance
_session_manager: Optional[SessionManager] = None
def get_session_manager(timeout_minutes: int = 60) -> SessionManager:
"""
Get global session manager instance.
Args:
timeout_minutes: Session timeout in minutes
Returns:
SessionManager instance
"""
global _session_manager
if _session_manager is None:
_session_manager = SessionManager(timeout_minutes)
return _session_manager
# Convenience functions for common operations
def create_user_session(
user_id: str, api_keys: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""Create a new user session."""
manager = get_session_manager()
return manager.create_session(user_id, api_keys)
def get_user_session(user_id: str) -> Optional[Dict[str, Any]]:
"""Get user session data."""
manager = get_session_manager()
return manager.get_session(user_id)
def update_user_session(user_id: str, updates: Dict[str, Any]) -> bool:
"""Update user session data."""
manager = get_session_manager()
return manager.update_session(user_id, updates)
def delete_user_session(user_id: str) -> bool:
"""Delete user session."""
manager = get_session_manager()
return manager.delete_session(user_id)
def cleanup_sessions() -> int:
"""Clean up expired sessions."""
manager = get_session_manager()
return manager.cleanup_expired_sessions()
# Configuration validation functions
def validate_indicator_parameters(params: Dict[str, Any]) -> tuple[bool, Optional[str]]:
"""
Validate indicator parameters.
Args:
params: Indicator parameters dictionary
Returns:
Tuple of (is_valid, error_message)
"""
# RSI period validation
if "rsi_period" in params:
rsi_period = params["rsi_period"]
if not isinstance(rsi_period, int) or rsi_period < 2 or rsi_period > 100:
return False, "RSI period must be an integer between 2 and 100"
# MACD parameters validation
if "macd_fast" in params:
macd_fast = params["macd_fast"]
if not isinstance(macd_fast, int) or macd_fast < 2 or macd_fast > 50:
return False, "MACD fast period must be an integer between 2 and 50"
if "macd_slow" in params:
macd_slow = params["macd_slow"]
if not isinstance(macd_slow, int) or macd_slow < 2 or macd_slow > 100:
return False, "MACD slow period must be an integer between 2 and 100"
# Ensure slow > fast
if "macd_fast" in params and macd_slow <= params["macd_fast"]:
return False, "MACD slow period must be greater than fast period"
if "macd_signal" in params:
macd_signal = params["macd_signal"]
if not isinstance(macd_signal, int) or macd_signal < 2 or macd_signal > 50:
return False, "MACD signal period must be an integer between 2 and 50"
# Stochastic parameters validation
if "stoch_k_period" in params:
stoch_k = params["stoch_k_period"]
if not isinstance(stoch_k, int) or stoch_k < 2 or stoch_k > 50:
return False, "Stochastic K period must be an integer between 2 and 50"
if "stoch_d_period" in params:
stoch_d = params["stoch_d_period"]
if not isinstance(stoch_d, int) or stoch_d < 2 or stoch_d > 20:
return False, "Stochastic D period must be an integer between 2 and 20"
return True, None
def validate_model_name(provider: str, model: str) -> tuple[bool, Optional[str]]:
"""
Validate LLM model name for a given provider.
Args:
provider: LLM provider name (openai, anthropic, qwen)
model: Model name
Returns:
Tuple of (is_valid, error_message)
"""
valid_models = {
"openai": [
"gpt-4",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
],
"anthropic": [
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-2.1",
"claude-2.0",
],
"qwen": ["qwen-turbo", "qwen-plus", "qwen-max"],
}
if provider not in valid_models:
return (
False,
f"Unknown provider: {provider}. Valid providers: {', '.join(valid_models.keys())}",
)
if model not in valid_models[provider]:
return (
False,
f"Invalid model '{model}' for provider '{provider}'. Valid models: {', '.join(valid_models[provider])}",
)
return True, None
def validate_data_provider(provider: str) -> tuple[bool, Optional[str]]:
"""
Validate data provider name.
Args:
provider: Data provider name
Returns:
Tuple of (is_valid, error_message)
"""
valid_providers = ["yfinance", "alpha_vantage"]
if provider not in valid_providers:
return (
False,
f"Invalid data provider: {provider}. Valid providers: {', '.join(valid_providers)}",
)
return True, None
def validate_hf_token(token: str) -> tuple[bool, Optional[str]]:
"""
Validate HuggingFace API token format.
Args:
token: HuggingFace API token
Returns:
Tuple of (is_valid, error_message)
"""
if not token:
return False, "HuggingFace token cannot be empty"
# HF tokens typically start with "hf_" and are alphanumeric
if not token.startswith("hf_"):
return (
False,
"HuggingFace token should start with 'hf_'. Get your token from https://huggingface.co/settings/tokens",
)
# Check minimum length (HF tokens are typically 30-40 characters)
if len(token) < 20:
return False, "HuggingFace token appears too short. Please check your token."
# Check for valid characters (alphanumeric and underscores)
if not all(c.isalnum() or c == "_" for c in token):
return False, "HuggingFace token contains invalid characters"
return True, None
def validate_api_keys(api_keys: Dict[str, str]) -> tuple[bool, Optional[str]]:
"""
Validate API keys for various providers.
Args:
api_keys: Dictionary of API keys by provider
Returns:
Tuple of (is_valid, error_message)
"""
# Validate HuggingFace token if provided
if "huggingface" in api_keys or "hf_token" in api_keys:
token = api_keys.get("huggingface") or api_keys.get("hf_token")
is_valid, error = validate_hf_token(token)
if not is_valid:
return False, f"Invalid HuggingFace token: {error}"
# Add validation for other providers as needed
# OpenAI keys typically start with "sk-"
if "openai" in api_keys:
openai_key = api_keys["openai"]
if not openai_key.startswith("sk-"):
return False, "OpenAI API key should start with 'sk-'"
# Anthropic keys typically start with "sk-ant-"
if "anthropic" in api_keys:
anthropic_key = api_keys["anthropic"]
if not anthropic_key.startswith("sk-ant-"):
return False, "Anthropic API key should start with 'sk-ant-'"
return True, None
def validate_configuration(config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
"""
Validate complete configuration object.
Args:
config: Configuration dictionary
Returns:
Tuple of (is_valid, error_message)
"""
# Validate indicator parameters
if "indicator_parameters" in config:
is_valid, error = validate_indicator_parameters(config["indicator_parameters"])
if not is_valid:
return False, f"Invalid indicator parameters: {error}"
# Validate LLM provider
if "llm_provider" in config:
provider = config["llm_provider"]
if provider not in ["openai", "anthropic", "huggingface", "qwen"]:
return False, f"Invalid LLM provider: {provider}"
# Validate API keys if present
if "api_keys" in config:
is_valid, error = validate_api_keys(config["api_keys"])
if not is_valid:
return False, f"Invalid API keys: {error}"
# Validate data providers
if "data_providers" in config:
providers = config["data_providers"]
if "ohlc_primary" in providers:
is_valid, error = validate_data_provider(providers["ohlc_primary"])
if not is_valid:
return False, f"Invalid OHLC primary provider: {error}"
if "fundamentals_primary" in providers:
is_valid, error = validate_data_provider(providers["fundamentals_primary"])
if not is_valid:
return False, f"Invalid fundamentals provider: {error}"
return True, None