""" DuckDB 数据库管理模块 - Sync Space 专用版 实现数据库初始化、Hugging Face Dataset 同步、单例连接管理 """ import os import logging from pathlib import Path from typing import Optional, List from dotenv import load_dotenv # 加载 .env 文件 load_dotenv() import duckdb from huggingface_hub import hf_hub_download, upload_file, login, upload_folder logger = logging.getLogger(__name__) # 环境变量配置 - Sync Space 专用路径 DUCKDB_PATH = os.getenv("DUCKDB_PATH", "/tmp/data/stock_data.duckdb") HF_TOKEN = os.getenv("HF_TOKEN") DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "") HF_HOME = os.getenv("HF_HOME", "/tmp/.huggingface") class DatabaseManager: """DuckDB 数据库管理器 - 单例模式""" _instance: Optional['DatabaseManager'] = None _connection: Optional[duckdb.DuckDBPyConnection] = None def __new__(cls) -> 'DatabaseManager': if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @property def conn(self) -> duckdb.DuckDBPyConnection: """获取数据库连接""" if self._connection is None: # 确保目录存在 os.makedirs(os.path.dirname(DUCKDB_PATH), exist_ok=True) self._connection = duckdb.connect(DUCKDB_PATH) logger.info(f"Database connection established: {DUCKDB_PATH}") return self._connection def close(self) -> None: """关闭数据库连接""" if self._connection is not None: self._connection.close() self._connection = None logger.info("Database connection closed") def _download_with_retry(self, repo_id: str, filename: str, max_retries: int = 3) -> Optional[str]: """带重试的文件下载""" import time from huggingface_hub import hf_hub_download for attempt in range(max_retries): try: return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") except Exception as e: if attempt < max_retries - 1: delay = 2 ** attempt # 指数退避: 2s, 4s, 8s logger.warning(f"Download failed (attempt {attempt + 1}), retrying in {delay}s...") time.sleep(delay) else: logger.error(f"Download failed after {max_retries} attempts: {e}") return None def _list_files_with_retry(self, repo_id: str, max_retries: int = 3) -> Optional[List]: """带重试的文件列表获取""" import time from huggingface_hub import list_repo_files for attempt in range(max_retries): try: return list(list_repo_files(repo_id=repo_id, repo_type="dataset")) except Exception as e: if attempt < max_retries - 1: delay = 2 ** attempt # 指数退避: 2s, 4s, 8s logger.warning(f"List files failed (attempt {attempt + 1}), retrying in {delay}s...") time.sleep(delay) else: logger.error(f"List files failed after {max_retries} attempts: {e}") return None def _fallback_download_from_status(self) -> bool: """通过 sync_status.json 回退下载关键数据""" import json import shutil if not HF_TOKEN or not DATASET_REPO_ID: return False logger.info("Attempting fallback download via sync_status.json...") # 尝试下载 sync_status.json status_path = self._download_with_retry(DATASET_REPO_ID, "data/sync_status.json", max_retries=2) if not status_path: logger.warning("Failed to download sync_status.json, cannot use fallback mode") return False try: with open(status_path, 'r') as f: status = json.load(f) # 下载股票列表(必需) if 'stock_list' in status: list_file = Path(os.path.dirname(DUCKDB_PATH)) / "stock_list.parquet" local_path = self._download_with_retry(DATASET_REPO_ID, "data/stock_list.parquet", max_retries=2) if local_path: shutil.copy(local_path, list_file) self.conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{list_file}')") logger.info("Stock list downloaded via fallback") # 下载最近3个月的日K数据 daily_status = status.get('daily', {}) last_trade_date = daily_status.get('last_trade_date', '') if last_trade_date: # 计算最近3个月 from datetime import datetime last_date = datetime.strptime(last_trade_date, '%Y-%m-%d') months_to_download = [] for i in range(3): year = last_date.year month = last_date.month - i if month <= 0: year -= 1 month += 12 months_to_download.append(f"{year}-{month:02d}") parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet" parquet_dir.mkdir(parents=True, exist_ok=True) downloaded = 0 for month_str in months_to_download: filename = f"data/parquet/{month_str}.parquet" local_path = self._download_with_retry(DATASET_REPO_ID, filename, max_retries=1) if local_path: dest_path = parquet_dir / f"{month_str}.parquet" shutil.copy(local_path, dest_path) downloaded += 1 if downloaded > 0: self._refresh_views() logger.info(f"Fallback download complete: {downloaded} months of daily data") return True except Exception as e: logger.error(f"Fallback download failed: {e}") return False def init_db(self, force_download: bool = False) -> None: """ 初始化数据库 - Sync Space 智能下载模式(带重试和回退) Args: force_download: 强制从 HF Dataset 下载数据(默认 False) """ conn = self.conn # 1. 检查本地是否已有数据表 if not force_download: try: count = conn.execute("SELECT COUNT(*) FROM stock_list").fetchone()[0] if count > 0: logger.info(f"Local database tables exist ({count} stocks).") # 即使表存在,也要确保视图被创建(如果本地有 parquet 文件) self._refresh_views() return except Exception: pass # 2. 尝试从本地 Parquet 文件恢复(Space 没重启的情况) parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet" list_file = Path(os.path.dirname(DUCKDB_PATH)) / "stock_list.parquet" if not force_download and list_file.exists(): try: conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{list_file}')") self._refresh_views() logger.info("Database restored from local parquet files.") return except Exception as e: logger.warning(f"Failed to restore from local parquet: {e}") # 3. 从 HF Dataset 下载数据(带重试) if HF_TOKEN and DATASET_REPO_ID: logger.info("Downloading remote Parquet files from HF Dataset...") try: # 首先尝试获取文件列表(带重试) all_files = self._list_files_with_retry(DATASET_REPO_ID, max_retries=3) if all_files is None: # 列表获取失败,尝试回退模式 logger.warning("Failed to list files, attempting fallback mode...") if self._fallback_download_from_status(): logger.info("Database initialized via fallback mode") return else: logger.warning("Fallback mode failed, creating empty database") self._create_tables() return # 正常流程:下载股票列表 if "data/stock_list.parquet" in all_files: local_list_path = self._download_with_retry(DATASET_REPO_ID, "data/stock_list.parquet") if local_list_path: import shutil shutil.copy(local_list_path, list_file) conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{list_file}')") # 下载日线数据分区(只下载最近3个月) parquet_files = sorted([f for f in all_files if f.startswith("data/parquet/") and f.endswith(".parquet")]) if parquet_files: # 只下载最近3个月的数据 recent_files = parquet_files[-3:] logger.info(f"Downloading {len(recent_files)} recent parquet files (last 3 months)") for f in recent_files: remote_path = self._download_with_retry(DATASET_REPO_ID, f) if remote_path: dest_path = Path(os.path.dirname(DUCKDB_PATH)) / f.replace("data/", "") dest_path.parent.mkdir(parents=True, exist_ok=True) import shutil shutil.copy(remote_path, dest_path) self._refresh_views() logger.info(f"Remote data downloaded and views created.") else: self._create_tables() except Exception as e: logger.error(f"Failed to load remote Parquet: {e}") self._create_tables() else: self._create_tables() logger.info("Local database initialized") def _refresh_views(self) -> None: """刷新数据库视图""" conn = self.conn parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet" if parquet_dir.exists(): p_files = list(parquet_dir.glob("*.parquet")) if p_files: files_sql = ", ".join([f"'{str(f)}'" for f in p_files]) conn.execute("DROP VIEW IF EXISTS stock_daily") conn.execute(f"CREATE OR REPLACE VIEW stock_daily AS SELECT * FROM read_parquet([{files_sql}])") logger.info(f"Database views refreshed with {len(p_files)} partitions") def upload_db(self) -> None: """上传 Parquet 分区到 Hugging Face Dataset""" if not HF_TOKEN or not DATASET_REPO_ID: logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload") return try: # 先关闭连接 self.close() login(token=HF_TOKEN) # 1. 上传股票列表 if Path(DUCKDB_PATH).exists(): conn = duckdb.connect(DUCKDB_PATH) list_path = os.path.join(os.path.dirname(DUCKDB_PATH), "stock_list.parquet") conn.execute(f"COPY stock_list TO '{list_path}' (FORMAT PARQUET)") conn.close() upload_file( path_or_fileobj=list_path, path_in_repo="data/stock_list.parquet", repo_id=DATASET_REPO_ID, repo_type="dataset", ) # 2. 上传所有 Parquet 行情文件 parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet" if parquet_dir.exists(): for p_file in parquet_dir.glob("*.parquet"): upload_file( path_or_fileobj=str(p_file), path_in_repo=f"data/parquet/{p_file.name}", repo_id=DATASET_REPO_ID, repo_type="dataset", ) # 3. 上传资金流向数据(按月分表) fund_flow_dir = Path(os.path.dirname(DUCKDB_PATH)) / "fund_flow" if fund_flow_dir.exists(): for ff_file in fund_flow_dir.glob("*.parquet"): upload_file( path_or_fileobj=str(ff_file), path_in_repo=f"data/fund_flow/{ff_file.name}", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Fund flow data uploaded") # 4. 上传估值指标数据(按月分表) valuation_dir = Path(os.path.dirname(DUCKDB_PATH)) / "valuation" if valuation_dir.exists(): for val_file in valuation_dir.glob("*.parquet"): upload_file( path_or_fileobj=str(val_file), path_in_repo=f"data/valuation/{val_file.name}", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Valuation data uploaded") # 5. 上传融资融券数据(按月分表) margin_dir = Path(os.path.dirname(DUCKDB_PATH)) / "margin" if margin_dir.exists(): for mar_file in margin_dir.glob("*.parquet"): upload_file( path_or_fileobj=str(mar_file), path_in_repo=f"data/margin/{mar_file.name}", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Margin data uploaded") # 6. 上传财务指标数据 financial_path = Path(os.path.dirname(DUCKDB_PATH)) / "financial_indicator.parquet" if financial_path.exists(): upload_file( path_or_fileobj=str(financial_path), path_in_repo="data/financial_indicator.parquet", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Financial indicator data uploaded") # 7. 上传股东户数数据 holder_path = Path(os.path.dirname(DUCKDB_PATH)) / "holder_num.parquet" if holder_path.exists(): upload_file( path_or_fileobj=str(holder_path), path_in_repo="data/holder_num.parquet", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Holder number data uploaded") # 8. 上传分红数据 dividend_path = Path(os.path.dirname(DUCKDB_PATH)) / "dividend.parquet" if dividend_path.exists(): upload_file( path_or_fileobj=str(dividend_path), path_in_repo="data/dividend.parquet", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Dividend data uploaded") # 9. 上传十大股东数据 top_holders_path = Path(os.path.dirname(DUCKDB_PATH)) / "top_holders.parquet" if top_holders_path.exists(): upload_file( path_or_fileobj=str(top_holders_path), path_in_repo="data/top_holders.parquet", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Top holders data uploaded") # 10. 上传限售解禁数据 restricted_path = Path(os.path.dirname(DUCKDB_PATH)) / "restricted_unlock.parquet" if restricted_path.exists(): upload_file( path_or_fileobj=str(restricted_path), path_in_repo="data/restricted_unlock.parquet", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info("Restricted unlock data uploaded") logger.info(f"Parquet files uploaded to HF Dataset: {DATASET_REPO_ID}") except Exception as e: logger.error(f"Failed to upload to HF: {e}") finally: _ = self.conn def upload_indicator(self, indicator_name: str, local_path: Path, remote_path: str, max_retries: int = 3) -> bool: """ 上传单个指标数据到 HF Dataset(批量上传,每类指标一次 commit,带重试) Args: indicator_name: 指标名称(用于日志) local_path: 本地文件或目录路径 remote_path: 远程路径前缀(如 "data/fund_flow") max_retries: 最大重试次数(默认3) Returns: bool: 是否上传成功 """ if not HF_TOKEN or not DATASET_REPO_ID: logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload") return False import time for attempt in range(max_retries): try: login(token=HF_TOKEN) if local_path.is_file(): # 单文件上传 upload_file( path_or_fileobj=str(local_path), path_in_repo=f"{remote_path}/{local_path.name}", repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message=f"Update {indicator_name}: {local_path.name}" ) logger.info(f"{indicator_name} uploaded: {local_path.name}") elif local_path.is_dir(): # 目录批量上传(所有文件一次 commit) import tempfile import shutil # 收集所有 parquet 文件 files_to_upload = list(local_path.glob("*.parquet")) if not files_to_upload: logger.info(f"{indicator_name}: no files to upload") return True # 使用 upload_folder 批量上传(只产生一次 commit) upload_folder( folder_path=str(local_path), path_in_repo=remote_path, repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message=f"Update {indicator_name}: {len(files_to_upload)} files", ignore_patterns=["*.tmp", "*.lock"] # 忽略临时文件 ) logger.info(f"{indicator_name} uploaded: {len(files_to_upload)} files (batch)") return True except Exception as e: if attempt < max_retries - 1: delay = 2 ** attempt # 指数退避:1s, 2s, 4s logger.warning(f"Upload failed (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s...") time.sleep(delay) else: logger.error(f"Failed to upload {indicator_name} after {max_retries} attempts: {e}") return False return False def upload_indicator_smart(self, indicator_name: str, local_dir: Path, remote_path: str, changed_files: List[str], batch_threshold: int = 10, max_retries: int = 3) -> bool: """ 智能上传指标数据(带重试): - 变更文件多(>= threshold)→ 批量上传变更文件(临时目录,一次 commit) - 变更文件少(< threshold)→ 逐个上传变更文件(多个 commit,但数量少) Args: indicator_name: 指标名称(用于日志) local_dir: 本地目录路径 remote_path: 远程路径前缀 changed_files: 变更的文件名列表 batch_threshold: 批量上传阈值(默认10) max_retries: 最大重试次数(默认3) Returns: bool: 是否上传成功 """ if not HF_TOKEN or not DATASET_REPO_ID: logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload") return False if not changed_files: logger.info(f"{indicator_name}: no changes to upload") return True import time for attempt in range(max_retries): try: login(token=HF_TOKEN) if len(changed_files) >= batch_threshold: # 变更文件多 → 创建临时目录,只复制变更文件,批量上传(一次 commit) logger.info(f"{indicator_name}: {len(changed_files)} files changed, using batch upload") import tempfile import shutil with tempfile.TemporaryDirectory() as tmpdir: tmp_path = Path(tmpdir) # 只复制变更的文件到临时目录 for filename in changed_files: local_file = local_dir / filename if local_file.exists(): shutil.copy(str(local_file), str(tmp_path / filename)) # 上传临时目录(只包含变更文件) upload_folder( folder_path=str(tmp_path), path_in_repo=remote_path, repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message=f"Update {indicator_name}: {len(changed_files)} files" ) logger.info(f"{indicator_name} uploaded: {len(changed_files)} files (batch)") else: # 变更文件少 → 逐个上传(每个文件一次 commit,但数量少) logger.info(f"{indicator_name}: {len(changed_files)} files changed, uploading individually") for filename in changed_files: local_file = local_dir / filename if local_file.exists(): upload_file( path_or_fileobj=str(local_file), path_in_repo=f"{remote_path}/{filename}", repo_id=DATASET_REPO_ID, repo_type="dataset", ) logger.info(f"{indicator_name} uploaded: {len(changed_files)} files (individual)") return True except Exception as e: if attempt < max_retries - 1: delay = 2 ** attempt # 指数退避:1s, 2s, 4s logger.warning(f"Upload failed (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s...") time.sleep(delay) else: logger.error(f"Failed to upload {indicator_name} after {max_retries} attempts: {e}") return False return False def _create_tables(self) -> None: """创建数据库表结构""" conn = self.conn # 日线行情表(保持原有结构不变) conn.execute(""" CREATE TABLE IF NOT EXISTS stock_daily ( code VARCHAR, trade_date DATE, open DOUBLE, high DOUBLE, low DOUBLE, close DOUBLE, volume BIGINT, amount DOUBLE, pct_chg DOUBLE, turnover_rate DOUBLE, PRIMARY KEY (code, trade_date) ) """) # 股票基础信息表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_list ( code VARCHAR PRIMARY KEY, name VARCHAR, market VARCHAR, list_date DATE ) """) # 资金流向表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_fund_flow ( code VARCHAR, trade_date DATE, close DOUBLE, pct_chg DOUBLE, main_net_inflow DOUBLE, main_net_inflow_pct DOUBLE, huge_net_inflow DOUBLE, huge_net_inflow_pct DOUBLE, large_net_inflow DOUBLE, large_net_inflow_pct DOUBLE, medium_net_inflow DOUBLE, medium_net_inflow_pct DOUBLE, small_net_inflow DOUBLE, small_net_inflow_pct DOUBLE, PRIMARY KEY (code, trade_date) ) """) # 估值指标表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_valuation ( code VARCHAR, trade_date DATE, pe_ttm DOUBLE, pe_static DOUBLE, pb DOUBLE, ps_ttm DOUBLE, dv_ratio DOUBLE, total_mv DOUBLE, circ_mv DOUBLE, PRIMARY KEY (code, trade_date) ) """) # 融资融券表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_margin ( code VARCHAR, trade_date DATE, rzye DOUBLE, rzmre DOUBLE, rzche DOUBLE, rqye DOUBLE, rqmcl DOUBLE, rzrqye DOUBLE, PRIMARY KEY (code, trade_date) ) """) # 财务指标表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_financial_indicator ( code VARCHAR, trade_date DATE, roe DOUBLE, roa DOUBLE, gross_margin DOUBLE, net_margin DOUBLE, debt_ratio DOUBLE, current_ratio DOUBLE, quick_ratio DOUBLE, inventory_turnover DOUBLE, receivable_turnover DOUBLE, total_asset_turnover DOUBLE, PRIMARY KEY (code, trade_date) ) """) # 股东户数表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_holder_num ( code VARCHAR, trade_date DATE, holder_num BIGINT, avg_share DOUBLE, avg_value DOUBLE, total_share DOUBLE, total_value DOUBLE, PRIMARY KEY (code, trade_date) ) """) # 历史分红表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_dividend ( code VARCHAR, trade_date DATE, dividend_type VARCHAR, dividend_amount DOUBLE, record_date DATE, ex_date DATE, pay_date DATE, PRIMARY KEY (code, trade_date, dividend_type) ) """) # 十大股东表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_top_holders ( code VARCHAR, trade_date DATE, holder_name VARCHAR, holder_type VARCHAR, hold_num DOUBLE, hold_ratio DOUBLE, hold_change DOUBLE, hold_change_ratio DOUBLE, PRIMARY KEY (code, trade_date, holder_name) ) """) # 限售解禁表 conn.execute(""" CREATE TABLE IF NOT EXISTS stock_restricted_unlock ( code VARCHAR, trade_date DATE, unlock_date DATE, unlock_num DOUBLE, unlock_value DOUBLE, unlock_ratio DOUBLE, lock_type VARCHAR, PRIMARY KEY (code, unlock_date) ) """) # 创建索引 conn.execute(""" CREATE INDEX IF NOT EXISTS idx_code_date ON stock_daily (code, trade_date) """) conn.execute(""" CREATE INDEX IF NOT EXISTS idx_fund_flow_code_date ON stock_fund_flow (code, trade_date) """) conn.execute(""" CREATE INDEX IF NOT EXISTS idx_valuation_code_date ON stock_valuation (code, trade_date) """) conn.execute(""" CREATE INDEX IF NOT EXISTS idx_margin_code_date ON stock_margin (code, trade_date) """) logger.info("Database tables created/verified") # 全局单例实例 db_manager = DatabaseManager() def get_db() -> DatabaseManager: """获取数据库管理器实例""" return db_manager