farquasar commited on
Commit
d8fd664
·
verified ·
1 Parent(s): 0aa887c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -23
app.py CHANGED
@@ -17,6 +17,28 @@ def fetch_binance_data(symbol, timeframe, limit=2000):
17
  df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
18
  return df
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Rolling Window Normalizer
21
  class RollingWindowNormalizer:
22
  def __init__(self, window=24):
@@ -87,18 +109,35 @@ def create_features_and_labels_with_advanced_features(btc, eth):
87
  features = np.vstack((btc_features, eth_features))
88
  return features, labels
89
 
90
- def get_data_predict(btc_ori, eth_ori, symbol='BCH/USDT', timeframe='4h', epsilon=2, normalized=False, limit=50):
91
- btc_data_ = fetch_binance_data('BTC/USDT', timeframe, limit=limit)
92
- eth_data_ = fetch_binance_data(symbol, timeframe, limit=limit)
 
 
 
 
 
 
 
 
 
 
 
93
  btc_data_ = remove_outliers(btc_data_, epsilon)
94
- eth_data_ = remove_outliers(eth_data_, epsilon)
 
95
  if normalized:
96
- btc_data_all = pd.concat([btc_ori, btc_data_]).drop_duplicates(subset='timestamp').reset_index(drop=True)
97
- eth_data_all = pd.concat([eth_ori, eth_data_]).drop_duplicates(subset='timestamp').reset_index(drop=True)
98
- btc_data_, _ = normalize(btc_data_all)
99
- eth_data_, _ = normalize(eth_data_all)
 
100
  label = btc_data_.copy()[['timestamp','close']].shift(-1)
101
- return btc_data_, eth_data_, label
 
 
 
 
102
 
103
  def predictions(model, X1, X2, name, n_steps):
104
  features_, labels_ = create_features_and_labels_with_advanced_features(X1, X2)
@@ -133,23 +172,37 @@ with open('model_n4h_cat.pkl','rb') as f:
133
  model_n4h_cat = pickle.load(f)
134
 
135
  def predict_and_plot(timeframe, limit, epsilon, n_steps, ma):
136
- btc_ori = yf.download('BTC-USD', period=f'{limit}d', interval=timeframe)
137
- eth_ori = yf.download('BCH-USD', period=f'{limit}d', interval=timeframe)
138
- btc_data, eth_data, label = get_data_predict(btc_ori, eth_ori, symbol='ETH/USDT', timeframe=timeframe, epsilon=epsilon, normalized=True, limit=limit)
 
 
 
 
 
 
 
 
 
 
139
  model = model_n1d_cat if timeframe=='1d' else model_n4h_cat
140
- preds = predictions(model, btc_data, eth_data, name=timeframe, n_steps=n_steps)
141
- fig = plot(preds, label=btc_data, timeframe=timeframe, ma=ma, n_steps=n_steps)
142
  return fig
143
 
144
- interface = gr.Interface(fn=predict_and_plot,
145
- inputs=[gr.Dropdown(['1d','4h'], label='Timeframe', value='1d'),
146
- gr.Slider(50,500,step=50,value=100,label='Data Limit'),
147
- gr.Slider(0.1,5.0,step=0.1,value=2.0,label='Epsilon'),
148
- gr.Slider(50,500,step=50,value=200,label='N_steps'),
149
- gr.Slider(1,20,step=1,value=5,label='Moving Average Window (ma)')],
150
- outputs=gr.Plot(),
151
- title='BTC Price Movement Prediction',
152
- description='Predict BTC price movements using pre-trained LightGBM models.')
 
 
 
 
153
 
154
  if __name__=='__main__':
155
  interface.launch()
 
17
  df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
18
  return df
19
 
20
+ def fetch_yfinance_data(pair: str, period: str, interval: str) -> pd.DataFrame:
21
+ """
22
+ pair: e.g. "BCH/USDT" or "BTC/USDT"
23
+ period: e.g. "100d"
24
+ interval: e.g. "1d", "60m", "90m", "1h"
25
+ """
26
+ # Yahoo uses e.g. "BCH-USD" for BCH/USDT
27
+ ticker = pair.replace("/USDT", "-USD")
28
+ df = yf.download(ticker, period=period, interval=interval)
29
+ df = df.reset_index().rename(columns={
30
+ 'Datetime': 'timestamp',
31
+ 'Date': 'timestamp',
32
+ 'Open': 'open',
33
+ 'High': 'high',
34
+ 'Low': 'low',
35
+ 'Close': 'close',
36
+ 'Volume': 'volume'
37
+ })
38
+ # ensure we have a timestamp column in datetime
39
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
40
+ return df
41
+
42
  # Rolling Window Normalizer
43
  class RollingWindowNormalizer:
44
  def __init__(self, window=24):
 
109
  features = np.vstack((btc_features, eth_features))
110
  return features, labels
111
 
112
+ def get_data_predict(
113
+ btc_ori: pd.DataFrame,
114
+ bch_ori: pd.DataFrame,
115
+ symbol: str = 'BCH/USDT',
116
+ timeframe: str = '4h',
117
+ epsilon: float = 2,
118
+ normalized: bool = False,
119
+ limit: int = 50
120
+ ):
121
+ period = f'{limit}d' # last N days
122
+ # fetch entirely from yfinance
123
+ btc_data_ = fetch_yfinance_data('BTC/USDT', period, timeframe)
124
+ bch_data_ = fetch_yfinance_data(symbol, period, timeframe)
125
+
126
  btc_data_ = remove_outliers(btc_data_, epsilon)
127
+ bch_data_ = remove_outliers(bch_data_, epsilon)
128
+
129
  if normalized:
130
+ # merge with ori if you still want to include historical yf data
131
+ btc_all = pd.concat([btc_ori, btc_data_]).drop_duplicates('timestamp').reset_index(drop=True)
132
+ bch_all = pd.concat([bch_ori, bch_data_]).drop_duplicates('timestamp').reset_index(drop=True)
133
+ btc_data_, _ = normalize(btc_all)
134
+ bch_data_, _ = normalize(bch_all)
135
  label = btc_data_.copy()[['timestamp','close']].shift(-1)
136
+ return btc_data_, bch_data_, label
137
+
138
+ return btc_data_, bch_data_, None
139
+
140
+
141
 
142
  def predictions(model, X1, X2, name, n_steps):
143
  features_, labels_ = create_features_and_labels_with_advanced_features(X1, X2)
 
172
  model_n4h_cat = pickle.load(f)
173
 
174
  def predict_and_plot(timeframe, limit, epsilon, n_steps, ma):
175
+ period = f'{limit}d'
176
+ # original “ori” series now also from yfinance
177
+ btc_ori = fetch_yfinance_data('BTC/USDT', period, timeframe)
178
+ bch_ori = fetch_yfinance_data('BCH/USDT', period, timeframe)
179
+
180
+ btc_data, bch_data, label = get_data_predict(
181
+ btc_ori, bch_ori,
182
+ symbol='BCH/USDT',
183
+ timeframe=timeframe,
184
+ epsilon=epsilon,
185
+ normalized=True,
186
+ limit=limit
187
+ )
188
  model = model_n1d_cat if timeframe=='1d' else model_n4h_cat
189
+ preds = predictions(model, btc_data, bch_data, name=timeframe, n_steps=n_steps)
190
+ fig = plot(preds, label=label, timeframe=timeframe, ma=ma, n_steps=n_steps)
191
  return fig
192
 
193
+ interface = gr.Interface(
194
+ fn=predict_and_plot,
195
+ inputs=[
196
+ gr.Dropdown(['1d','4h'], label='Timeframe', value='1d'),
197
+ gr.Slider(50,500,step=50,value=100,label='Data Limit (days)'),
198
+ gr.Slider(0.1,5.0,step=0.1,value=2.0,label='Epsilon'),
199
+ gr.Slider(50,500,step=50,value=200,label='N_steps'),
200
+ gr.Slider(1,20,step=1,value=5,label='Moving Average Window (ma)')
201
+ ],
202
+ outputs=gr.Plot(),
203
+ title='BTC/BCH Price Movement Prediction',
204
+ description='Fetches everything via yfinance; uses BCH/USDT ↔️ BCH-USD under the hood.'
205
+ )
206
 
207
  if __name__=='__main__':
208
  interface.launch()