Spaces:
Running
Running
File size: 13,384 Bytes
2d4ab0a 5440607 2d4ab0a 1fb7b71 2d4ab0a 1fb7b71 5440607 2d4ab0a 1fb7b71 5440607 2d4ab0a de2aac1 2d4ab0a 5440607 b39d11d 5440607 2d4ab0a 5440607 1fb7b71 2d4ab0a f908ead 19a5512 db4cda7 5440607 db4cda7 5440607 f908ead 5440607 01577d6 1fb7b71 5440607 d1ec6b7 1fb7b71 2d4ab0a 1fb7b71 5440607 b39d11d 2065db8 1fb7b71 de2aac1 2065db8 19a5512 db4cda7 01577d6 b39d11d 01577d6 db4cda7 1fb7b71 2065db8 5440607 db4cda7 b39d11d db4cda7 5440607 b39d11d 5440607 2065db8 5440607 db4cda7 5440607 2065db8 d1ec6b7 b39d11d de2aac1 b39d11d 5440607 2d4ab0a 15e51a9 5440607 15e51a9 5440607 15e51a9 5440607 15e51a9 de2aac1 15e51a9 5440607 15e51a9 5440607 15e51a9 5440607 15e51a9 5440607 1fb7b71 5440607 af9693c 5440607 2d4ab0a 5440607 8cd616e 5440607 2d4ab0a 5440607 babd2df 5440607 2d4ab0a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 | """
数据同步脚本 - 云端兼容版
支持本地母本模式 (DuckDB) 和 云端无盘模式 (Parquet View)
"""
import os
import sys
import logging
import time
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional, Dict, Any
from pathlib import Path
import pandas as pd
import akshare as ak
import yfinance as yf
from huggingface_hub import hf_hub_download, upload_file
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.database import get_db
from app.database_user import get_beijing_time
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 配置
YEARS_OF_DATA = 10
MAX_WORKERS = 5 # 降低并发数,减少超时
SYNC_LIMIT = -1
USE_YAHOO_FALLBACK = 0
def get_stock_list() -> pd.DataFrame:
"""获取全市场标的列表"""
logger.info("Fetching all-market target list...")
all_lists = []
try:
df_a = ak.stock_zh_a_spot_em()[['代码', '名称']]
df_a.columns = ['code', 'name']
df_a['market'] = df_a['code'].apply(lambda x: '主板' if x.startswith(('60', '00')) else ('创业板' if x.startswith('30') else ('科创板' if x.startswith('68') else ('北交所' if x.startswith(('8', '4', '920')) else '其他'))))
all_lists.append(df_a)
# ETF / LOF / REITs
try:
df_etf = ak.fund_etf_spot_em()[['代码', '名称']]
df_etf.columns = ['code', 'name']
df_etf['market'] = 'ETF'
all_lists.append(df_etf)
except Exception as e:
logger.warning(f"ETF list fetch failed: {e}")
try:
df_lof = ak.fund_lof_spot_em()[['代码', '名称']]
df_lof.columns = ['code', 'name']
df_lof['market'] = 'LOF'
all_lists.append(df_lof)
except Exception as e:
logger.warning(f"LOF list fetch failed: {e}")
try:
df_reits = ak.reits_realtime_em()[['代码', '名称']]
df_reits.columns = ['code', 'name']
df_reits['market'] = 'REITs'
all_lists.append(df_reits)
except Exception as e:
logger.warning(f"REITs list fetch failed: {e}")
try:
df_cb = ak.bond_zh_hs_cov_spot()
c_code = '代码' if '代码' in df_cb.columns else 'symbol'
c_name = '名称' if '名称' in df_cb.columns else 'name'
df_cb = df_cb[[c_code, c_name]]; df_cb.columns = ['code', 'name']; df_cb['market'] = '可转债'; all_lists.append(df_cb)
except: pass
except Exception as e:
logger.error(f"List fetching error: {e}")
if not all_lists:
db = get_db()
return db.conn.execute("SELECT code, name, market FROM stock_list").df()
df = pd.concat(all_lists).drop_duplicates(subset=['code'])
df['list_date'] = None
return df
def get_target_daily(code: str, start_date: str, market: str) -> Optional[pd.DataFrame]:
"""抓取单只标的数据"""
max_retries = 3 # 增加重试次数
for attempt in range(max_retries):
try:
end_date = get_beijing_time().strftime('%Y%m%d')
fetch_start = start_date.replace('-', '')
df = None
if market == 'INDEX':
df = ak.stock_zh_index_daily_em(symbol=f"sh{code}" if code.startswith('000') else f"sz{code}")
elif market == 'ETF':
df = ak.fund_etf_hist_em(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq")
elif market == 'LOF':
df = ak.fund_lof_hist_em(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq")
elif market == '可转债':
# 可转债接口通常使用 6 位数字代码,兼容 sh110xxx / sz12xxxx / bj81xxxx
cov_symbol = code[-6:] if len(code) > 6 else code
try:
df = ak.bond_zh_hs_cov_daily(symbol=cov_symbol)
except Exception:
df = ak.bond_zh_hs_cov_daily(symbol=code)
elif market == 'REITs':
# REITs 接口仅支持 symbol 参数,返回全量历史后再按日期过滤
df = ak.reits_hist_em(symbol=code)
else:
df = ak.stock_zh_a_hist(symbol=code, period="daily", start_date=fetch_start, end_date=end_date, adjust="hfq")
if df is not None and not df.empty:
# 标准化列名
rename_map = {
'日期': 'trade_date', 'date': 'trade_date', 'Date': 'trade_date',
'开盘': 'open', '今开': 'open', 'Open': 'open',
'最高': 'high', 'High': 'high',
'最低': 'low', 'Low': 'low',
'收盘': 'close', '最新价': 'close', 'Close': 'close',
'成交量': 'volume', 'Volume': 'volume',
'成交额': 'amount', 'Amount': 'amount',
'涨跌幅': 'pct_chg',
'换手率': 'turnover_rate', '换手': 'turnover_rate'
}
df = df.rename(columns=rename_map)
# 如果还是没有 trade_date,尝试将索引转为列
if 'trade_date' not in df.columns:
df = df.reset_index().rename(columns={'index': 'trade_date', 'date': 'trade_date'})
df['trade_date'] = pd.to_datetime(df['trade_date'])
df = df[df['trade_date'] >= pd.to_datetime(start_date)]
if 'amount' not in df.columns: df['amount'] = 0
if 'pct_chg' not in df.columns: df['pct_chg'] = df['close'].pct_change() * 100
if 'turnover_rate' not in df.columns: df['turnover_rate'] = 0
df['code'] = code
return df[['code', 'trade_date', 'open', 'high', 'low', 'close', 'volume', 'amount', 'pct_chg', 'turnover_rate']]
except Exception as e:
if attempt == max_retries - 1:
logger.warning(f"Failed to fetch {code} ({market}): {str(e)}")
time.sleep(1)
return None
def get_last_trading_day() -> str:
"""获取最近一个交易日(用于增量截止日)"""
try:
df = ak.stock_zh_index_daily_em(symbol="sh000300")
if df is not None and not df.empty:
date_col = 'date' if 'date' in df.columns else ('日期' if '日期' in df.columns else None)
if date_col:
return pd.to_datetime(df[date_col].iloc[-1]).strftime('%Y-%m-%d')
except Exception as e:
logger.warning(f"Failed to get last trading day from index data: {e}")
# 回退:按工作日估算
d = get_beijing_time()
while d.weekday() >= 5: # 5=周六, 6=周日
d -= timedelta(days=1)
return d.strftime('%Y-%m-%d')
def get_index_daily(code: str) -> Optional[pd.DataFrame]:
"""抓取指数日线(默认用于沪深300)"""
try:
symbol = f"sh{code}" if code.startswith('000') else f"sz{code}"
df = ak.stock_zh_index_daily_em(symbol=symbol)
if df is None or df.empty:
return None
rename_map = {
'date': 'trade_date', '日期': 'trade_date',
'open': 'open', '开盘': 'open',
'high': 'high', '最高': 'high',
'low': 'low', '最低': 'low',
'close': 'close', '收盘': 'close',
'volume': 'volume', '成交量': 'volume',
'amount': 'amount', '成交额': 'amount',
'pct_chg': 'pct_chg', '涨跌幅': 'pct_chg'
}
df = df.rename(columns=rename_map)
if 'trade_date' not in df.columns:
return None
df['trade_date'] = pd.to_datetime(df['trade_date'])
if 'amount' not in df.columns:
df['amount'] = 0
if 'pct_chg' not in df.columns:
df['pct_chg'] = df['close'].pct_change() * 100
if 'volume' not in df.columns:
df['volume'] = 0
df['turnover_rate'] = 0
df['code'] = code
return df[['code', 'trade_date', 'open', 'high', 'low', 'close', 'volume', 'amount', 'pct_chg', 'turnover_rate']]
except Exception as e:
logger.warning(f"Failed to fetch index {code}: {e}")
return None
def sync_stock_daily(targets: List[Dict[str, str]], last_trade_day: str) -> int:
"""增量同步逻辑 - 兼容云端视图"""
db = get_db()
# 1. 获取现状 (无论是表还是视图)
existing_latest = db.conn.execute("SELECT code, CAST(MAX(trade_date) AS VARCHAR) FROM stock_daily GROUP BY code").fetchall()
latest_map = {row[0]: row[1] for row in existing_latest}
pending = []
for t in targets:
code = t['code']
if code in latest_map:
if latest_map[code] >= last_trade_day: continue
start_dt = (pd.to_datetime(latest_map[code]) + timedelta(days=1)).strftime('%Y-%m-%d')
else:
start_dt = (get_beijing_time() - timedelta(days=YEARS_OF_DATA * 365)).strftime('%Y-%m-%d')
t['start_dt'] = start_dt
pending.append(t)
if not pending: return 0
logger.info(f"Syncing {len(pending)} targets...")
all_new_data = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = {executor.submit(get_target_daily, t['code'], t['start_dt'], t['market']): t['code'] for t in pending}
for i, future in enumerate(as_completed(futures), 1):
res = future.result()
if res is not None: all_new_data.append(res)
if i % 500 == 0: logger.info(f"Progress: {i}/{len(pending)}")
if all_new_data:
inc_df = pd.concat(all_new_data, ignore_index=True)
# 识别变动月份
changed = inc_df.assign(yr=inc_df['trade_date'].dt.year, mo=inc_df['trade_date'].dt.month)[['yr', 'mo']].drop_duplicates().values
for yr, mo in changed:
yr, mo = int(yr), int(mo)
filename = f"{yr}-{mo:02d}.parquet"
local_path = Path(f"backend/data/parquet/{filename}")
local_path.parent.mkdir(parents=True, exist_ok=True)
# 云端增量核心:如果本地没母本,先从云端拉取旧的月份文件
if not Path(os.getenv("DUCKDB_PATH", "")).exists() or os.path.getsize(os.getenv("DUCKDB_PATH")) < 1024:
try:
repo_id = os.getenv("DATASET_REPO_ID")
if repo_id:
old_file = hf_hub_download(repo_id=repo_id, filename=f"data/parquet/{filename}", repo_type="dataset")
old_df = pd.read_parquet(old_file)
# 合并
month_inc = inc_df[(inc_df['trade_date'].dt.year == yr) & (inc_df['trade_date'].dt.month == mo)]
final_month_df = pd.concat([old_df, month_inc]).drop_duplicates(subset=['code', 'trade_date'])
final_month_df.to_parquet(local_path)
logger.info(f"Merged cloud data for {filename}")
continue
except: pass
# 本地母本模式:直接从母本导出
db.conn.execute("CREATE TEMP TABLE temp_inc AS SELECT * FROM inc_df")
db.conn.execute("INSERT OR IGNORE INTO stock_daily SELECT * FROM temp_inc")
db.conn.execute(f"COPY (SELECT * FROM stock_daily WHERE date_part('year', trade_date) = {yr} AND date_part('month', trade_date) = {mo}) TO '{local_path}' (FORMAT PARQUET)")
db.conn.execute("DROP TABLE temp_inc")
return len(all_new_data)
def main():
logger.info("Sync started...")
db = get_db(); db.init_db()
# 1. 列表同步
target_list = get_stock_list()
# 如果是本地母本,更新表;如果是云端,导出 parquet 准备上传
if Path(os.getenv("DUCKDB_PATH", "")).exists():
db.conn.execute("DELETE FROM stock_list")
db.conn.execute("INSERT INTO stock_list SELECT code, name, market, list_date FROM target_list")
list_parquet = Path("backend/data/stock_list.parquet")
list_parquet.parent.mkdir(parents=True, exist_ok=True)
target_list.to_parquet(list_parquet)
# 2. 行情同步
last_day = get_last_trading_day()
sync_stock_daily(target_list.to_dict('records'), last_day)
# 3. 指数同步
idx_df = get_index_daily('000300')
if idx_df is not None:
idx_path = Path("backend/data/parquet/index_000300.parquet")
idx_df.to_parquet(idx_path)
# 4. 上传
db.upload_db()
# 5. 热重载:刷新数据库视图和缓存
try:
logger.info("Reloading database views to reflect new data...")
db.init_db() # 重新下载并挂载最新的 Parquet 文件
from app.core import clear_eligible_stocks_cache
clear_eligible_stocks_cache() # 清除内存缓存
logger.info("Hot reload completed!")
except Exception as e:
logger.error(f"Hot reload failed: {e}")
logger.info("Sync finished!")
if __name__ == "__main__":
main()
|