File size: 13,362 Bytes
a1bf219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
"""
Session management for API keys and user configuration.

This module provides in-memory session management with automatic timeout
to securely handle API keys and user-specific configuration.
"""

import time
from datetime import datetime, timedelta
from typing import Any, Dict, Optional


class SessionManager:
    """Manages user sessions with API keys and configuration."""

    def __init__(self, timeout_minutes: int = 60):
        """
        Initialize session manager.

        Args:
            timeout_minutes: Session timeout in minutes (default: 60)
        """
        self.timeout_minutes = timeout_minutes
        self.sessions: Dict[str, Dict[str, Any]] = {}

    def create_session(
        self, session_id: str, api_keys: Optional[Dict[str, str]] = None
    ) -> Dict[str, Any]:
        """
        Create a new session.

        Args:
            session_id: Unique session identifier
            api_keys: Optional API keys to store (default: use environment variables)

        Returns:
            Session data dictionary
        """
        session_data = {
            "created_at": datetime.now(),
            "last_accessed": datetime.now(),
            "api_keys": api_keys or {},
            "config": {},
        }

        self.sessions[session_id] = session_data
        return session_data

    def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
        """
        Get session data if it exists and hasn't expired.

        Args:
            session_id: Session identifier

        Returns:
            Session data or None if expired/not found
        """
        if session_id not in self.sessions:
            return None

        session = self.sessions[session_id]

        # Check if session has expired
        if self._is_expired(session):
            self.delete_session(session_id)
            return None

        # Update last accessed time
        session["last_accessed"] = datetime.now()
        return session

    def update_session(self, session_id: str, updates: Dict[str, Any]) -> bool:
        """
        Update session data.

        Args:
            session_id: Session identifier
            updates: Dictionary of updates to apply

        Returns:
            True if successful, False if session not found/expired
        """
        session = self.get_session(session_id)
        if session is None:
            return False

        # Merge updates into session
        for key, value in updates.items():
            if key == "api_keys":
                # Merge API keys
                session["api_keys"].update(value)
            elif key == "config":
                # Merge configuration
                session["config"].update(value)
            else:
                session[key] = value

        session["last_accessed"] = datetime.now()
        return True

    def delete_session(self, session_id: str) -> bool:
        """
        Delete session and clear all stored data.

        Args:
            session_id: Session identifier

        Returns:
            True if deleted, False if not found
        """
        if session_id in self.sessions:
            # Clear sensitive data
            if "api_keys" in self.sessions[session_id]:
                self.sessions[session_id]["api_keys"].clear()

            del self.sessions[session_id]
            return True
        return False

    def cleanup_expired_sessions(self) -> int:
        """
        Remove all expired sessions.

        Returns:
            Number of sessions cleaned up
        """
        expired_sessions = [
            session_id
            for session_id, session in self.sessions.items()
            if self._is_expired(session)
        ]

        for session_id in expired_sessions:
            self.delete_session(session_id)

        return len(expired_sessions)

    def _is_expired(self, session: Dict[str, Any]) -> bool:
        """
        Check if session has expired.

        Args:
            session: Session data dictionary

        Returns:
            True if expired, False otherwise
        """
        last_accessed = session.get("last_accessed")
        if last_accessed is None:
            return True

        expiry_time = last_accessed + timedelta(minutes=self.timeout_minutes)
        return datetime.now() > expiry_time

    def get_active_session_count(self) -> int:
        """
        Get count of active (non-expired) sessions.

        Returns:
            Number of active sessions
        """
        return len([s for s in self.sessions.values() if not self._is_expired(s)])


# Global session manager instance
_session_manager: Optional[SessionManager] = None


def get_session_manager(timeout_minutes: int = 60) -> SessionManager:
    """
    Get global session manager instance.

    Args:
        timeout_minutes: Session timeout in minutes

    Returns:
        SessionManager instance
    """
    global _session_manager
    if _session_manager is None:
        _session_manager = SessionManager(timeout_minutes)
    return _session_manager


# Convenience functions for common operations


def create_user_session(
    user_id: str, api_keys: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
    """Create a new user session."""
    manager = get_session_manager()
    return manager.create_session(user_id, api_keys)


def get_user_session(user_id: str) -> Optional[Dict[str, Any]]:
    """Get user session data."""
    manager = get_session_manager()
    return manager.get_session(user_id)


def update_user_session(user_id: str, updates: Dict[str, Any]) -> bool:
    """Update user session data."""
    manager = get_session_manager()
    return manager.update_session(user_id, updates)


def delete_user_session(user_id: str) -> bool:
    """Delete user session."""
    manager = get_session_manager()
    return manager.delete_session(user_id)


def cleanup_sessions() -> int:
    """Clean up expired sessions."""
    manager = get_session_manager()
    return manager.cleanup_expired_sessions()


# Configuration validation functions


def validate_indicator_parameters(params: Dict[str, Any]) -> tuple[bool, Optional[str]]:
    """
    Validate indicator parameters.

    Args:
        params: Indicator parameters dictionary

    Returns:
        Tuple of (is_valid, error_message)
    """
    # RSI period validation
    if "rsi_period" in params:
        rsi_period = params["rsi_period"]
        if not isinstance(rsi_period, int) or rsi_period < 2 or rsi_period > 100:
            return False, "RSI period must be an integer between 2 and 100"

    # MACD parameters validation
    if "macd_fast" in params:
        macd_fast = params["macd_fast"]
        if not isinstance(macd_fast, int) or macd_fast < 2 or macd_fast > 50:
            return False, "MACD fast period must be an integer between 2 and 50"

    if "macd_slow" in params:
        macd_slow = params["macd_slow"]
        if not isinstance(macd_slow, int) or macd_slow < 2 or macd_slow > 100:
            return False, "MACD slow period must be an integer between 2 and 100"

        # Ensure slow > fast
        if "macd_fast" in params and macd_slow <= params["macd_fast"]:
            return False, "MACD slow period must be greater than fast period"

    if "macd_signal" in params:
        macd_signal = params["macd_signal"]
        if not isinstance(macd_signal, int) or macd_signal < 2 or macd_signal > 50:
            return False, "MACD signal period must be an integer between 2 and 50"

    # Stochastic parameters validation
    if "stoch_k_period" in params:
        stoch_k = params["stoch_k_period"]
        if not isinstance(stoch_k, int) or stoch_k < 2 or stoch_k > 50:
            return False, "Stochastic K period must be an integer between 2 and 50"

    if "stoch_d_period" in params:
        stoch_d = params["stoch_d_period"]
        if not isinstance(stoch_d, int) or stoch_d < 2 or stoch_d > 20:
            return False, "Stochastic D period must be an integer between 2 and 20"

    return True, None


def validate_model_name(provider: str, model: str) -> tuple[bool, Optional[str]]:
    """
    Validate LLM model name for a given provider.

    Args:
        provider: LLM provider name (openai, anthropic, qwen)
        model: Model name

    Returns:
        Tuple of (is_valid, error_message)
    """
    valid_models = {
        "openai": [
            "gpt-4",
            "gpt-4-turbo",
            "gpt-4-turbo-preview",
            "gpt-3.5-turbo",
            "gpt-3.5-turbo-16k",
        ],
        "anthropic": [
            "claude-3-opus-20240229",
            "claude-3-sonnet-20240229",
            "claude-3-haiku-20240307",
            "claude-2.1",
            "claude-2.0",
        ],
        "qwen": ["qwen-turbo", "qwen-plus", "qwen-max"],
    }

    if provider not in valid_models:
        return (
            False,
            f"Unknown provider: {provider}. Valid providers: {', '.join(valid_models.keys())}",
        )

    if model not in valid_models[provider]:
        return (
            False,
            f"Invalid model '{model}' for provider '{provider}'. Valid models: {', '.join(valid_models[provider])}",
        )

    return True, None


def validate_data_provider(provider: str) -> tuple[bool, Optional[str]]:
    """
    Validate data provider name.

    Args:
        provider: Data provider name

    Returns:
        Tuple of (is_valid, error_message)
    """
    valid_providers = ["yfinance", "alpha_vantage"]

    if provider not in valid_providers:
        return (
            False,
            f"Invalid data provider: {provider}. Valid providers: {', '.join(valid_providers)}",
        )

    return True, None


def validate_hf_token(token: str) -> tuple[bool, Optional[str]]:
    """
    Validate HuggingFace API token format.

    Args:
        token: HuggingFace API token

    Returns:
        Tuple of (is_valid, error_message)
    """
    if not token:
        return False, "HuggingFace token cannot be empty"

    # HF tokens typically start with "hf_" and are alphanumeric
    if not token.startswith("hf_"):
        return (
            False,
            "HuggingFace token should start with 'hf_'. Get your token from https://huggingface.co/settings/tokens",
        )

    # Check minimum length (HF tokens are typically 30-40 characters)
    if len(token) < 20:
        return False, "HuggingFace token appears too short. Please check your token."

    # Check for valid characters (alphanumeric and underscores)
    if not all(c.isalnum() or c == "_" for c in token):
        return False, "HuggingFace token contains invalid characters"

    return True, None


def validate_api_keys(api_keys: Dict[str, str]) -> tuple[bool, Optional[str]]:
    """
    Validate API keys for various providers.

    Args:
        api_keys: Dictionary of API keys by provider

    Returns:
        Tuple of (is_valid, error_message)
    """
    # Validate HuggingFace token if provided
    if "huggingface" in api_keys or "hf_token" in api_keys:
        token = api_keys.get("huggingface") or api_keys.get("hf_token")
        is_valid, error = validate_hf_token(token)
        if not is_valid:
            return False, f"Invalid HuggingFace token: {error}"

    # Add validation for other providers as needed
    # OpenAI keys typically start with "sk-"
    if "openai" in api_keys:
        openai_key = api_keys["openai"]
        if not openai_key.startswith("sk-"):
            return False, "OpenAI API key should start with 'sk-'"

    # Anthropic keys typically start with "sk-ant-"
    if "anthropic" in api_keys:
        anthropic_key = api_keys["anthropic"]
        if not anthropic_key.startswith("sk-ant-"):
            return False, "Anthropic API key should start with 'sk-ant-'"

    return True, None


def validate_configuration(config: Dict[str, Any]) -> tuple[bool, Optional[str]]:
    """
    Validate complete configuration object.

    Args:
        config: Configuration dictionary

    Returns:
        Tuple of (is_valid, error_message)
    """
    # Validate indicator parameters
    if "indicator_parameters" in config:
        is_valid, error = validate_indicator_parameters(config["indicator_parameters"])
        if not is_valid:
            return False, f"Invalid indicator parameters: {error}"

    # Validate LLM provider
    if "llm_provider" in config:
        provider = config["llm_provider"]
        if provider not in ["openai", "anthropic", "huggingface", "qwen"]:
            return False, f"Invalid LLM provider: {provider}"

    # Validate API keys if present
    if "api_keys" in config:
        is_valid, error = validate_api_keys(config["api_keys"])
        if not is_valid:
            return False, f"Invalid API keys: {error}"

    # Validate data providers
    if "data_providers" in config:
        providers = config["data_providers"]

        if "ohlc_primary" in providers:
            is_valid, error = validate_data_provider(providers["ohlc_primary"])
            if not is_valid:
                return False, f"Invalid OHLC primary provider: {error}"

        if "fundamentals_primary" in providers:
            is_valid, error = validate_data_provider(providers["fundamentals_primary"])
            if not is_valid:
                return False, f"Invalid fundamentals provider: {error}"

    return True, None