Spaces:
Paused
Paused
| from flask import Flask, request, jsonify | |
| import pandas as pd | |
| import numpy as np | |
| import baostock as bs | |
| from sklearn.preprocessing import MinMaxScaler | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import mean_absolute_error | |
| from neuralprophet import NeuralProphet, set_log_level | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| import os | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| # Set log level to suppress unnecessary warnings | |
| set_log_level("ERROR") | |
| # Baostock API login | |
| lg = bs.login() | |
| if lg.error_code != '0': | |
| raise ConnectionError(f"Baostock login failed. Error code: {lg.error_code}, Error message: {lg.error_msg}") | |
| # Collect historical data | |
| def get_historical_data(start_date, end_date): | |
| data = bs.query_history_k_data_plus( | |
| "sz.000001", # Shanghai Composite Index | |
| "date,open,high,low,close,volume", | |
| start_date=start_date, | |
| end_date=end_date, | |
| frequency="d" | |
| ) | |
| if data.error_code != '0': | |
| raise ValueError(f"Error in fetching data: {data.error_msg}") | |
| data_list = [] | |
| while data.next(): | |
| data_list.append(data.get_row_data()) | |
| data_df = pd.DataFrame(data_list, columns=data.fields) | |
| # Convert relevant columns to numeric type | |
| data_df[['open', 'close', 'high', 'low', 'volume']] = data_df[['open', 'close', 'high', 'low', 'volume']].apply(pd.to_numeric, errors='coerce') | |
| return data_df.dropna() | |
| # Filter stocks based on conditions | |
| def filter_stocks(data_df): | |
| data_df = data_df[(data_df["open"] >= 0.98 * data_df["close"].shift(1).fillna(0)) & (data_df["open"] <= 1.02 * data_df["close"].shift(1).fillna(0))] | |
| data_df = data_df[(data_df["high"] == data_df["close"]) & (data_df["low"] == data_df["close"]) & (data_df["open"] != 0) & (data_df["close"] != 0)] | |
| return data_df | |
| # Prepare the training and validation data | |
| data_df = get_historical_data("2005-05-30", "2024-01-31") | |
| filtered_df = filter_stocks(data_df) | |
| if filtered_df.empty: | |
| raise ValueError("Filtered dataset is empty. Please adjust the filtering conditions.") | |
| train_data, val_data = train_test_split(filtered_df, test_size=0.2, random_state=42) | |
| # Define custom model | |
| class CustomModel(nn.Module): | |
| def __init__(self): | |
| super(CustomModel, self).__init__() | |
| self.neural_prophet = NeuralProphet( | |
| n_forecasts=1, | |
| n_lags=30, | |
| n_changepoints=10, | |
| changepoints_range=0.8, | |
| learning_rate=1e-3, | |
| optimizer=Adam, | |
| ) | |
| def predict(self, df): | |
| future = self.neural_prophet.make_future_dataframe(df, periods=1) | |
| forecast = self.neural_prophet.predict(future) | |
| return forecast['yhat1'].values | |
| # Instantiate model | |
| model = CustomModel() | |
| # Prepare data for prediction | |
| def prepare_data(date): | |
| data_df = get_historical_data("2005-05-30", date) | |
| filtered_df = filter_stocks(data_df) | |
| if filtered_df.empty: | |
| return pd.DataFrame() # Return empty DataFrame if no data matches the filter | |
| # Scale the data using MinMaxScaler | |
| scaler = MinMaxScaler() | |
| filtered_df[['open', 'high', 'low', 'close', 'volume']] = scaler.fit_transform(filtered_df[['open', 'high', 'low', 'close', 'volume']]) | |
| return filtered_df | |
| # Define a route to predict the top 5 stock codes | |
| def predict(): | |
| try: | |
| date = request.json['date'] | |
| data_df = prepare_data(date) | |
| if data_df.empty: | |
| return jsonify({'error': 'No data available for the given date'}), 400 | |
| y_pred = model.predict(data_df) | |
| top_5_stocks = y_pred[:5] # Assuming y_pred contains the predicted values for stocks | |
| return jsonify({'top_5_stocks': top_5_stocks.tolist()}) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| # Run the Flask app | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |
| # Logout from Baostock API | |
| bs.logout() |