sync_stock / app /database.py
superxu520's picture
"perf:reduce-concurrent-threads-and-optimize-download"
e6d501f
"""
DuckDB 数据库管理模块 - Sync Space 专用版
实现数据库初始化、Hugging Face Dataset 同步、单例连接管理
"""
import os
import logging
from pathlib import Path
from typing import Optional, List
from dotenv import load_dotenv
# 加载 .env 文件
load_dotenv()
import duckdb
from huggingface_hub import hf_hub_download, upload_file, login, upload_folder
logger = logging.getLogger(__name__)
# 环境变量配置 - Sync Space 专用路径
DUCKDB_PATH = os.getenv("DUCKDB_PATH", "/tmp/data/stock_data.duckdb")
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "")
HF_HOME = os.getenv("HF_HOME", "/tmp/.huggingface")
class DatabaseManager:
"""DuckDB 数据库管理器 - 单例模式"""
_instance: Optional['DatabaseManager'] = None
_connection: Optional[duckdb.DuckDBPyConnection] = None
def __new__(cls) -> 'DatabaseManager':
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@property
def conn(self) -> duckdb.DuckDBPyConnection:
"""获取数据库连接"""
if self._connection is None:
# 确保目录存在
os.makedirs(os.path.dirname(DUCKDB_PATH), exist_ok=True)
self._connection = duckdb.connect(DUCKDB_PATH)
logger.info(f"Database connection established: {DUCKDB_PATH}")
return self._connection
def close(self) -> None:
"""关闭数据库连接"""
if self._connection is not None:
self._connection.close()
self._connection = None
logger.info("Database connection closed")
def _download_with_retry(self, repo_id: str, filename: str, max_retries: int = 3) -> Optional[str]:
"""带重试的文件下载"""
import time
from huggingface_hub import hf_hub_download
for attempt in range(max_retries):
try:
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
except Exception as e:
if attempt < max_retries - 1:
delay = 2 ** attempt # 指数退避: 2s, 4s, 8s
logger.warning(f"Download failed (attempt {attempt + 1}), retrying in {delay}s...")
time.sleep(delay)
else:
logger.error(f"Download failed after {max_retries} attempts: {e}")
return None
def _list_files_with_retry(self, repo_id: str, max_retries: int = 3) -> Optional[List]:
"""带重试的文件列表获取"""
import time
from huggingface_hub import list_repo_files
for attempt in range(max_retries):
try:
return list(list_repo_files(repo_id=repo_id, repo_type="dataset"))
except Exception as e:
if attempt < max_retries - 1:
delay = 2 ** attempt # 指数退避: 2s, 4s, 8s
logger.warning(f"List files failed (attempt {attempt + 1}), retrying in {delay}s...")
time.sleep(delay)
else:
logger.error(f"List files failed after {max_retries} attempts: {e}")
return None
def _fallback_download_from_status(self) -> bool:
"""通过 sync_status.json 回退下载关键数据"""
import json
import shutil
if not HF_TOKEN or not DATASET_REPO_ID:
return False
logger.info("Attempting fallback download via sync_status.json...")
# 尝试下载 sync_status.json
status_path = self._download_with_retry(DATASET_REPO_ID, "data/sync_status.json", max_retries=2)
if not status_path:
logger.warning("Failed to download sync_status.json, cannot use fallback mode")
return False
try:
with open(status_path, 'r') as f:
status = json.load(f)
# 下载股票列表(必需)
if 'stock_list' in status:
list_file = Path(os.path.dirname(DUCKDB_PATH)) / "stock_list.parquet"
local_path = self._download_with_retry(DATASET_REPO_ID, "data/stock_list.parquet", max_retries=2)
if local_path:
shutil.copy(local_path, list_file)
self.conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{list_file}')")
logger.info("Stock list downloaded via fallback")
# 下载最近3个月的日K数据
daily_status = status.get('daily', {})
last_trade_date = daily_status.get('last_trade_date', '')
if last_trade_date:
# 计算最近3个月
from datetime import datetime
last_date = datetime.strptime(last_trade_date, '%Y-%m-%d')
months_to_download = []
for i in range(3):
year = last_date.year
month = last_date.month - i
if month <= 0:
year -= 1
month += 12
months_to_download.append(f"{year}-{month:02d}")
parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet"
parquet_dir.mkdir(parents=True, exist_ok=True)
downloaded = 0
for month_str in months_to_download:
filename = f"data/parquet/{month_str}.parquet"
local_path = self._download_with_retry(DATASET_REPO_ID, filename, max_retries=1)
if local_path:
dest_path = parquet_dir / f"{month_str}.parquet"
shutil.copy(local_path, dest_path)
downloaded += 1
if downloaded > 0:
self._refresh_views()
logger.info(f"Fallback download complete: {downloaded} months of daily data")
return True
except Exception as e:
logger.error(f"Fallback download failed: {e}")
return False
def init_db(self, force_download: bool = False) -> None:
"""
初始化数据库 - Sync Space 智能下载模式(带重试和回退)
Args:
force_download: 强制从 HF Dataset 下载数据(默认 False)
"""
conn = self.conn
# 1. 检查本地是否已有数据表
if not force_download:
try:
count = conn.execute("SELECT COUNT(*) FROM stock_list").fetchone()[0]
if count > 0:
logger.info(f"Local database tables exist ({count} stocks).")
# 即使表存在,也要确保视图被创建(如果本地有 parquet 文件)
self._refresh_views()
return
except Exception:
pass
# 2. 尝试从本地 Parquet 文件恢复(Space 没重启的情况)
parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet"
list_file = Path(os.path.dirname(DUCKDB_PATH)) / "stock_list.parquet"
if not force_download and list_file.exists():
try:
conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{list_file}')")
self._refresh_views()
logger.info("Database restored from local parquet files.")
return
except Exception as e:
logger.warning(f"Failed to restore from local parquet: {e}")
# 3. 从 HF Dataset 下载数据(带重试)
if HF_TOKEN and DATASET_REPO_ID:
logger.info("Downloading remote Parquet files from HF Dataset...")
try:
# 首先尝试获取文件列表(带重试)
all_files = self._list_files_with_retry(DATASET_REPO_ID, max_retries=3)
if all_files is None:
# 列表获取失败,尝试回退模式
logger.warning("Failed to list files, attempting fallback mode...")
if self._fallback_download_from_status():
logger.info("Database initialized via fallback mode")
return
else:
logger.warning("Fallback mode failed, creating empty database")
self._create_tables()
return
# 正常流程:下载股票列表
if "data/stock_list.parquet" in all_files:
local_list_path = self._download_with_retry(DATASET_REPO_ID, "data/stock_list.parquet")
if local_list_path:
import shutil
shutil.copy(local_list_path, list_file)
conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{list_file}')")
# 下载日线数据分区(只下载最近3个月)
parquet_files = sorted([f for f in all_files if f.startswith("data/parquet/") and f.endswith(".parquet")])
if parquet_files:
# 只下载最近3个月的数据
recent_files = parquet_files[-3:]
logger.info(f"Downloading {len(recent_files)} recent parquet files (last 3 months)")
for f in recent_files:
remote_path = self._download_with_retry(DATASET_REPO_ID, f)
if remote_path:
dest_path = Path(os.path.dirname(DUCKDB_PATH)) / f.replace("data/", "")
dest_path.parent.mkdir(parents=True, exist_ok=True)
import shutil
shutil.copy(remote_path, dest_path)
self._refresh_views()
logger.info(f"Remote data downloaded and views created.")
else:
self._create_tables()
except Exception as e:
logger.error(f"Failed to load remote Parquet: {e}")
self._create_tables()
else:
self._create_tables()
logger.info("Local database initialized")
def _refresh_views(self) -> None:
"""刷新数据库视图"""
conn = self.conn
parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet"
if parquet_dir.exists():
p_files = list(parquet_dir.glob("*.parquet"))
if p_files:
files_sql = ", ".join([f"'{str(f)}'" for f in p_files])
conn.execute("DROP VIEW IF EXISTS stock_daily")
conn.execute(f"CREATE OR REPLACE VIEW stock_daily AS SELECT * FROM read_parquet([{files_sql}])")
logger.info(f"Database views refreshed with {len(p_files)} partitions")
def upload_db(self) -> None:
"""上传 Parquet 分区到 Hugging Face Dataset"""
if not HF_TOKEN or not DATASET_REPO_ID:
logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload")
return
try:
# 先关闭连接
self.close()
login(token=HF_TOKEN)
# 1. 上传股票列表
if Path(DUCKDB_PATH).exists():
conn = duckdb.connect(DUCKDB_PATH)
list_path = os.path.join(os.path.dirname(DUCKDB_PATH), "stock_list.parquet")
conn.execute(f"COPY stock_list TO '{list_path}' (FORMAT PARQUET)")
conn.close()
upload_file(
path_or_fileobj=list_path,
path_in_repo="data/stock_list.parquet",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
# 2. 上传所有 Parquet 行情文件
parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet"
if parquet_dir.exists():
for p_file in parquet_dir.glob("*.parquet"):
upload_file(
path_or_fileobj=str(p_file),
path_in_repo=f"data/parquet/{p_file.name}",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
# 3. 上传资金流向数据(按月分表)
fund_flow_dir = Path(os.path.dirname(DUCKDB_PATH)) / "fund_flow"
if fund_flow_dir.exists():
for ff_file in fund_flow_dir.glob("*.parquet"):
upload_file(
path_or_fileobj=str(ff_file),
path_in_repo=f"data/fund_flow/{ff_file.name}",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Fund flow data uploaded")
# 4. 上传估值指标数据(按月分表)
valuation_dir = Path(os.path.dirname(DUCKDB_PATH)) / "valuation"
if valuation_dir.exists():
for val_file in valuation_dir.glob("*.parquet"):
upload_file(
path_or_fileobj=str(val_file),
path_in_repo=f"data/valuation/{val_file.name}",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Valuation data uploaded")
# 5. 上传融资融券数据(按月分表)
margin_dir = Path(os.path.dirname(DUCKDB_PATH)) / "margin"
if margin_dir.exists():
for mar_file in margin_dir.glob("*.parquet"):
upload_file(
path_or_fileobj=str(mar_file),
path_in_repo=f"data/margin/{mar_file.name}",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Margin data uploaded")
# 6. 上传财务指标数据
financial_path = Path(os.path.dirname(DUCKDB_PATH)) / "financial_indicator.parquet"
if financial_path.exists():
upload_file(
path_or_fileobj=str(financial_path),
path_in_repo="data/financial_indicator.parquet",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Financial indicator data uploaded")
# 7. 上传股东户数数据
holder_path = Path(os.path.dirname(DUCKDB_PATH)) / "holder_num.parquet"
if holder_path.exists():
upload_file(
path_or_fileobj=str(holder_path),
path_in_repo="data/holder_num.parquet",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Holder number data uploaded")
# 8. 上传分红数据
dividend_path = Path(os.path.dirname(DUCKDB_PATH)) / "dividend.parquet"
if dividend_path.exists():
upload_file(
path_or_fileobj=str(dividend_path),
path_in_repo="data/dividend.parquet",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Dividend data uploaded")
# 9. 上传十大股东数据
top_holders_path = Path(os.path.dirname(DUCKDB_PATH)) / "top_holders.parquet"
if top_holders_path.exists():
upload_file(
path_or_fileobj=str(top_holders_path),
path_in_repo="data/top_holders.parquet",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Top holders data uploaded")
# 10. 上传限售解禁数据
restricted_path = Path(os.path.dirname(DUCKDB_PATH)) / "restricted_unlock.parquet"
if restricted_path.exists():
upload_file(
path_or_fileobj=str(restricted_path),
path_in_repo="data/restricted_unlock.parquet",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info("Restricted unlock data uploaded")
logger.info(f"Parquet files uploaded to HF Dataset: {DATASET_REPO_ID}")
except Exception as e:
logger.error(f"Failed to upload to HF: {e}")
finally:
_ = self.conn
def upload_indicator(self, indicator_name: str, local_path: Path, remote_path: str,
max_retries: int = 3) -> bool:
"""
上传单个指标数据到 HF Dataset(批量上传,每类指标一次 commit,带重试)
Args:
indicator_name: 指标名称(用于日志)
local_path: 本地文件或目录路径
remote_path: 远程路径前缀(如 "data/fund_flow")
max_retries: 最大重试次数(默认3)
Returns:
bool: 是否上传成功
"""
if not HF_TOKEN or not DATASET_REPO_ID:
logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload")
return False
import time
for attempt in range(max_retries):
try:
login(token=HF_TOKEN)
if local_path.is_file():
# 单文件上传
upload_file(
path_or_fileobj=str(local_path),
path_in_repo=f"{remote_path}/{local_path.name}",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Update {indicator_name}: {local_path.name}"
)
logger.info(f"{indicator_name} uploaded: {local_path.name}")
elif local_path.is_dir():
# 目录批量上传(所有文件一次 commit)
import tempfile
import shutil
# 收集所有 parquet 文件
files_to_upload = list(local_path.glob("*.parquet"))
if not files_to_upload:
logger.info(f"{indicator_name}: no files to upload")
return True
# 使用 upload_folder 批量上传(只产生一次 commit)
upload_folder(
folder_path=str(local_path),
path_in_repo=remote_path,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Update {indicator_name}: {len(files_to_upload)} files",
ignore_patterns=["*.tmp", "*.lock"] # 忽略临时文件
)
logger.info(f"{indicator_name} uploaded: {len(files_to_upload)} files (batch)")
return True
except Exception as e:
if attempt < max_retries - 1:
delay = 2 ** attempt # 指数退避:1s, 2s, 4s
logger.warning(f"Upload failed (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s...")
time.sleep(delay)
else:
logger.error(f"Failed to upload {indicator_name} after {max_retries} attempts: {e}")
return False
return False
def upload_indicator_smart(self, indicator_name: str, local_dir: Path, remote_path: str,
changed_files: List[str], batch_threshold: int = 10,
max_retries: int = 3) -> bool:
"""
智能上传指标数据(带重试):
- 变更文件多(>= threshold)→ 批量上传变更文件(临时目录,一次 commit)
- 变更文件少(< threshold)→ 逐个上传变更文件(多个 commit,但数量少)
Args:
indicator_name: 指标名称(用于日志)
local_dir: 本地目录路径
remote_path: 远程路径前缀
changed_files: 变更的文件名列表
batch_threshold: 批量上传阈值(默认10)
max_retries: 最大重试次数(默认3)
Returns:
bool: 是否上传成功
"""
if not HF_TOKEN or not DATASET_REPO_ID:
logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload")
return False
if not changed_files:
logger.info(f"{indicator_name}: no changes to upload")
return True
import time
for attempt in range(max_retries):
try:
login(token=HF_TOKEN)
if len(changed_files) >= batch_threshold:
# 变更文件多 → 创建临时目录,只复制变更文件,批量上传(一次 commit)
logger.info(f"{indicator_name}: {len(changed_files)} files changed, using batch upload")
import tempfile
import shutil
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
# 只复制变更的文件到临时目录
for filename in changed_files:
local_file = local_dir / filename
if local_file.exists():
shutil.copy(str(local_file), str(tmp_path / filename))
# 上传临时目录(只包含变更文件)
upload_folder(
folder_path=str(tmp_path),
path_in_repo=remote_path,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Update {indicator_name}: {len(changed_files)} files"
)
logger.info(f"{indicator_name} uploaded: {len(changed_files)} files (batch)")
else:
# 变更文件少 → 逐个上传(每个文件一次 commit,但数量少)
logger.info(f"{indicator_name}: {len(changed_files)} files changed, uploading individually")
for filename in changed_files:
local_file = local_dir / filename
if local_file.exists():
upload_file(
path_or_fileobj=str(local_file),
path_in_repo=f"{remote_path}/{filename}",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info(f"{indicator_name} uploaded: {len(changed_files)} files (individual)")
return True
except Exception as e:
if attempt < max_retries - 1:
delay = 2 ** attempt # 指数退避:1s, 2s, 4s
logger.warning(f"Upload failed (attempt {attempt + 1}/{max_retries}): {e}, retrying in {delay}s...")
time.sleep(delay)
else:
logger.error(f"Failed to upload {indicator_name} after {max_retries} attempts: {e}")
return False
return False
def _create_tables(self) -> None:
"""创建数据库表结构"""
conn = self.conn
# 日线行情表(保持原有结构不变)
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_daily (
code VARCHAR,
trade_date DATE,
open DOUBLE,
high DOUBLE,
low DOUBLE,
close DOUBLE,
volume BIGINT,
amount DOUBLE,
pct_chg DOUBLE,
turnover_rate DOUBLE,
PRIMARY KEY (code, trade_date)
)
""")
# 股票基础信息表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_list (
code VARCHAR PRIMARY KEY,
name VARCHAR,
market VARCHAR,
list_date DATE
)
""")
# 资金流向表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_fund_flow (
code VARCHAR,
trade_date DATE,
close DOUBLE,
pct_chg DOUBLE,
main_net_inflow DOUBLE,
main_net_inflow_pct DOUBLE,
huge_net_inflow DOUBLE,
huge_net_inflow_pct DOUBLE,
large_net_inflow DOUBLE,
large_net_inflow_pct DOUBLE,
medium_net_inflow DOUBLE,
medium_net_inflow_pct DOUBLE,
small_net_inflow DOUBLE,
small_net_inflow_pct DOUBLE,
PRIMARY KEY (code, trade_date)
)
""")
# 估值指标表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_valuation (
code VARCHAR,
trade_date DATE,
pe_ttm DOUBLE,
pe_static DOUBLE,
pb DOUBLE,
ps_ttm DOUBLE,
dv_ratio DOUBLE,
total_mv DOUBLE,
circ_mv DOUBLE,
PRIMARY KEY (code, trade_date)
)
""")
# 融资融券表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_margin (
code VARCHAR,
trade_date DATE,
rzye DOUBLE,
rzmre DOUBLE,
rzche DOUBLE,
rqye DOUBLE,
rqmcl DOUBLE,
rzrqye DOUBLE,
PRIMARY KEY (code, trade_date)
)
""")
# 财务指标表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_financial_indicator (
code VARCHAR,
trade_date DATE,
roe DOUBLE,
roa DOUBLE,
gross_margin DOUBLE,
net_margin DOUBLE,
debt_ratio DOUBLE,
current_ratio DOUBLE,
quick_ratio DOUBLE,
inventory_turnover DOUBLE,
receivable_turnover DOUBLE,
total_asset_turnover DOUBLE,
PRIMARY KEY (code, trade_date)
)
""")
# 股东户数表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_holder_num (
code VARCHAR,
trade_date DATE,
holder_num BIGINT,
avg_share DOUBLE,
avg_value DOUBLE,
total_share DOUBLE,
total_value DOUBLE,
PRIMARY KEY (code, trade_date)
)
""")
# 历史分红表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_dividend (
code VARCHAR,
trade_date DATE,
dividend_type VARCHAR,
dividend_amount DOUBLE,
record_date DATE,
ex_date DATE,
pay_date DATE,
PRIMARY KEY (code, trade_date, dividend_type)
)
""")
# 十大股东表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_top_holders (
code VARCHAR,
trade_date DATE,
holder_name VARCHAR,
holder_type VARCHAR,
hold_num DOUBLE,
hold_ratio DOUBLE,
hold_change DOUBLE,
hold_change_ratio DOUBLE,
PRIMARY KEY (code, trade_date, holder_name)
)
""")
# 限售解禁表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_restricted_unlock (
code VARCHAR,
trade_date DATE,
unlock_date DATE,
unlock_num DOUBLE,
unlock_value DOUBLE,
unlock_ratio DOUBLE,
lock_type VARCHAR,
PRIMARY KEY (code, unlock_date)
)
""")
# 创建索引
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_code_date
ON stock_daily (code, trade_date)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_fund_flow_code_date
ON stock_fund_flow (code, trade_date)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_valuation_code_date
ON stock_valuation (code, trade_date)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_margin_code_date
ON stock_margin (code, trade_date)
""")
logger.info("Database tables created/verified")
# 全局单例实例
db_manager = DatabaseManager()
def get_db() -> DatabaseManager:
"""获取数据库管理器实例"""
return db_manager