Spaces:
Running
Running
| """ | |
| 数据同步脚本 - 云端兼容版 | |
| 支持本地母本模式 (DuckDB) 和 云端无盘模式 (Parquet View) | |
| """ | |
| import os | |
| import sys | |
| import logging | |
| import time | |
| from datetime import datetime, timedelta | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import List, Optional, Dict, Any | |
| from pathlib import Path | |
| import pandas as pd | |
| import akshare as ak | |
| import yfinance as yf | |
| from huggingface_hub import hf_hub_download, upload_file | |
| # 添加父目录到路径 | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from app.database import get_db | |
| from app.database_user import get_beijing_time | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # 配置 | |
| YEARS_OF_DATA = 10 | |
| MAX_WORKERS = 5 # 降低并发数,减少超时 | |
| SYNC_LIMIT = -1 | |
| USE_YAHOO_FALLBACK = 0 | |
| def get_stock_list() -> pd.DataFrame: | |
| """获取全市场标的列表""" | |
| logger.info("Fetching all-market target list...") | |
| all_lists = [] | |
| try: | |
| df_a = ak.stock_zh_a_spot_em()[['代码', '名称']] | |
| df_a.columns = ['code', 'name'] | |
| df_a['market'] = df_a['code'].apply(lambda x: '主板' if x.startswith(('60', '00')) else ('创业板' if x.startswith('30') else ('科创板' if x.startswith('68') else ('北交所' if x.startswith(('8', '4', '920')) else '其他')))) | |
| all_lists.append(df_a) | |
| # ETF / LOF / REITs | |
| try: | |
| df_etf = ak.fund_etf_spot_em()[['代码', '名称']] | |
| df_etf.columns = ['code', 'name'] | |
| df_etf['market'] = 'ETF' | |
| all_lists.append(df_etf) | |
| except Exception as e: | |
| logger.warning(f"ETF list fetch failed: {e}") | |
| try: | |
| df_lof = ak.fund_lof_spot_em()[['代码', '名称']] | |
| df_lof.columns = ['code', 'name'] | |
| df_lof['market'] = 'LOF' | |
| all_lists.append(df_lof) | |
| except Exception as e: | |
| logger.warning(f"LOF list fetch failed: {e}") | |
| try: | |
| df_reits = ak.reits_realtime_em()[['代码', '名称']] | |
| df_reits.columns = ['code', 'name'] | |
| df_reits['market'] = 'REITs' | |
| all_lists.append(df_reits) | |
| except Exception as e: | |
| logger.warning(f"REITs list fetch failed: {e}") | |
| try: | |
| df_cb = ak.bond_zh_hs_cov_spot() | |
| c_code = '代码' if '代码' in df_cb.columns else 'symbol' | |
| c_name = '名称' if '名称' in df_cb.columns else 'name' | |
| df_cb = df_cb[[c_code, c_name]]; df_cb.columns = ['code', 'name']; df_cb['market'] = '可转债'; all_lists.append(df_cb) | |
| except: pass | |
| except Exception as e: | |
| logger.error(f"List fetching error: {e}") | |
| if not all_lists: | |
| db = get_db() | |
| return db.conn.execute("SELECT code, name, market FROM stock_list").df() | |
| df = pd.concat(all_lists).drop_duplicates(subset=['code']) | |
| df['list_date'] = None | |
| return df | |
| def get_target_daily(code: str, start_date: str, market: str) -> Optional[pd.DataFrame]: | |
| """抓取单只标的数据""" | |
| max_retries = 3 # 增加重试次数 | |
| for attempt in range(max_retries): | |
| try: | |
| end_date = get_beijing_time().strftime('%Y%m%d') | |
| fetch_start = start_date.replace('-', '') | |
| df = None | |
| if market == 'INDEX': | |
| df = ak.stock_zh_index_daily_em(symbol=f"sh{code}" if code.startswith('000') else f"sz{code}") | |
| elif market == 'ETF': | |
| df = ak.fund_etf_hist_em(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq") | |
| elif market == 'LOF': | |
| df = ak.fund_lof_hist_em(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq") | |
| elif market == '可转债': | |
| # 可转债接口通常使用 6 位数字代码,兼容 sh110xxx / sz12xxxx / bj81xxxx | |
| cov_symbol = code[-6:] if len(code) > 6 else code | |
| try: | |
| df = ak.bond_zh_hs_cov_daily(symbol=cov_symbol) | |
| except Exception: | |
| df = ak.bond_zh_hs_cov_daily(symbol=code) | |
| elif market == 'REITs': | |
| # REITs 接口仅支持 symbol 参数,返回全量历史后再按日期过滤 | |
| df = ak.reits_hist_em(symbol=code) | |
| else: | |
| df = ak.stock_zh_a_hist(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq") | |
| if df is not None and not df.empty: | |
| # 标准化列名 | |
| rename_map = { | |
| '日期': 'trade_date', 'date': 'trade_date', 'Date': 'trade_date', | |
| '开盘': 'open', '今开': 'open', 'Open': 'open', | |
| '最高': 'high', 'High': 'high', | |
| '最低': 'low', 'Low': 'low', | |
| '收盘': 'close', '最新价': 'close', 'Close': 'close', | |
| '成交量': 'volume', 'Volume': 'volume', | |
| '成交额': 'amount', 'Amount': 'amount', | |
| '涨跌幅': 'pct_chg', | |
| '换手率': 'turnover_rate', '换手': 'turnover_rate' | |
| } | |
| df = df.rename(columns=rename_map) | |
| # 如果还是没有 trade_date,尝试将索引转为列 | |
| if 'trade_date' not in df.columns: | |
| df = df.reset_index().rename(columns={'index': 'trade_date', 'date': 'trade_date'}) | |
| df['trade_date'] = pd.to_datetime(df['trade_date']) | |
| df = df[df['trade_date'] >= pd.to_datetime(start_date)] | |
| if 'amount' not in df.columns: df['amount'] = 0 | |
| if 'pct_chg' not in df.columns: df['pct_chg'] = df['close'].pct_change() * 100 | |
| if 'turnover_rate' not in df.columns: df['turnover_rate'] = 0 | |
| df['code'] = code | |
| return df[['code', 'trade_date', 'open', 'high', 'low', 'close', 'volume', 'amount', 'pct_chg', 'turnover_rate']] | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| logger.warning(f"Failed to fetch {code} ({market}): {str(e)}") | |
| time.sleep(1) | |
| return None | |
| def get_last_trading_day() -> str: | |
| """获取最近一个交易日(用于增量截止日)""" | |
| try: | |
| df = ak.stock_zh_index_daily_em(symbol="sh000300") | |
| if df is not None and not df.empty: | |
| date_col = 'date' if 'date' in df.columns else ('日期' if '日期' in df.columns else None) | |
| if date_col: | |
| return pd.to_datetime(df[date_col].iloc[-1]).strftime('%Y-%m-%d') | |
| except Exception as e: | |
| logger.warning(f"Failed to get last trading day from index data: {e}") | |
| # 回退:按工作日估算 | |
| d = get_beijing_time() | |
| while d.weekday() >= 5: # 5=周六, 6=周日 | |
| d -= timedelta(days=1) | |
| return d.strftime('%Y-%m-%d') | |
| def get_index_daily(code: str) -> Optional[pd.DataFrame]: | |
| """抓取指数日线(默认用于沪深300)""" | |
| try: | |
| symbol = f"sh{code}" if code.startswith('000') else f"sz{code}" | |
| df = ak.stock_zh_index_daily_em(symbol=symbol) | |
| if df is None or df.empty: | |
| return None | |
| rename_map = { | |
| 'date': 'trade_date', '日期': 'trade_date', | |
| 'open': 'open', '开盘': 'open', | |
| 'high': 'high', '最高': 'high', | |
| 'low': 'low', '最低': 'low', | |
| 'close': 'close', '收盘': 'close', | |
| 'volume': 'volume', '成交量': 'volume', | |
| 'amount': 'amount', '成交额': 'amount', | |
| 'pct_chg': 'pct_chg', '涨跌幅': 'pct_chg' | |
| } | |
| df = df.rename(columns=rename_map) | |
| if 'trade_date' not in df.columns: | |
| return None | |
| df['trade_date'] = pd.to_datetime(df['trade_date']) | |
| if 'amount' not in df.columns: | |
| df['amount'] = 0 | |
| if 'pct_chg' not in df.columns: | |
| df['pct_chg'] = df['close'].pct_change() * 100 | |
| if 'volume' not in df.columns: | |
| df['volume'] = 0 | |
| df['turnover_rate'] = 0 | |
| df['code'] = code | |
| return df[['code', 'trade_date', 'open', 'high', 'low', 'close', 'volume', 'amount', 'pct_chg', 'turnover_rate']] | |
| except Exception as e: | |
| logger.warning(f"Failed to fetch index {code}: {e}") | |
| return None | |
| def sync_stock_daily(targets: List[Dict[str, str]], last_trade_day: str) -> int: | |
| """增量同步逻辑 - 兼容云端视图""" | |
| db = get_db() | |
| # 1. 获取现状 (无论是表还是视图) | |
| existing_latest = db.conn.execute("SELECT code, CAST(MAX(trade_date) AS VARCHAR) FROM stock_daily GROUP BY code").fetchall() | |
| latest_map = {row[0]: row[1] for row in existing_latest} | |
| pending = [] | |
| for t in targets: | |
| code = t['code'] | |
| if code in latest_map: | |
| if latest_map[code] >= last_trade_day: continue | |
| start_dt = (pd.to_datetime(latest_map[code]) + timedelta(days=1)).strftime('%Y-%m-%d') | |
| else: | |
| start_dt = (get_beijing_time() - timedelta(days=YEARS_OF_DATA * 365)).strftime('%Y-%m-%d') | |
| t['start_dt'] = start_dt | |
| pending.append(t) | |
| if not pending: return 0 | |
| logger.info(f"Syncing {len(pending)} targets...") | |
| all_new_data = [] | |
| with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: | |
| futures = {executor.submit(get_target_daily, t['code'], t['start_dt'], t['market']): t['code'] for t in pending} | |
| for i, future in enumerate(as_completed(futures), 1): | |
| res = future.result() | |
| if res is not None: all_new_data.append(res) | |
| if i % 500 == 0: logger.info(f"Progress: {i}/{len(pending)}") | |
| if all_new_data: | |
| inc_df = pd.concat(all_new_data, ignore_index=True) | |
| # 识别变动月份 | |
| changed = inc_df.assign(yr=inc_df['trade_date'].dt.year, mo=inc_df['trade_date'].dt.month)[['yr', 'mo']].drop_duplicates().values | |
| for yr, mo in changed: | |
| yr, mo = int(yr), int(mo) | |
| filename = f"{yr}-{mo:02d}.parquet" | |
| local_path = Path(f"backend/data/parquet/{filename}") | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| # 云端增量核心:如果本地没母本,先从云端拉取旧的月份文件 | |
| if not Path(os.getenv("DUCKDB_PATH", "")).exists() or os.path.getsize(os.getenv("DUCKDB_PATH")) < 1024: | |
| try: | |
| repo_id = os.getenv("DATASET_REPO_ID") | |
| if repo_id: | |
| old_file = hf_hub_download(repo_id=repo_id, filename=f"data/parquet/{filename}", repo_type="dataset") | |
| old_df = pd.read_parquet(old_file) | |
| # 合并 | |
| month_inc = inc_df[(inc_df['trade_date'].dt.year == yr) & (inc_df['trade_date'].dt.month == mo)] | |
| final_month_df = pd.concat([old_df, month_inc]).drop_duplicates(subset=['code', 'trade_date']) | |
| final_month_df.to_parquet(local_path) | |
| logger.info(f"Merged cloud data for {filename}") | |
| continue | |
| except: pass | |
| # 本地母本模式:直接从母本导出 | |
| db.conn.execute("CREATE TEMP TABLE temp_inc AS SELECT * FROM inc_df") | |
| db.conn.execute("INSERT OR IGNORE INTO stock_daily SELECT * FROM temp_inc") | |
| db.conn.execute(f"COPY (SELECT * FROM stock_daily WHERE date_part('year', trade_date) = {yr} AND date_part('month', trade_date) = {mo}) TO '{local_path}' (FORMAT PARQUET)") | |
| db.conn.execute("DROP TABLE temp_inc") | |
| return len(all_new_data) | |
| def main(): | |
| logger.info("Sync started...") | |
| db = get_db(); db.init_db() | |
| # 1. 列表同步 | |
| target_list = get_stock_list() | |
| # 如果是本地母本,更新表;如果是云端,导出 parquet 准备上传 | |
| if Path(os.getenv("DUCKDB_PATH", "")).exists(): | |
| db.conn.execute("DELETE FROM stock_list") | |
| db.conn.execute("INSERT INTO stock_list SELECT code, name, market, list_date FROM target_list") | |
| list_parquet = Path("backend/data/stock_list.parquet") | |
| list_parquet.parent.mkdir(parents=True, exist_ok=True) | |
| target_list.to_parquet(list_parquet) | |
| # 2. 行情同步 | |
| last_day = get_last_trading_day() | |
| sync_stock_daily(target_list.to_dict('records'), last_day) | |
| # 3. 指数同步 | |
| idx_df = get_index_daily('000300') | |
| if idx_df is not None: | |
| idx_path = Path("backend/data/parquet/index_000300.parquet") | |
| idx_df.to_parquet(idx_path) | |
| # 4. 上传 | |
| db.upload_db() | |
| # 5. 热重载:刷新数据库视图和缓存 | |
| try: | |
| logger.info("Reloading database views to reflect new data...") | |
| db.init_db() # 重新下载并挂载最新的 Parquet 文件 | |
| from app.core import clear_eligible_stocks_cache | |
| clear_eligible_stocks_cache() # 清除内存缓存 | |
| logger.info("Hot reload completed!") | |
| except Exception as e: | |
| logger.error(f"Hot reload failed: {e}") | |
| logger.info("Sync finished!") | |
| if __name__ == "__main__": | |
| main() | |