Spaces:
Runtime error
Runtime error
| import os | |
| import pickle | |
| from typing import List | |
| import numpy as np | |
| import pandas as pd | |
| class DataProcessor: | |
| def __init__( | |
| self, | |
| data_source: str, | |
| start_date: str, | |
| end_date: str, | |
| time_interval: str, | |
| **kwargs, | |
| ): | |
| self.data_source = data_source | |
| self.start_date = start_date | |
| self.end_date = end_date | |
| self.time_interval = time_interval | |
| self.dataframe = pd.DataFrame() | |
| if self.data_source == "akshare": | |
| from meta.data_processors.akshare import Akshare | |
| processor_dict = {self.data_source: Akshare} | |
| elif self.data_source == "alpaca": | |
| from meta.data_processors.alpaca import Alpaca | |
| processor_dict = {self.data_source: Alpaca} | |
| elif self.data_source == "alphavantage": | |
| from meta.data_processors.alphavantage import Alphavantage | |
| processor_dict = {self.data_source: Alphavantage} | |
| elif self.data_source == "baostock": | |
| from meta.data_processors.baostock import Baostock | |
| processor_dict = {self.data_source: Baostock} | |
| elif self.data_source == "binance": | |
| from meta.data_processors.binance import Binance | |
| processor_dict = {self.data_source: Binance} | |
| elif self.data_source == "ccxt": | |
| from meta.data_processors.ccxt import Ccxt | |
| processor_dict = {self.data_source: Ccxt} | |
| elif self.data_source == "iexcloud": | |
| from meta.data_processors.iexcloud import Iexcloud | |
| processor_dict = {self.data_source: Iexcloud} | |
| elif self.data_source == "joinquant": | |
| from meta.data_processors.joinquant import Joinquant | |
| processor_dict = {self.data_source: Joinquant} | |
| elif self.data_source == "quandl": | |
| from meta.data_processors.quandl import Quandl | |
| processor_dict = {self.data_source: Quandl} | |
| elif self.data_source == "quantconnect": | |
| from meta.data_processors.quantconnect import Quantconnect | |
| processor_dict = {self.data_source: Quantconnect} | |
| elif self.data_source == "ricequant": | |
| from meta.data_processors.ricequant import Ricequant | |
| processor_dict = {self.data_source: Ricequant} | |
| elif self.data_source == "tushare": | |
| from meta.data_processors.tushare import Tushare | |
| processor_dict = {self.data_source: Tushare} | |
| elif self.data_source == "wrds": | |
| from meta.data_processors.wrds import Wrds | |
| processor_dict = {self.data_source: Wrds} | |
| elif self.data_source == "yahoofinance": | |
| from meta.data_processors.yahoofinance import Yahoofinance | |
| processor_dict = {self.data_source: Yahoofinance} | |
| else: | |
| print(f"{self.data_source} is NOT supported yet.") | |
| try: | |
| self.processor = processor_dict.get(self.data_source)( | |
| data_source, start_date, end_date, time_interval, **kwargs | |
| ) | |
| print(f"{self.data_source} successfully connected") | |
| except: | |
| raise ValueError( | |
| f"Please input correct account info for {self.data_source}!" | |
| ) | |
| def download_data(self, ticker_list): | |
| self.processor.download_data(ticker_list=ticker_list) | |
| self.dataframe = self.processor.dataframe | |
| def clean_data(self): | |
| self.processor.dataframe = self.dataframe | |
| self.processor.clean_data() | |
| self.dataframe = self.processor.dataframe | |
| def add_technical_indicator( | |
| self, tech_indicator_list: List[str], select_stockstats_talib: int = 0 | |
| ): | |
| self.tech_indicator_list = tech_indicator_list | |
| self.processor.add_technical_indicator( | |
| tech_indicator_list, select_stockstats_talib | |
| ) | |
| self.dataframe = self.processor.dataframe | |
| def add_turbulence(self): | |
| self.processor.add_turbulence() | |
| self.dataframe = self.processor.dataframe | |
| def add_vix(self): | |
| self.processor.add_vix() | |
| self.dataframe = self.processor.dataframe | |
| def df_to_array(self, if_vix: bool) -> np.array: | |
| price_array, tech_array, turbulence_array = self.processor.df_to_array( | |
| self.tech_indicator_list, if_vix | |
| ) | |
| # fill nan with 0 for technical indicators | |
| tech_nan_positions = np.isnan(tech_array) | |
| tech_array[tech_nan_positions] = 0 | |
| return price_array, tech_array, turbulence_array | |
| def data_split(self, df, start, end, target_date_col="time"): | |
| """ | |
| split the dataset into training or testing using date | |
| :param data: (df) pandas dataframe, start, end | |
| :return: (df) pandas dataframe | |
| """ | |
| data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)] | |
| data = data.sort_values([target_date_col, "tic"], ignore_index=True) | |
| data.index = data[target_date_col].factorize()[0] | |
| return data | |
| def fillna(self): | |
| self.processor.dataframe = self.dataframe | |
| self.processor.fillna() | |
| self.dataframe = self.processor.dataframe | |
| def run( | |
| self, | |
| ticker_list: str, | |
| technical_indicator_list: List[str], | |
| if_vix: bool, | |
| cache: bool = False, | |
| select_stockstats_talib: int = 0, | |
| ): | |
| if self.time_interval == "1s" and self.data_source != "binance": | |
| raise ValueError( | |
| "Currently 1s interval data is only supported with 'binance' as data source" | |
| ) | |
| cache_filename = ( | |
| "_".join( | |
| ticker_list | |
| + [ | |
| self.data_source, | |
| self.start_date, | |
| self.end_date, | |
| self.time_interval, | |
| ] | |
| ) | |
| + ".pickle" | |
| ) | |
| cache_dir = "./cache" | |
| cache_path = os.path.join(cache_dir, cache_filename) | |
| if cache and os.path.isfile(cache_path): | |
| print(f"Using cached file {cache_path}") | |
| self.tech_indicator_list = technical_indicator_list | |
| with open(cache_path, "rb") as handle: | |
| self.processor.dataframe = pickle.load(handle) | |
| else: | |
| self.download_data(ticker_list) | |
| self.clean_data() | |
| if cache: | |
| if not os.path.exists(cache_dir): | |
| os.mkdir(cache_dir) | |
| with open(cache_path, "wb") as handle: | |
| pickle.dump( | |
| self.dataframe, | |
| handle, | |
| protocol=pickle.HIGHEST_PROTOCOL, | |
| ) | |
| self.add_technical_indicator(technical_indicator_list, select_stockstats_talib) | |
| if if_vix: | |
| self.add_vix() | |
| price_array, tech_array, turbulence_array = self.df_to_array(if_vix) | |
| tech_nan_positions = np.isnan(tech_array) | |
| tech_array[tech_nan_positions] = 0 | |
| return price_array, tech_array, turbulence_array | |
| def test_joinquant(): | |
| # TRADE_START_DATE = "2019-09-01" | |
| TRADE_START_DATE = "2020-09-01" | |
| TRADE_END_DATE = "2021-09-11" | |
| # supported time interval: '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M' | |
| TIME_INTERVAL = "1d" | |
| TECHNICAL_INDICATOR = [ | |
| "macd", | |
| "boll_ub", | |
| "boll_lb", | |
| "rsi_30", | |
| "dx_30", | |
| "close_30_sma", | |
| "close_60_sma", | |
| ] | |
| kwargs = {"username": "xxx", "password": "xxx"} | |
| p = DataProcessor( | |
| data_source="joinquant", | |
| start_date=TRADE_START_DATE, | |
| end_date=TRADE_END_DATE, | |
| time_interval=TIME_INTERVAL, | |
| **kwargs, | |
| ) | |
| ticker_list = ["000612.XSHE", "601808.XSHG"] | |
| p.download_data(ticker_list=ticker_list) | |
| p.clean_data() | |
| p.add_turbulence() | |
| p.add_technical_indicator(TECHNICAL_INDICATOR) | |
| p.add_vix() | |
| price_array, tech_array, turbulence_array = p.run( | |
| ticker_list, TECHNICAL_INDICATOR, if_vix=False, cache=True | |
| ) | |
| pass | |
| # if __name__ == "__main__": | |
| # # test_joinquant() | |
| # # test_binance() | |
| # # test_yahoofinance() | |
| # test_baostock() | |
| # # test_quandl() | |