Spaces:
Paused
Paused
File size: 3,936 Bytes
2fc58fd abea06b c83db7b c037f49 c83db7b c037f49 c83db7b c037f49 abea06b c83db7b c037f49 a2f53f8 29b6798 c037f49 c83db7b c037f49 c83db7b c037f49 c83db7b c037f49 c83db7b c037f49 c83db7b c037f49 c83db7b c037f49 c83db7b c037f49 abea06b c037f49 abea06b c037f49 abea06b c037f49 abea06b c037f49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from flask import Flask, request, jsonify
import pandas as pd
import numpy as np
import baostock as bs
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
from neuralprophet import NeuralProphet, set_log_level
import torch
import torch.nn as nn
from torch.optim import Adam
import os
# Initialize Flask app
app = Flask(__name__)
# Set log level to suppress unnecessary warnings
set_log_level("ERROR")
# Baostock API login
lg = bs.login()
if lg.error_code != '0':
raise ConnectionError(f"Baostock login failed. Error code: {lg.error_code}, Error message: {lg.error_msg}")
# Collect historical data
def get_historical_data(start_date, end_date):
data = bs.query_history_k_data_plus(
"sz.000001", # Shanghai Composite Index
"date,open,high,low,close,volume",
start_date=start_date,
end_date=end_date,
frequency="d"
)
if data.error_code != '0':
raise ValueError(f"Error in fetching data: {data.error_msg}")
data_list = []
while data.next():
data_list.append(data.get_row_data())
data_df = pd.DataFrame(data_list, columns=data.fields)
# Convert relevant columns to numeric type
data_df[['open', 'close', 'high', 'low', 'volume']] = data_df[['open', 'close', 'high', 'low', 'volume']].apply(pd.to_numeric, errors='coerce')
return data_df.dropna()
# Filter stocks based on conditions
def filter_stocks(data_df):
data_df = data_df[(data_df["open"] >= 0.98 * data_df["close"].shift(1).fillna(0)) & (data_df["open"] <= 1.02 * data_df["close"].shift(1).fillna(0))]
data_df = data_df[(data_df["high"] == data_df["close"]) & (data_df["low"] == data_df["close"]) & (data_df["open"] != 0) & (data_df["close"] != 0)]
return data_df
# Prepare the training and validation data
data_df = get_historical_data("2005-05-30", "2024-01-31")
filtered_df = filter_stocks(data_df)
if filtered_df.empty:
raise ValueError("Filtered dataset is empty. Please adjust the filtering conditions.")
train_data, val_data = train_test_split(filtered_df, test_size=0.2, random_state=42)
# Define custom model
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
self.neural_prophet = NeuralProphet(
n_forecasts=1,
n_lags=30,
n_changepoints=10,
changepoints_range=0.8,
learning_rate=1e-3,
optimizer=Adam,
)
def predict(self, df):
future = self.neural_prophet.make_future_dataframe(df, periods=1)
forecast = self.neural_prophet.predict(future)
return forecast['yhat1'].values
# Instantiate model
model = CustomModel()
# Prepare data for prediction
def prepare_data(date):
data_df = get_historical_data("2005-05-30", date)
filtered_df = filter_stocks(data_df)
if filtered_df.empty:
return pd.DataFrame() # Return empty DataFrame if no data matches the filter
# Scale the data using MinMaxScaler
scaler = MinMaxScaler()
filtered_df[['open', 'high', 'low', 'close', 'volume']] = scaler.fit_transform(filtered_df[['open', 'high', 'low', 'close', 'volume']])
return filtered_df
# Define a route to predict the top 5 stock codes
@app.route('/predict', methods=['POST'])
def predict():
try:
date = request.json['date']
data_df = prepare_data(date)
if data_df.empty:
return jsonify({'error': 'No data available for the given date'}), 400
y_pred = model.predict(data_df)
top_5_stocks = y_pred[:5] # Assuming y_pred contains the predicted values for stocks
return jsonify({'top_5_stocks': top_5_stocks.tolist()})
except Exception as e:
return jsonify({'error': str(e)}), 500
# Run the Flask app
if __name__ == '__main__':
app.run(debug=True)
# Logout from Baostock API
bs.logout() |