Paper_Trading / backend /scripts /sync_data.py
superxuu
feat: unify all system time records to Beijing Time (UTC+8)
de2aac1
"""
数据同步脚本 - 云端兼容版
支持本地母本模式 (DuckDB) 和 云端无盘模式 (Parquet View)
"""
import os
import sys
import logging
import time
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional, Dict, Any
from pathlib import Path
import pandas as pd
import akshare as ak
import yfinance as yf
from huggingface_hub import hf_hub_download, upload_file
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.database import get_db
from app.database_user import get_beijing_time
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 配置
YEARS_OF_DATA = 10
MAX_WORKERS = 5 # 降低并发数,减少超时
SYNC_LIMIT = -1
USE_YAHOO_FALLBACK = 0
def get_stock_list() -> pd.DataFrame:
"""获取全市场标的列表"""
logger.info("Fetching all-market target list...")
all_lists = []
try:
df_a = ak.stock_zh_a_spot_em()[['代码', '名称']]
df_a.columns = ['code', 'name']
df_a['market'] = df_a['code'].apply(lambda x: '主板' if x.startswith(('60', '00')) else ('创业板' if x.startswith('30') else ('科创板' if x.startswith('68') else ('北交所' if x.startswith(('8', '4', '920')) else '其他'))))
all_lists.append(df_a)
# ETF / LOF / REITs
try:
df_etf = ak.fund_etf_spot_em()[['代码', '名称']]
df_etf.columns = ['code', 'name']
df_etf['market'] = 'ETF'
all_lists.append(df_etf)
except Exception as e:
logger.warning(f"ETF list fetch failed: {e}")
try:
df_lof = ak.fund_lof_spot_em()[['代码', '名称']]
df_lof.columns = ['code', 'name']
df_lof['market'] = 'LOF'
all_lists.append(df_lof)
except Exception as e:
logger.warning(f"LOF list fetch failed: {e}")
try:
df_reits = ak.reits_realtime_em()[['代码', '名称']]
df_reits.columns = ['code', 'name']
df_reits['market'] = 'REITs'
all_lists.append(df_reits)
except Exception as e:
logger.warning(f"REITs list fetch failed: {e}")
try:
df_cb = ak.bond_zh_hs_cov_spot()
c_code = '代码' if '代码' in df_cb.columns else 'symbol'
c_name = '名称' if '名称' in df_cb.columns else 'name'
df_cb = df_cb[[c_code, c_name]]; df_cb.columns = ['code', 'name']; df_cb['market'] = '可转债'; all_lists.append(df_cb)
except: pass
except Exception as e:
logger.error(f"List fetching error: {e}")
if not all_lists:
db = get_db()
return db.conn.execute("SELECT code, name, market FROM stock_list").df()
df = pd.concat(all_lists).drop_duplicates(subset=['code'])
df['list_date'] = None
return df
def get_target_daily(code: str, start_date: str, market: str) -> Optional[pd.DataFrame]:
"""抓取单只标的数据"""
max_retries = 3 # 增加重试次数
for attempt in range(max_retries):
try:
end_date = get_beijing_time().strftime('%Y%m%d')
fetch_start = start_date.replace('-', '')
df = None
if market == 'INDEX':
df = ak.stock_zh_index_daily_em(symbol=f"sh{code}" if code.startswith('000') else f"sz{code}")
elif market == 'ETF':
df = ak.fund_etf_hist_em(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq")
elif market == 'LOF':
df = ak.fund_lof_hist_em(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq")
elif market == '可转债':
# 可转债接口通常使用 6 位数字代码,兼容 sh110xxx / sz12xxxx / bj81xxxx
cov_symbol = code[-6:] if len(code) > 6 else code
try:
df = ak.bond_zh_hs_cov_daily(symbol=cov_symbol)
except Exception:
df = ak.bond_zh_hs_cov_daily(symbol=code)
elif market == 'REITs':
# REITs 接口仅支持 symbol 参数,返回全量历史后再按日期过滤
df = ak.reits_hist_em(symbol=code)
else:
df = ak.stock_zh_a_hist(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq")
if df is not None and not df.empty:
# 标准化列名
rename_map = {
'日期': 'trade_date', 'date': 'trade_date', 'Date': 'trade_date',
'开盘': 'open', '今开': 'open', 'Open': 'open',
'最高': 'high', 'High': 'high',
'最低': 'low', 'Low': 'low',
'收盘': 'close', '最新价': 'close', 'Close': 'close',
'成交量': 'volume', 'Volume': 'volume',
'成交额': 'amount', 'Amount': 'amount',
'涨跌幅': 'pct_chg',
'换手率': 'turnover_rate', '换手': 'turnover_rate'
}
df = df.rename(columns=rename_map)
# 如果还是没有 trade_date,尝试将索引转为列
if 'trade_date' not in df.columns:
df = df.reset_index().rename(columns={'index': 'trade_date', 'date': 'trade_date'})
df['trade_date'] = pd.to_datetime(df['trade_date'])
df = df[df['trade_date'] >= pd.to_datetime(start_date)]
if 'amount' not in df.columns: df['amount'] = 0
if 'pct_chg' not in df.columns: df['pct_chg'] = df['close'].pct_change() * 100
if 'turnover_rate' not in df.columns: df['turnover_rate'] = 0
df['code'] = code
return df[['code', 'trade_date', 'open', 'high', 'low', 'close', 'volume', 'amount', 'pct_chg', 'turnover_rate']]
except Exception as e:
if attempt == max_retries - 1:
logger.warning(f"Failed to fetch {code} ({market}): {str(e)}")
time.sleep(1)
return None
def get_last_trading_day() -> str:
"""获取最近一个交易日(用于增量截止日)"""
try:
df = ak.stock_zh_index_daily_em(symbol="sh000300")
if df is not None and not df.empty:
date_col = 'date' if 'date' in df.columns else ('日期' if '日期' in df.columns else None)
if date_col:
return pd.to_datetime(df[date_col].iloc[-1]).strftime('%Y-%m-%d')
except Exception as e:
logger.warning(f"Failed to get last trading day from index data: {e}")
# 回退:按工作日估算
d = get_beijing_time()
while d.weekday() >= 5: # 5=周六, 6=周日
d -= timedelta(days=1)
return d.strftime('%Y-%m-%d')
def get_index_daily(code: str) -> Optional[pd.DataFrame]:
"""抓取指数日线(默认用于沪深300)"""
try:
symbol = f"sh{code}" if code.startswith('000') else f"sz{code}"
df = ak.stock_zh_index_daily_em(symbol=symbol)
if df is None or df.empty:
return None
rename_map = {
'date': 'trade_date', '日期': 'trade_date',
'open': 'open', '开盘': 'open',
'high': 'high', '最高': 'high',
'low': 'low', '最低': 'low',
'close': 'close', '收盘': 'close',
'volume': 'volume', '成交量': 'volume',
'amount': 'amount', '成交额': 'amount',
'pct_chg': 'pct_chg', '涨跌幅': 'pct_chg'
}
df = df.rename(columns=rename_map)
if 'trade_date' not in df.columns:
return None
df['trade_date'] = pd.to_datetime(df['trade_date'])
if 'amount' not in df.columns:
df['amount'] = 0
if 'pct_chg' not in df.columns:
df['pct_chg'] = df['close'].pct_change() * 100
if 'volume' not in df.columns:
df['volume'] = 0
df['turnover_rate'] = 0
df['code'] = code
return df[['code', 'trade_date', 'open', 'high', 'low', 'close', 'volume', 'amount', 'pct_chg', 'turnover_rate']]
except Exception as e:
logger.warning(f"Failed to fetch index {code}: {e}")
return None
def sync_stock_daily(targets: List[Dict[str, str]], last_trade_day: str) -> int:
"""增量同步逻辑 - 兼容云端视图"""
db = get_db()
# 1. 获取现状 (无论是表还是视图)
existing_latest = db.conn.execute("SELECT code, CAST(MAX(trade_date) AS VARCHAR) FROM stock_daily GROUP BY code").fetchall()
latest_map = {row[0]: row[1] for row in existing_latest}
pending = []
for t in targets:
code = t['code']
if code in latest_map:
if latest_map[code] >= last_trade_day: continue
start_dt = (pd.to_datetime(latest_map[code]) + timedelta(days=1)).strftime('%Y-%m-%d')
else:
start_dt = (get_beijing_time() - timedelta(days=YEARS_OF_DATA * 365)).strftime('%Y-%m-%d')
t['start_dt'] = start_dt
pending.append(t)
if not pending: return 0
logger.info(f"Syncing {len(pending)} targets...")
all_new_data = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = {executor.submit(get_target_daily, t['code'], t['start_dt'], t['market']): t['code'] for t in pending}
for i, future in enumerate(as_completed(futures), 1):
res = future.result()
if res is not None: all_new_data.append(res)
if i % 500 == 0: logger.info(f"Progress: {i}/{len(pending)}")
if all_new_data:
inc_df = pd.concat(all_new_data, ignore_index=True)
# 识别变动月份
changed = inc_df.assign(yr=inc_df['trade_date'].dt.year, mo=inc_df['trade_date'].dt.month)[['yr', 'mo']].drop_duplicates().values
for yr, mo in changed:
yr, mo = int(yr), int(mo)
filename = f"{yr}-{mo:02d}.parquet"
local_path = Path(f"backend/data/parquet/{filename}")
local_path.parent.mkdir(parents=True, exist_ok=True)
# 云端增量核心:如果本地没母本,先从云端拉取旧的月份文件
if not Path(os.getenv("DUCKDB_PATH", "")).exists() or os.path.getsize(os.getenv("DUCKDB_PATH")) < 1024:
try:
repo_id = os.getenv("DATASET_REPO_ID")
if repo_id:
old_file = hf_hub_download(repo_id=repo_id, filename=f"data/parquet/{filename}", repo_type="dataset")
old_df = pd.read_parquet(old_file)
# 合并
month_inc = inc_df[(inc_df['trade_date'].dt.year == yr) & (inc_df['trade_date'].dt.month == mo)]
final_month_df = pd.concat([old_df, month_inc]).drop_duplicates(subset=['code', 'trade_date'])
final_month_df.to_parquet(local_path)
logger.info(f"Merged cloud data for {filename}")
continue
except: pass
# 本地母本模式:直接从母本导出
db.conn.execute("CREATE TEMP TABLE temp_inc AS SELECT * FROM inc_df")
db.conn.execute("INSERT OR IGNORE INTO stock_daily SELECT * FROM temp_inc")
db.conn.execute(f"COPY (SELECT * FROM stock_daily WHERE date_part('year', trade_date) = {yr} AND date_part('month', trade_date) = {mo}) TO '{local_path}' (FORMAT PARQUET)")
db.conn.execute("DROP TABLE temp_inc")
return len(all_new_data)
def main():
logger.info("Sync started...")
db = get_db(); db.init_db()
# 1. 列表同步
target_list = get_stock_list()
# 如果是本地母本,更新表;如果是云端,导出 parquet 准备上传
if Path(os.getenv("DUCKDB_PATH", "")).exists():
db.conn.execute("DELETE FROM stock_list")
db.conn.execute("INSERT INTO stock_list SELECT code, name, market, list_date FROM target_list")
list_parquet = Path("backend/data/stock_list.parquet")
list_parquet.parent.mkdir(parents=True, exist_ok=True)
target_list.to_parquet(list_parquet)
# 2. 行情同步
last_day = get_last_trading_day()
sync_stock_daily(target_list.to_dict('records'), last_day)
# 3. 指数同步
idx_df = get_index_daily('000300')
if idx_df is not None:
idx_path = Path("backend/data/parquet/index_000300.parquet")
idx_df.to_parquet(idx_path)
# 4. 上传
db.upload_db()
# 5. 热重载:刷新数据库视图和缓存
try:
logger.info("Reloading database views to reflect new data...")
db.init_db() # 重新下载并挂载最新的 Parquet 文件
from app.core import clear_eligible_stocks_cache
clear_eligible_stocks_cache() # 清除内存缓存
logger.info("Hot reload completed!")
except Exception as e:
logger.error(f"Hot reload failed: {e}")
logger.info("Sync finished!")
if __name__ == "__main__":
main()