File size: 3,936 Bytes
2fc58fd
abea06b
c83db7b
 
c037f49
c83db7b
c037f49
c83db7b
 
 
c037f49
abea06b
c83db7b
c037f49
a2f53f8
29b6798
c037f49
c83db7b
 
c037f49
 
 
 
c83db7b
 
c037f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c83db7b
 
 
 
 
c037f49
 
 
c83db7b
 
 
 
 
c037f49
c83db7b
c037f49
c83db7b
c037f49
c83db7b
 
c037f49
abea06b
c037f49
 
 
 
abea06b
 
 
c037f49
 
abea06b
 
 
 
c037f49
 
 
 
 
 
 
 
 
 
 
 
 
abea06b
c037f49
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from flask import Flask, request, jsonify
import pandas as pd
import numpy as np
import baostock as bs
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
from neuralprophet import NeuralProphet, set_log_level
import torch
import torch.nn as nn
from torch.optim import Adam
import os

# Initialize Flask app
app = Flask(__name__)

# Set log level to suppress unnecessary warnings
set_log_level("ERROR")

# Baostock API login
lg = bs.login()
if lg.error_code != '0':
    raise ConnectionError(f"Baostock login failed. Error code: {lg.error_code}, Error message: {lg.error_msg}")

# Collect historical data
def get_historical_data(start_date, end_date):
    data = bs.query_history_k_data_plus(
        "sz.000001",  # Shanghai Composite Index
        "date,open,high,low,close,volume",
        start_date=start_date,
        end_date=end_date,
        frequency="d"
    )
    if data.error_code != '0':
        raise ValueError(f"Error in fetching data: {data.error_msg}")

    data_list = []
    while data.next():
        data_list.append(data.get_row_data())
    data_df = pd.DataFrame(data_list, columns=data.fields)

    # Convert relevant columns to numeric type
    data_df[['open', 'close', 'high', 'low', 'volume']] = data_df[['open', 'close', 'high', 'low', 'volume']].apply(pd.to_numeric, errors='coerce')
    return data_df.dropna()

# Filter stocks based on conditions
def filter_stocks(data_df):
    data_df = data_df[(data_df["open"] >= 0.98 * data_df["close"].shift(1).fillna(0)) & (data_df["open"] <= 1.02 * data_df["close"].shift(1).fillna(0))]
    data_df = data_df[(data_df["high"] == data_df["close"]) & (data_df["low"] == data_df["close"]) & (data_df["open"] != 0) & (data_df["close"] != 0)]
    return data_df

# Prepare the training and validation data
data_df = get_historical_data("2005-05-30", "2024-01-31")
filtered_df = filter_stocks(data_df)

if filtered_df.empty:
    raise ValueError("Filtered dataset is empty. Please adjust the filtering conditions.")

train_data, val_data = train_test_split(filtered_df, test_size=0.2, random_state=42)

# Define custom model
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.neural_prophet = NeuralProphet(
            n_forecasts=1,
            n_lags=30,
            n_changepoints=10,
            changepoints_range=0.8,
            learning_rate=1e-3,
            optimizer=Adam,
        )

    def predict(self, df):
        future = self.neural_prophet.make_future_dataframe(df, periods=1)
        forecast = self.neural_prophet.predict(future)
        return forecast['yhat1'].values

# Instantiate model
model = CustomModel()

# Prepare data for prediction
def prepare_data(date):
    data_df = get_historical_data("2005-05-30", date)
    filtered_df = filter_stocks(data_df)
    if filtered_df.empty:
        return pd.DataFrame()  # Return empty DataFrame if no data matches the filter

    # Scale the data using MinMaxScaler
    scaler = MinMaxScaler()
    filtered_df[['open', 'high', 'low', 'close', 'volume']] = scaler.fit_transform(filtered_df[['open', 'high', 'low', 'close', 'volume']])
    return filtered_df

# Define a route to predict the top 5 stock codes
@app.route('/predict', methods=['POST'])
def predict():
    try:
        date = request.json['date']
        data_df = prepare_data(date)
        if data_df.empty:
            return jsonify({'error': 'No data available for the given date'}), 400

        y_pred = model.predict(data_df)
        top_5_stocks = y_pred[:5]  # Assuming y_pred contains the predicted values for stocks
        return jsonify({'top_5_stocks': top_5_stocks.tolist()})
    except Exception as e:
        return jsonify({'error': str(e)}), 500

# Run the Flask app
if __name__ == '__main__':
    app.run(debug=True)

# Logout from Baostock API
bs.logout()