File size: 5,072 Bytes
ab96cfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f6faa9
ab96cfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Configuration management for the HF EDA MCP Server.

This module provides configuration classes and utilities for managing
server settings, authentication, caching, and performance parameters.
"""

import os
import logging
import sys
from typing import Optional, Dict, Any
from dataclasses import dataclass, field


@dataclass
class ServerConfig:
    """Configuration class for the HF EDA MCP Server."""

    # Server settings
    port: int = 7860
    host: str = "0.0.0.0"
    mcp_server: bool = True
    share: bool = False

    # Authentication settings
    hf_token: Optional[str] = None

    # Logging settings
    log_level: str = "INFO"
    log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

    # Cache settings
    cache_dir: Optional[str] = None
    max_cache_size: int = 1000  # MB

    # Performance settings
    max_sample_size: int = 50000
    max_concurrent_requests: int = 10
    request_timeout: int = 300  # seconds

    # Additional Gradio settings
    gradio_settings: Dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_env(cls) -> "ServerConfig":
        """Create configuration from environment variables."""
        config = cls()

        # Server settings
        config.port = int(os.getenv("HF_EDA_PORT", config.port))
        config.host = os.getenv("HF_EDA_HOST", config.host)
        config.mcp_server = os.getenv("HF_EDA_MCP_ENABLED", "true").lower() == "true"
        config.share = os.getenv("HF_EDA_SHARE", "false").lower() == "true"

        # Authentication
        config.hf_token = os.getenv("HF_TOKEN")

        # Logging
        config.log_level = os.getenv("HF_EDA_LOG_LEVEL", config.log_level).upper()

        # Cache settings
        config.cache_dir = os.getenv("HF_EDA_CACHE_DIR")
        config.max_cache_size = int(
            os.getenv("HF_EDA_MAX_CACHE_SIZE", config.max_cache_size)
        )

        # Performance settings
        config.max_sample_size = int(
            os.getenv("HF_EDA_MAX_SAMPLE_SIZE", config.max_sample_size)
        )
        config.max_concurrent_requests = int(
            os.getenv("HF_EDA_MAX_CONCURRENT", config.max_concurrent_requests)
        )
        config.request_timeout = int(
            os.getenv("HF_EDA_REQUEST_TIMEOUT", config.request_timeout)
        )

        return config


def setup_logging(config: ServerConfig) -> logging.Logger:
    """Set up logging configuration."""
    # Configure root logger
    logging.basicConfig(
        level=getattr(logging, config.log_level),
        format=config.log_format,
        handlers=[
            logging.StreamHandler(sys.stdout),
        ],
    )

    # Create logger for this module
    logger = logging.getLogger(__name__)

    # Set specific log levels for external libraries
    logging.getLogger("gradio").setLevel(logging.WARNING)
    logging.getLogger("httpx").setLevel(logging.WARNING)
    logging.getLogger("urllib3").setLevel(logging.WARNING)

    return logger


def validate_config(config: ServerConfig) -> None:
    """Validate server configuration and log warnings for potential issues."""
    logger = logging.getLogger(__name__)

    # Validate port range
    if not (1024 <= config.port <= 65535):
        logger.warning(
            f"Port {config.port} may require elevated privileges or be invalid"
        )

    # Check cache directory
    if config.cache_dir:
        try:
            os.makedirs(config.cache_dir, exist_ok=True)
            if not os.access(config.cache_dir, os.W_OK):
                logger.error(f"Cache directory {config.cache_dir} is not writable")
                raise PermissionError(
                    f"Cannot write to cache directory: {config.cache_dir}"
                )
        except Exception as e:
            logger.error(
                f"Failed to create/access cache directory {config.cache_dir}: {e}"
            )
            raise

    # Validate performance settings
    if config.max_sample_size > 100000:
        logger.warning(
            f"Large max_sample_size ({config.max_sample_size}) may cause memory issues"
        )

    if config.request_timeout < 30:
        logger.warning(
            f"Short request timeout ({config.request_timeout}s) may cause failures for large datasets"
        )

    # Check authentication
    if not config.hf_token:
        logger.warning(
            "No HuggingFace token configured - only public datasets will be accessible"
        )
        logger.info("Set HF_TOKEN environment variable to access private datasets")
    else:
        logger.info("HuggingFace token configured - private datasets accessible")


# Global configuration instance
_global_config: Optional[ServerConfig] = None


def get_config() -> ServerConfig:
    """Get the global configuration instance."""
    global _global_config
    if _global_config is None:
        _global_config = ServerConfig.from_env()
    return _global_config


def set_config(config: ServerConfig) -> None:
    """Set the global configuration instance."""
    global _global_config
    _global_config = config