Spaces:
Running
Running
| """ | |
| DuckDB 数据库管理模块 | |
| 实现数据库初始化、Hugging Face Dataset 同步、单例连接管理 | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| from functools import lru_cache | |
| from dotenv import load_dotenv | |
| # 加载 .env 文件 | |
| load_dotenv() | |
| import duckdb | |
| from huggingface_hub import hf_hub_download, upload_file, login | |
| logger = logging.getLogger(__name__) | |
| # 环境变量配置 | |
| DUCKDB_PATH = os.getenv("DUCKDB_PATH", "/app/data/stock_data.duckdb") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "") | |
| HF_HOME = os.getenv("HF_HOME", "/tmp/.huggingface") | |
| class DatabaseManager: | |
| """DuckDB 数据库管理器 - 单例模式""" | |
| _instance: Optional['DatabaseManager'] = None | |
| _connection: Optional[duckdb.DuckDBPyConnection] = None | |
| def __new__(cls) -> 'DatabaseManager': | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def conn(self) -> duckdb.DuckDBPyConnection: | |
| """获取数据库连接""" | |
| if self._connection is None: | |
| # 确保目录存在 | |
| os.makedirs(os.path.dirname(DUCKDB_PATH), exist_ok=True) | |
| self._connection = duckdb.connect(DUCKDB_PATH) | |
| # 如果在 HF 环境,配置远程访问插件 | |
| if HF_TOKEN and DATASET_REPO_ID: | |
| try: | |
| self._connection.execute("INSTALL httpfs; LOAD httpfs;") | |
| self._connection.execute(f"SET s3_region='us-east-1';") # HF 默认区域 | |
| # 这里的逻辑可以根据是否切换到 Parquet 进一步扩展 | |
| except Exception as e: | |
| logger.warning(f"Failed to load httpfs extension: {e}") | |
| logger.info(f"Database connection established: {DUCKDB_PATH}") | |
| return self._connection | |
| def close(self) -> None: | |
| """关闭数据库连接""" | |
| if self._connection is not None: | |
| self._connection.close() | |
| self._connection = None | |
| logger.info("Database connection closed") | |
| def init_db(self) -> None: | |
| """ | |
| 初始化数据库 | |
| - 在本地环境:使用本地 DuckDB | |
| - 在 HF 环境:挂载云端 Parquet 文件 | |
| """ | |
| conn = self.conn | |
| # 检查是否在 HF 环境 | |
| is_hf = os.getenv("SPACES_REPO_NAME") is not None or (HF_TOKEN and DATASET_REPO_ID) | |
| if is_hf: | |
| logger.info("HF environment detected. Downloading remote Parquet files...") | |
| try: | |
| from huggingface_hub import list_repo_files, hf_hub_download | |
| # 1. 动态获取文件列表 | |
| all_files = list_repo_files(repo_id=DATASET_REPO_ID, repo_type="dataset") | |
| # 2. 股票列表:下载并加载 | |
| list_file = "data/stock_list.parquet" | |
| if list_file in all_files: | |
| local_list_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename=list_file, repo_type="dataset") | |
| conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{local_list_path}')") | |
| logger.info("Local stock_list table created from remote parquet") | |
| # 3. 日线数据:下载并加载 (下载模式替代挂载模式) | |
| # 注意:如果文件很多,这步可能会花点时间,但比下载整个 .duckdb 快得多 | |
| parquet_files = [f for f in all_files if f.startswith("data/parquet/") and f.endswith(".parquet")] | |
| if parquet_files: | |
| local_paths = [] | |
| for f in parquet_files: | |
| # HF SDK 会自动处理缓存,已下载的文件不会重复下载 | |
| path = hf_hub_download(repo_id=DATASET_REPO_ID, filename=f, repo_type="dataset") | |
| local_paths.append(f"'{path}'") | |
| files_sql = ", ".join(local_paths) | |
| conn.execute("DROP VIEW IF EXISTS stock_daily") | |
| conn.execute("DROP TABLE IF EXISTS stock_daily") | |
| conn.execute(f"CREATE OR REPLACE VIEW stock_daily AS SELECT * FROM read_parquet([{files_sql}])") | |
| logger.info(f"Remote stock_daily view created with {len(parquet_files)} partitions") | |
| else: | |
| logger.warning("No parquet data files found in dataset") | |
| self._create_tables() | |
| # 验证 | |
| count = conn.execute("SELECT COUNT(*) FROM stock_list").fetchone()[0] | |
| logger.info(f"Verification successful: {count} stocks found") | |
| except Exception as e: | |
| logger.error(f"Failed to load remote Parquet: {e}") | |
| self._create_tables() | |
| else: | |
| self._create_tables() | |
| logger.info("Local database initialized") | |
| def _download_from_hf(self) -> None: | |
| """从 Hugging Face Dataset 下载数据库文件""" | |
| if not HF_TOKEN or not DATASET_REPO_ID: | |
| logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping download") | |
| return | |
| try: | |
| # 登录 Hugging Face | |
| login(token=HF_TOKEN) | |
| # 下载数据库文件 | |
| downloaded_path = hf_hub_download( | |
| repo_id=DATASET_REPO_ID, | |
| filename="stock_data.duckdb", | |
| repo_type="dataset", | |
| local_dir=os.path.dirname(DUCKDB_PATH), | |
| ) | |
| # 如果下载路径不同,重命名 | |
| if downloaded_path != DUCKDB_PATH: | |
| os.rename(downloaded_path, DUCKDB_PATH) | |
| logger.info(f"Database downloaded from HF Dataset: {DATASET_REPO_ID}") | |
| except Exception as e: | |
| logger.error(f"Failed to download database from HF: {e}") | |
| def upload_db(self) -> None: | |
| """上传 Parquet 分区到 Hugging Face Dataset (不再上传全量 .duckdb)""" | |
| if not HF_TOKEN or not DATASET_REPO_ID: | |
| logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload") | |
| return | |
| try: | |
| # 先关闭连接 | |
| self.close() | |
| login(token=HF_TOKEN) | |
| # 1. 上传股票列表 (这个很小,可以全量覆盖) | |
| if Path(DUCKDB_PATH).exists(): | |
| # 导出列表为 parquet 方便云端读取 | |
| conn = duckdb.connect(DUCKDB_PATH) | |
| list_path = os.path.join(os.path.dirname(DUCKDB_PATH), "stock_list.parquet") | |
| conn.execute(f"COPY stock_list TO '{list_path}' (FORMAT PARQUET)") | |
| conn.close() | |
| upload_file( | |
| path_or_fileobj=list_path, | |
| path_in_repo="data/stock_list.parquet", | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| # 2. 上传所有 Parquet 行情文件 (增量核心) | |
| parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet" | |
| if parquet_dir.exists(): | |
| for p_file in parquet_dir.glob("*.parquet"): | |
| upload_file( | |
| path_or_fileobj=str(p_file), | |
| path_in_repo=f"data/parquet/{p_file.name}", | |
| repo_id=DATASET_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| logger.info(f"Parquet files uploaded to HF Dataset: {DATASET_REPO_ID}") | |
| except Exception as e: | |
| logger.error(f"Failed to upload to HF: {e}") | |
| finally: | |
| _ = self.conn | |
| def _create_tables(self) -> None: | |
| """创建数据库表结构""" | |
| conn = self.conn | |
| # 日线行情表 | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS stock_daily ( | |
| code VARCHAR, | |
| trade_date DATE, | |
| open DOUBLE, | |
| high DOUBLE, | |
| low DOUBLE, | |
| close DOUBLE, | |
| volume BIGINT, | |
| amount DOUBLE, | |
| pct_chg DOUBLE, | |
| turnover_rate DOUBLE, | |
| PRIMARY KEY (code, trade_date) | |
| ) | |
| """) | |
| # 股票基础信息表 | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS stock_list ( | |
| code VARCHAR PRIMARY KEY, | |
| name VARCHAR, | |
| market VARCHAR, | |
| list_date DATE | |
| ) | |
| """) | |
| # 创建索引 | |
| conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_code_date | |
| ON stock_daily (code, trade_date) | |
| """) | |
| logger.info("Database tables created/verified") | |
| # 全局单例实例 | |
| db_manager = DatabaseManager() | |
| def get_db() -> DatabaseManager: | |
| """获取数据库管理器实例""" | |
| return db_manager | |