Spaces:
Runtime error
Runtime error
| from typing import List | |
| import baostock as bs | |
| import numpy as np | |
| import pandas as pd | |
| import pytz | |
| import yfinance as yf | |
| """Reference: https://github.com/AI4Finance-LLC/FinRL""" | |
| try: | |
| import exchange_calendars as tc | |
| except: | |
| print( | |
| "Cannot import exchange_calendars.", | |
| "If you are using python>=3.7, please install it.", | |
| ) | |
| import trading_calendars as tc | |
| print("Use trading_calendars instead for yahoofinance processor..") | |
| # from basic_processor import _Base | |
| from meta.data_processors._base import _Base | |
| from meta.data_processors._base import calc_time_zone | |
| from meta.config import ( | |
| TIME_ZONE_SHANGHAI, | |
| TIME_ZONE_USEASTERN, | |
| TIME_ZONE_PARIS, | |
| TIME_ZONE_BERLIN, | |
| TIME_ZONE_JAKARTA, | |
| TIME_ZONE_SELFDEFINED, | |
| USE_TIME_ZONE_SELFDEFINED, | |
| BINANCE_BASE_URL, | |
| ) | |
| class Baostock(_Base): | |
| def __init__( | |
| self, | |
| data_source: str, | |
| start_date: str, | |
| end_date: str, | |
| time_interval: str, | |
| **kwargs, | |
| ): | |
| super().__init__(data_source, start_date, end_date, time_interval, **kwargs) | |
| # 日k线、周k线、月k线,以及5分钟、15分钟、30分钟和60分钟k线数据 | |
| # ["5m", "15m", "30m", "60m", "1d", "1w", "1M"] | |
| def download_data( | |
| self, ticker_list: List[str], save_path: str = "./data/dataset.csv" | |
| ): | |
| lg = bs.login() | |
| print("baostock login respond error_code:" + lg.error_code) | |
| print("baostock login respond error_msg:" + lg.error_msg) | |
| self.time_zone = calc_time_zone( | |
| ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED | |
| ) | |
| self.dataframe = pd.DataFrame() | |
| for ticker in ticker_list: | |
| nonstandrad_ticker = self.transfer_standard_ticker_to_nonstandard(ticker) | |
| # All supported: "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST" | |
| rs = bs.query_history_k_data_plus( | |
| nonstandrad_ticker, | |
| "date,code,open,high,low,close,volume", | |
| start_date=self.start_date, | |
| end_date=self.end_date, | |
| frequency=self.time_interval, | |
| adjustflag="3", | |
| ) | |
| print("baostock download_data respond error_code:" + rs.error_code) | |
| print("baostock download_data respond error_msg:" + rs.error_msg) | |
| data_list = [] | |
| while (rs.error_code == "0") & rs.next(): | |
| data_list.append(rs.get_row_data()) | |
| df = pd.DataFrame(data_list, columns=rs.fields) | |
| df.loc[:, "code"] = [ticker] * df.shape[0] | |
| self.dataframe = pd.concat([self.dataframe, df]) | |
| self.dataframe = self.dataframe.sort_values(by=["date", "code"]).reset_index( | |
| drop=True | |
| ) | |
| bs.logout() | |
| self.dataframe.open = self.dataframe.open.astype(float) | |
| self.dataframe.high = self.dataframe.high.astype(float) | |
| self.dataframe.low = self.dataframe.low.astype(float) | |
| self.dataframe.close = self.dataframe.close.astype(float) | |
| self.save_data(save_path) | |
| print( | |
| f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}" | |
| ) | |
| def get_trading_days(self, start, end): | |
| lg = bs.login() | |
| print("baostock login respond error_code:" + lg.error_code) | |
| print("baostock login respond error_msg:" + lg.error_msg) | |
| result = bs.query_trade_dates(start_date=start, end_date=end) | |
| bs.logout() | |
| return result | |
| # "600000.XSHG" -> "sh.600000" | |
| # "000612.XSHE" -> "sz.000612" | |
| def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str: | |
| n, alpha = ticker.split(".") | |
| assert alpha in ["XSHG", "XSHE"], "Wrong alpha" | |
| if alpha == "XSHG": | |
| nonstandard_ticker = "sh." + n | |
| elif alpha == "XSHE": | |
| nonstandard_ticker = "sz." + n | |
| return nonstandard_ticker | |