adpro commited on
Commit
17ac23b
verified
1 Parent(s): b980261

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -91
app.py CHANGED
@@ -4,9 +4,6 @@ from vnstock import Vnstock
4
  from datetime import datetime, timedelta
5
  import pandas as pd
6
  import numpy as np
7
- from sklearn.preprocessing import MinMaxScaler
8
- import torch
9
- from torch_geometric.data import Data
10
 
11
  vn = Vnstock()
12
  app = FastAPI(title="mFund VNStock API", version="1.0")
@@ -43,21 +40,6 @@ def calc_bollinger(series, window=20, num_std=2):
43
  std = series.rolling(window).std()
44
  return sma, sma + num_std * std, sma - num_std * std
45
 
46
- # ============================
47
- # M么 h矛nh GNN 膽啤n gi岷
48
- # ============================
49
- class StockGCN(torch.nn.Module):
50
- def __init__(self, num_features, hidden=16):
51
- super().__init__()
52
- from torch_geometric.nn import GCNConv
53
- self.conv1 = GCNConv(num_features, hidden)
54
- self.conv2 = GCNConv(hidden, 1)
55
-
56
- def forward(self, data):
57
- x, edge = data.x, data.edge_index
58
- x = torch.relu(self.conv1(x, edge))
59
- return self.conv2(x, edge)
60
-
61
  # ============================
62
  # 1) GET /stock/history
63
  # ============================
@@ -68,6 +50,9 @@ def get_history(symbol: str, start: str, end: str):
68
  return {"error": "end must be >= start"}
69
 
70
  stock = vn.stock(symbol=symbol)
 
 
 
71
  df = stock.quote.history(start=start, end=end, interval="1D")
72
 
73
  if df is None or df.empty:
@@ -94,6 +79,9 @@ def get_ta(symbol: str, start: str, end: str):
94
  return {"error": "end must be >= start"}
95
 
96
  stock = vn.stock(symbol=symbol)
 
 
 
97
  df = stock.quote.history(start=start, end=end, interval="1D")
98
 
99
  if df is None or df.empty:
@@ -109,78 +97,6 @@ def get_ta(symbol: str, start: str, end: str):
109
  except Exception as e:
110
  return {"error": str(e)}
111
 
112
- # ============================
113
- # 3) GET /stock/gnn
114
- # ============================
115
- @app.get("/stock/gnn")
116
- def get_gnn(symbol: str, days: int = 7):
117
- try:
118
- end = datetime.today()
119
- start = end - timedelta(days=365)
120
-
121
- stock = vn.stock(symbol=symbol)
122
- df = stock.quote.history(
123
- start=start.strftime("%Y-%m-%d"),
124
- end=end.strftime("%Y-%m-%d"),
125
- interval="1D"
126
- )
127
-
128
- if df is None or df.empty:
129
- return {"error": "no data"}
130
-
131
- df = df.rename(columns={"close": "Close"}).dropna()
132
- scaler = MinMaxScaler()
133
- df_scaled = scaler.fit_transform(df[["Close"]])
134
-
135
- edge_index = torch.tensor(
136
- [[i, i + 1] for i in range(len(df_scaled) - 1)],
137
- dtype=torch.long
138
- ).t()
139
-
140
- x = torch.tensor(df_scaled, dtype=torch.float)
141
- data_obj = Data(x=x, edge_index=edge_index)
142
-
143
- model = StockGCN(num_features=1)
144
- model.eval()
145
-
146
- preds_scaled = []
147
- last_value = torch.tensor([[df_scaled[-1][0]]], dtype=torch.float)
148
-
149
- for _ in range(days):
150
- new_obj = Data(
151
- x=torch.cat([data_obj.x, last_value]),
152
- edge_index=torch.tensor(
153
- [[i, i + 1] for i in range(len(data_obj.x))],
154
- dtype=torch.long
155
- ).t()
156
- )
157
- out = model(new_obj)
158
- last_value = out[-1].view(1, 1)
159
- preds_scaled.append(last_value.item())
160
-
161
- preds_real = scaler.inverse_transform(
162
- np.array(preds_scaled).reshape(-1, 1)
163
- ).flatten()
164
-
165
- dates = [
166
- (end + timedelta(days=i + 1)).strftime("%Y-%m-%d")
167
- for i in range(days)
168
- ]
169
-
170
- return {
171
- "symbol": symbol,
172
- "today_close": float(df["Close"].iloc[-1]),
173
- "predictions": [
174
- {"date": d, "price": float(p)}
175
- for d, p in zip(dates, preds_real)
176
- ],
177
- }
178
- except Exception as e:
179
- return {"error": str(e)}
180
-
181
- # ============================
182
- # Root
183
- # ============================
184
  @app.get("/")
185
  def root():
186
  return {
@@ -188,6 +104,5 @@ def root():
188
  "endpoints": [
189
  "/stock/history?symbol=FPT&start=2023-01-01&end=2023-12-31",
190
  "/stock/ta?symbol=HPG&start=2023-01-01&end=2023-12-31",
191
- "/stock/gnn?symbol=VNM&days=7",
192
  ],
193
  }
 
4
  from datetime import datetime, timedelta
5
  import pandas as pd
6
  import numpy as np
 
 
 
7
 
8
  vn = Vnstock()
9
  app = FastAPI(title="mFund VNStock API", version="1.0")
 
40
  std = series.rolling(window).std()
41
  return sma, sma + num_std * std, sma - num_std * std
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # ============================
44
  # 1) GET /stock/history
45
  # ============================
 
50
  return {"error": "end must be >= start"}
51
 
52
  stock = vn.stock(symbol=symbol)
53
+ if stock is None:
54
+ return {"error": "invalid symbol"}
55
+
56
  df = stock.quote.history(start=start, end=end, interval="1D")
57
 
58
  if df is None or df.empty:
 
79
  return {"error": "end must be >= start"}
80
 
81
  stock = vn.stock(symbol=symbol)
82
+ if stock is None:
83
+ return {"error": "invalid symbol"}
84
+
85
  df = stock.quote.history(start=start, end=end, interval="1D")
86
 
87
  if df is None or df.empty:
 
97
  except Exception as e:
98
  return {"error": str(e)}
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @app.get("/")
101
  def root():
102
  return {
 
104
  "endpoints": [
105
  "/stock/history?symbol=FPT&start=2023-01-01&end=2023-12-31",
106
  "/stock/ta?symbol=HPG&start=2023-01-01&end=2023-12-31",
 
107
  ],
108
  }