Update app.py
Browse files
app.py
CHANGED
|
@@ -4,9 +4,6 @@ from vnstock import Vnstock
|
|
| 4 |
from datetime import datetime, timedelta
|
| 5 |
import pandas as pd
|
| 6 |
import numpy as np
|
| 7 |
-
from sklearn.preprocessing import MinMaxScaler
|
| 8 |
-
import torch
|
| 9 |
-
from torch_geometric.data import Data
|
| 10 |
|
| 11 |
vn = Vnstock()
|
| 12 |
app = FastAPI(title="mFund VNStock API", version="1.0")
|
|
@@ -43,21 +40,6 @@ def calc_bollinger(series, window=20, num_std=2):
|
|
| 43 |
std = series.rolling(window).std()
|
| 44 |
return sma, sma + num_std * std, sma - num_std * std
|
| 45 |
|
| 46 |
-
# ============================
|
| 47 |
-
# M么 h矛nh GNN 膽啤n gi岷
|
| 48 |
-
# ============================
|
| 49 |
-
class StockGCN(torch.nn.Module):
|
| 50 |
-
def __init__(self, num_features, hidden=16):
|
| 51 |
-
super().__init__()
|
| 52 |
-
from torch_geometric.nn import GCNConv
|
| 53 |
-
self.conv1 = GCNConv(num_features, hidden)
|
| 54 |
-
self.conv2 = GCNConv(hidden, 1)
|
| 55 |
-
|
| 56 |
-
def forward(self, data):
|
| 57 |
-
x, edge = data.x, data.edge_index
|
| 58 |
-
x = torch.relu(self.conv1(x, edge))
|
| 59 |
-
return self.conv2(x, edge)
|
| 60 |
-
|
| 61 |
# ============================
|
| 62 |
# 1) GET /stock/history
|
| 63 |
# ============================
|
|
@@ -68,6 +50,9 @@ def get_history(symbol: str, start: str, end: str):
|
|
| 68 |
return {"error": "end must be >= start"}
|
| 69 |
|
| 70 |
stock = vn.stock(symbol=symbol)
|
|
|
|
|
|
|
|
|
|
| 71 |
df = stock.quote.history(start=start, end=end, interval="1D")
|
| 72 |
|
| 73 |
if df is None or df.empty:
|
|
@@ -94,6 +79,9 @@ def get_ta(symbol: str, start: str, end: str):
|
|
| 94 |
return {"error": "end must be >= start"}
|
| 95 |
|
| 96 |
stock = vn.stock(symbol=symbol)
|
|
|
|
|
|
|
|
|
|
| 97 |
df = stock.quote.history(start=start, end=end, interval="1D")
|
| 98 |
|
| 99 |
if df is None or df.empty:
|
|
@@ -109,78 +97,6 @@ def get_ta(symbol: str, start: str, end: str):
|
|
| 109 |
except Exception as e:
|
| 110 |
return {"error": str(e)}
|
| 111 |
|
| 112 |
-
# ============================
|
| 113 |
-
# 3) GET /stock/gnn
|
| 114 |
-
# ============================
|
| 115 |
-
@app.get("/stock/gnn")
|
| 116 |
-
def get_gnn(symbol: str, days: int = 7):
|
| 117 |
-
try:
|
| 118 |
-
end = datetime.today()
|
| 119 |
-
start = end - timedelta(days=365)
|
| 120 |
-
|
| 121 |
-
stock = vn.stock(symbol=symbol)
|
| 122 |
-
df = stock.quote.history(
|
| 123 |
-
start=start.strftime("%Y-%m-%d"),
|
| 124 |
-
end=end.strftime("%Y-%m-%d"),
|
| 125 |
-
interval="1D"
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
if df is None or df.empty:
|
| 129 |
-
return {"error": "no data"}
|
| 130 |
-
|
| 131 |
-
df = df.rename(columns={"close": "Close"}).dropna()
|
| 132 |
-
scaler = MinMaxScaler()
|
| 133 |
-
df_scaled = scaler.fit_transform(df[["Close"]])
|
| 134 |
-
|
| 135 |
-
edge_index = torch.tensor(
|
| 136 |
-
[[i, i + 1] for i in range(len(df_scaled) - 1)],
|
| 137 |
-
dtype=torch.long
|
| 138 |
-
).t()
|
| 139 |
-
|
| 140 |
-
x = torch.tensor(df_scaled, dtype=torch.float)
|
| 141 |
-
data_obj = Data(x=x, edge_index=edge_index)
|
| 142 |
-
|
| 143 |
-
model = StockGCN(num_features=1)
|
| 144 |
-
model.eval()
|
| 145 |
-
|
| 146 |
-
preds_scaled = []
|
| 147 |
-
last_value = torch.tensor([[df_scaled[-1][0]]], dtype=torch.float)
|
| 148 |
-
|
| 149 |
-
for _ in range(days):
|
| 150 |
-
new_obj = Data(
|
| 151 |
-
x=torch.cat([data_obj.x, last_value]),
|
| 152 |
-
edge_index=torch.tensor(
|
| 153 |
-
[[i, i + 1] for i in range(len(data_obj.x))],
|
| 154 |
-
dtype=torch.long
|
| 155 |
-
).t()
|
| 156 |
-
)
|
| 157 |
-
out = model(new_obj)
|
| 158 |
-
last_value = out[-1].view(1, 1)
|
| 159 |
-
preds_scaled.append(last_value.item())
|
| 160 |
-
|
| 161 |
-
preds_real = scaler.inverse_transform(
|
| 162 |
-
np.array(preds_scaled).reshape(-1, 1)
|
| 163 |
-
).flatten()
|
| 164 |
-
|
| 165 |
-
dates = [
|
| 166 |
-
(end + timedelta(days=i + 1)).strftime("%Y-%m-%d")
|
| 167 |
-
for i in range(days)
|
| 168 |
-
]
|
| 169 |
-
|
| 170 |
-
return {
|
| 171 |
-
"symbol": symbol,
|
| 172 |
-
"today_close": float(df["Close"].iloc[-1]),
|
| 173 |
-
"predictions": [
|
| 174 |
-
{"date": d, "price": float(p)}
|
| 175 |
-
for d, p in zip(dates, preds_real)
|
| 176 |
-
],
|
| 177 |
-
}
|
| 178 |
-
except Exception as e:
|
| 179 |
-
return {"error": str(e)}
|
| 180 |
-
|
| 181 |
-
# ============================
|
| 182 |
-
# Root
|
| 183 |
-
# ============================
|
| 184 |
@app.get("/")
|
| 185 |
def root():
|
| 186 |
return {
|
|
@@ -188,6 +104,5 @@ def root():
|
|
| 188 |
"endpoints": [
|
| 189 |
"/stock/history?symbol=FPT&start=2023-01-01&end=2023-12-31",
|
| 190 |
"/stock/ta?symbol=HPG&start=2023-01-01&end=2023-12-31",
|
| 191 |
-
"/stock/gnn?symbol=VNM&days=7",
|
| 192 |
],
|
| 193 |
}
|
|
|
|
| 4 |
from datetime import datetime, timedelta
|
| 5 |
import pandas as pd
|
| 6 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
vn = Vnstock()
|
| 9 |
app = FastAPI(title="mFund VNStock API", version="1.0")
|
|
|
|
| 40 |
std = series.rolling(window).std()
|
| 41 |
return sma, sma + num_std * std, sma - num_std * std
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# ============================
|
| 44 |
# 1) GET /stock/history
|
| 45 |
# ============================
|
|
|
|
| 50 |
return {"error": "end must be >= start"}
|
| 51 |
|
| 52 |
stock = vn.stock(symbol=symbol)
|
| 53 |
+
if stock is None:
|
| 54 |
+
return {"error": "invalid symbol"}
|
| 55 |
+
|
| 56 |
df = stock.quote.history(start=start, end=end, interval="1D")
|
| 57 |
|
| 58 |
if df is None or df.empty:
|
|
|
|
| 79 |
return {"error": "end must be >= start"}
|
| 80 |
|
| 81 |
stock = vn.stock(symbol=symbol)
|
| 82 |
+
if stock is None:
|
| 83 |
+
return {"error": "invalid symbol"}
|
| 84 |
+
|
| 85 |
df = stock.quote.history(start=start, end=end, interval="1D")
|
| 86 |
|
| 87 |
if df is None or df.empty:
|
|
|
|
| 97 |
except Exception as e:
|
| 98 |
return {"error": str(e)}
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
@app.get("/")
|
| 101 |
def root():
|
| 102 |
return {
|
|
|
|
| 104 |
"endpoints": [
|
| 105 |
"/stock/history?symbol=FPT&start=2023-01-01&end=2023-12-31",
|
| 106 |
"/stock/ta?symbol=HPG&start=2023-01-01&end=2023-12-31",
|
|
|
|
| 107 |
],
|
| 108 |
}
|