astacn commited on
Commit
abea06b
·
verified ·
1 Parent(s): 93d8576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -5
app.py CHANGED
@@ -1,7 +1,53 @@
1
- from fastapi import FastAPI
 
 
 
2
 
3
- app = FastAPI()
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ffrom flask import Flask, request, jsonify
2
+ from sklearn.preprocessing import MinMaxScaler
3
+ import pandas as pd
4
+ import os
5
 
6
+ app = Flask(__name__)
7
 
8
+ # Load the prediction model
9
+ model = CustomModel()
10
+
11
+ # Define a function to prepare the data for prediction
12
+ def prepare_data(date):
13
+ # Get the historical data for the given date
14
+ data = bs.query_history_k_data_plus(
15
+ "sz.000001", # Shanghai Composite Index
16
+ "date,open,high,low,close,volume",
17
+ start_date="2005-05-30",
18
+ end_date=date,
19
+ frequency="d"
20
+ )
21
+ data_list = []
22
+ while (data.error_code == '0') & data.next():
23
+ data_list.append(data.get_row_data())
24
+ data_df = pd.DataFrame(data_list, columns=data.fields)
25
+
26
+ # Convert 'open' and 'close' columns to numeric type
27
+ data_df['open'] = pd.to_numeric(data_df['open'])
28
+ data_df['close'] = pd.to_numeric(data_df['close'])
29
+
30
+ # Filter out stocks that meet the conditions
31
+ data_df = data_df[(data_df["open"] >= 0.98 * data_df["close"].shift(1).fillna(0)) & (data_df["open"] <= 1.02 * data_df["close"].shift(1).fillna(0))]
32
+ data_df = data_df[(data_df["high"] == data_df["close"]) & (data_df["low"] == data_df["close"])] # limit-up condition
33
+ data_df = data_df[(data_df["open"]!= 0) & (data_df["close"]!= 0)] # exclude zero prices
34
+
35
+ # Scale the data using MinMaxScaler
36
+ scaler = MinMaxScaler()
37
+ data_df[['open', 'high', 'low', 'close', 'volume']] = scaler.fit_transform(data_df[['open', 'high', 'low', 'close', 'volume']])
38
+
39
+ return data_df
40
+
41
+ # Define a route to predict the top 5 stock codes
42
+ @app.route('/predict', methods=['POST'])
43
+ def predict():
44
+ date = request.json['date']
45
+ data_df = prepare_data(date)
46
+ if data_df.empty:
47
+ return jsonify({'error': 'No data available for the given date'}), 400
48
+ y_pred = model.predict(data_df)
49
+ top_5_stocks = data_df.iloc[y_pred.argsort()[-5:]]
50
+ return jsonify({'top_5_stocks': top_5_stocks['code'].tolist()})
51
+
52
+ if __name__ == '__main__':
53
+ app.run(debug=True)