Paper_Trading / backend /app /database.py
superxuu
修复同步稳定性与远程视图冲突
b39d11d
"""
DuckDB 数据库管理模块
实现数据库初始化、Hugging Face Dataset 同步、单例连接管理
"""
import os
import logging
from pathlib import Path
from typing import Optional
from functools import lru_cache
from dotenv import load_dotenv
# 加载 .env 文件
load_dotenv()
import duckdb
from huggingface_hub import hf_hub_download, upload_file, login
logger = logging.getLogger(__name__)
# 环境变量配置
DUCKDB_PATH = os.getenv("DUCKDB_PATH", "/app/data/stock_data.duckdb")
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "")
HF_HOME = os.getenv("HF_HOME", "/tmp/.huggingface")
class DatabaseManager:
"""DuckDB 数据库管理器 - 单例模式"""
_instance: Optional['DatabaseManager'] = None
_connection: Optional[duckdb.DuckDBPyConnection] = None
def __new__(cls) -> 'DatabaseManager':
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@property
def conn(self) -> duckdb.DuckDBPyConnection:
"""获取数据库连接"""
if self._connection is None:
# 确保目录存在
os.makedirs(os.path.dirname(DUCKDB_PATH), exist_ok=True)
self._connection = duckdb.connect(DUCKDB_PATH)
# 如果在 HF 环境,配置远程访问插件
if HF_TOKEN and DATASET_REPO_ID:
try:
self._connection.execute("INSTALL httpfs; LOAD httpfs;")
self._connection.execute(f"SET s3_region='us-east-1';") # HF 默认区域
# 这里的逻辑可以根据是否切换到 Parquet 进一步扩展
except Exception as e:
logger.warning(f"Failed to load httpfs extension: {e}")
logger.info(f"Database connection established: {DUCKDB_PATH}")
return self._connection
def close(self) -> None:
"""关闭数据库连接"""
if self._connection is not None:
self._connection.close()
self._connection = None
logger.info("Database connection closed")
def init_db(self) -> None:
"""
初始化数据库
- 在本地环境:使用本地 DuckDB
- 在 HF 环境:挂载云端 Parquet 文件
"""
conn = self.conn
# 检查是否在 HF 环境
is_hf = os.getenv("SPACES_REPO_NAME") is not None or (HF_TOKEN and DATASET_REPO_ID)
if is_hf:
logger.info("HF environment detected. Downloading remote Parquet files...")
try:
from huggingface_hub import list_repo_files, hf_hub_download
# 1. 动态获取文件列表
all_files = list_repo_files(repo_id=DATASET_REPO_ID, repo_type="dataset")
# 2. 股票列表:下载并加载
list_file = "data/stock_list.parquet"
if list_file in all_files:
local_list_path = hf_hub_download(repo_id=DATASET_REPO_ID, filename=list_file, repo_type="dataset")
conn.execute(f"CREATE OR REPLACE TABLE stock_list AS SELECT * FROM read_parquet('{local_list_path}')")
logger.info("Local stock_list table created from remote parquet")
# 3. 日线数据:下载并加载 (下载模式替代挂载模式)
# 注意:如果文件很多,这步可能会花点时间,但比下载整个 .duckdb 快得多
parquet_files = [f for f in all_files if f.startswith("data/parquet/") and f.endswith(".parquet")]
if parquet_files:
local_paths = []
for f in parquet_files:
# HF SDK 会自动处理缓存,已下载的文件不会重复下载
path = hf_hub_download(repo_id=DATASET_REPO_ID, filename=f, repo_type="dataset")
local_paths.append(f"'{path}'")
files_sql = ", ".join(local_paths)
conn.execute("DROP VIEW IF EXISTS stock_daily")
conn.execute("DROP TABLE IF EXISTS stock_daily")
conn.execute(f"CREATE OR REPLACE VIEW stock_daily AS SELECT * FROM read_parquet([{files_sql}])")
logger.info(f"Remote stock_daily view created with {len(parquet_files)} partitions")
else:
logger.warning("No parquet data files found in dataset")
self._create_tables()
# 验证
count = conn.execute("SELECT COUNT(*) FROM stock_list").fetchone()[0]
logger.info(f"Verification successful: {count} stocks found")
except Exception as e:
logger.error(f"Failed to load remote Parquet: {e}")
self._create_tables()
else:
self._create_tables()
logger.info("Local database initialized")
def _download_from_hf(self) -> None:
"""从 Hugging Face Dataset 下载数据库文件"""
if not HF_TOKEN or not DATASET_REPO_ID:
logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping download")
return
try:
# 登录 Hugging Face
login(token=HF_TOKEN)
# 下载数据库文件
downloaded_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename="stock_data.duckdb",
repo_type="dataset",
local_dir=os.path.dirname(DUCKDB_PATH),
)
# 如果下载路径不同,重命名
if downloaded_path != DUCKDB_PATH:
os.rename(downloaded_path, DUCKDB_PATH)
logger.info(f"Database downloaded from HF Dataset: {DATASET_REPO_ID}")
except Exception as e:
logger.error(f"Failed to download database from HF: {e}")
def upload_db(self) -> None:
"""上传 Parquet 分区到 Hugging Face Dataset (不再上传全量 .duckdb)"""
if not HF_TOKEN or not DATASET_REPO_ID:
logger.warning("HF_TOKEN or DATASET_REPO_ID not set, skipping upload")
return
try:
# 先关闭连接
self.close()
login(token=HF_TOKEN)
# 1. 上传股票列表 (这个很小,可以全量覆盖)
if Path(DUCKDB_PATH).exists():
# 导出列表为 parquet 方便云端读取
conn = duckdb.connect(DUCKDB_PATH)
list_path = os.path.join(os.path.dirname(DUCKDB_PATH), "stock_list.parquet")
conn.execute(f"COPY stock_list TO '{list_path}' (FORMAT PARQUET)")
conn.close()
upload_file(
path_or_fileobj=list_path,
path_in_repo="data/stock_list.parquet",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
# 2. 上传所有 Parquet 行情文件 (增量核心)
parquet_dir = Path(os.path.dirname(DUCKDB_PATH)) / "parquet"
if parquet_dir.exists():
for p_file in parquet_dir.glob("*.parquet"):
upload_file(
path_or_fileobj=str(p_file),
path_in_repo=f"data/parquet/{p_file.name}",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
)
logger.info(f"Parquet files uploaded to HF Dataset: {DATASET_REPO_ID}")
except Exception as e:
logger.error(f"Failed to upload to HF: {e}")
finally:
_ = self.conn
def _create_tables(self) -> None:
"""创建数据库表结构"""
conn = self.conn
# 日线行情表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_daily (
code VARCHAR,
trade_date DATE,
open DOUBLE,
high DOUBLE,
low DOUBLE,
close DOUBLE,
volume BIGINT,
amount DOUBLE,
pct_chg DOUBLE,
turnover_rate DOUBLE,
PRIMARY KEY (code, trade_date)
)
""")
# 股票基础信息表
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_list (
code VARCHAR PRIMARY KEY,
name VARCHAR,
market VARCHAR,
list_date DATE
)
""")
# 创建索引
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_code_date
ON stock_daily (code, trade_date)
""")
logger.info("Database tables created/verified")
# 全局单例实例
db_manager = DatabaseManager()
def get_db() -> DatabaseManager:
"""获取数据库管理器实例"""
return db_manager