""" DuckDB 数据库管理模块 实现数据库初始化、Hugging Face Dataset 同步、单例连接管理 """ import os import logging from pathlib import Path from typing import Optional from functools import lru_cache from dotenv import load_dotenv # 加载 .env 文件 load_dotenv() import duckdb from huggingface_hub import hf_hub_download, upload_file, login logger = logging.getLogger(__name__) # 环境变量配置 DUCKDB_PATH = os.getenv("DUCKDB_PATH", "/app/data/stock_data.duckdb") HF_TOKEN = os.getenv("HF_TOKEN") DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "") HF_HOME = os.getenv("HF_HOME", "/tmp/.huggingface") class DatabaseManager: """DuckDB 数据库管理器 - 单例模式""" _instance: Optional['DatabaseManager'] = None _connection: Optional[duckdb.DuckDBPyConnection] = None def __new__(cls) -> 'DatabaseManager': if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @property def conn(self) -> duckdb.DuckDBPyConnection: """获取数据库连接""" if self._connection is None: # 确保目录存在 os.makedirs(os.path.dirname(DUCKDB_PATH), exist_ok=True) self._connection = duckdb.connect(DUCKDB_PATH) # 如果在 HF 环境,配置远程访问插件 if HF_TOKEN and DATASET_REPO_ID: try: self._connection.execute("INSTALL httpfs; LOAD httpfs;") self._connection.execute(f"SET s3_region='us-east-1';") # HF 默认区域 # 这里的逻辑可以根据是否切换到 Parquet 进一步扩展 except Exception as e: logger.warning(f"Failed to load httpfs extension: {e}") logger.info(f"Database connection established: {DUCKDB_PATH}") return self._connection def close(self) -> None: """关闭数据库连接""" if self._connection is not None: self._connection.close() self._connection = None logger.info("Database connection closed") def init_db(self) -> None: """ 初始化数据库 - 在本地环境:使用本地 DuckDB - 在 HF 环境:挂载云端 Parquet 文件 """ conn = self.conn # 检查是否在 HF 环境 is_hf = os.getenv("SPACES_REPO_NAME") is not None or (HF_TOKEN and DATASET_REPO_ID) if is_hf: logger.info("HF environment detected. Downloading remote Parquet files...") try: from huggingface_hub import list_repo_files, hf_hub_download # 1. 动态获取文件列表 all_files = list_repo_files(repo_id=DATASET_REPO_ID, repo_type="dataset") # 2. 股票列表:下载并加载 list_file = "data/stock_list.parquet" if list_file in all_files: local_list_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename=list_file, repo_type="dataset") conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{local_list_path}')") logger.info("Local stock_list table created from remote parquet") # 3. 日线数据:下载并加载 (下载模式替代挂载模式) # 注意:如果文件很多,这步可能会花点时间,但比下载整个 .duckdb 快得多 parquet_files = [f for f in all_files if f.startswith("data/parquet/") and f.endswith(".parquet")] if parquet_files: local_paths = [] for f in parquet_files: # HF SDK 会自动处理缓存,已下载的文件不会重复下载 path = hf_hub_download(repo_id=DATASET_REPO_ID, filename=f, repo_type="dataset") local_paths.append(f"'{path}'") files_sql = ", ".join(local_paths) conn.execute("DROP VIEW IF EXISTS stock_daily") conn.execute("DROP TABLE IF EXISTS stock_daily") conn.execute(f"CREATE OR REPLACE VIEW stock_daily AS SELECT * FROM read_parquet([{files_sql}])") logger.info(f"Remote stock_daily view created with {len(parquet_files)} partitions") else: logger.warning("No parquet data files found in dataset") self._create_tables() # 验证 count = conn.execute("SELECT COUNT(*) FROM stock_list").fetchone()[0] logger.info(f"Verification successful: {count} stocks found") except Exception as e: logger.error(f"Failed to load remote Parquet: {e}") self._create_tables() else: self._create_tables() logger.info("Local database initialized") def _download_from_hf(self) -> None: """从 Hugging Face Dataset 下载数据库文件""" if not HF_TOKEN or not DATASET_REPO_ID: logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping download") return try: # 登录 Hugging Face login(token=HF_TOKEN) # 下载数据库文件 downloaded_path = hf_hub_download( repo_id=DATASET_REPO_ID, filename="stock_data.duckdb", repo_type="dataset", local_dir=os.path.dirname(DUCKDB_PATH), ) # 如果下载路径不同,重命名 if downloaded_path != DUCKDB_PATH: os.rename(downloaded_path, DUCKDB_PATH) logger.info(f"Database downloaded from HF Dataset: {DATASET_REPO_ID}") except Exception as e: logger.error(f"Failed to download database from HF: {e}") def upload_db(self) -> None: """上传 Parquet 分区到 Hugging Face Dataset (不再上传全量 .duckdb)""" if not HF_TOKEN or not DATASET_REPO_ID: logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload") return try: # 先关闭连接 self.close() login(token=HF_TOKEN) # 1. 上传股票列表 (这个很小,可以全量覆盖) if Path(DUCKDB_PATH).exists(): # 导出列表为 parquet 方便云端读取 conn = duckdb.connect(DUCKDB_PATH) list_path = os.path.join(os.path.dirname(DUCKDB_PATH), "stock_list.parquet") conn.execute(f"COPY stock_list TO '{list_path}' (FORMAT PARQUET)") conn.close() upload_file( path_or_fileobj=list_path, path_in_repo="data/stock_list.parquet", repo_id=DATASET_REPO_ID, repo_type="dataset", ) # 2. 上传所有 Parquet 行情文件 (增量核心) parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet" if parquet_dir.exists(): for p_file in parquet_dir.glob("*.parquet"): upload_file( path_or_fileobj=str(p_file), path_in_repo=f"data/parquet/{p_file.name}", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info(f"Parquet files uploaded to HF Dataset: {DATASET_REPO_ID}") except Exception as e: logger.error(f"Failed to upload to HF: {e}") finally: _ = self.conn def _create_tables(self) -> None: """创建数据库表结构""" conn = self.conn # 日线行情表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_daily ( code VARCHAR, trade_date DATE, open DOUBLE, high DOUBLE, low DOUBLE, close DOUBLE, volume BIGINT, amount DOUBLE, pct_chg DOUBLE, turnover_rate DOUBLE, PRIMARY KEY (code, trade_date) ) """) # 股票基础信息表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_list ( code VARCHAR PRIMARY KEY, name VARCHAR, market VARCHAR, list_date DATE ) """) # 创建索引 conn.execute(""" CREATE INDEX IF NOT EXISTS idx_code_date ON stock_daily (code, trade_date) """) logger.info("Database tables created/verified") # 全局单例实例 db_manager = DatabaseManager() def get_db() -> DatabaseManager: """获取数据库管理器实例""" return db_manager