Spaces:
Running
Running
FetchForecast: now able to direct the input df to from_data endpoint
Browse files- src/fetch_forecast.py +53 -4
src/fetch_forecast.py
CHANGED
|
@@ -7,7 +7,7 @@ import requests
|
|
| 7 |
|
| 8 |
|
| 9 |
class FetchForecast:
|
| 10 |
-
def __init__(self, ticker: str, debug=False) -> None:
|
| 11 |
if debug:
|
| 12 |
self.logger_level = logging.DEBUG
|
| 13 |
else:
|
|
@@ -17,9 +17,16 @@ class FetchForecast:
|
|
| 17 |
|
| 18 |
# args
|
| 19 |
self.ticker = ticker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# constants
|
| 21 |
self.past_horizon = 5 # number of past business days
|
| 22 |
-
self.endpoint = "v1/forecast/from_symbol"
|
| 23 |
# build the api-url based on env variables
|
| 24 |
self.api_env = os.environ.get("FORECAST_API_ENV")
|
| 25 |
api_url_temp = os.environ.get("API_URL_TEMPLATE")
|
|
@@ -30,8 +37,12 @@ class FetchForecast:
|
|
| 30 |
return past_df, fcst_df
|
| 31 |
|
| 32 |
def call_api(self) -> tuple:
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
resp = requests.post(f"{self.api_url}/{self.endpoint}", json=pl_in, timeout=30)
|
| 36 |
if resp.status_code == 200:
|
| 37 |
data = resp.json()
|
|
@@ -53,6 +64,44 @@ class FetchForecast:
|
|
| 53 |
return past_df, fcst_df
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if __name__ == "__main__":
|
| 57 |
# load the env variables fom .env file
|
| 58 |
dotenv.load_dotenv(dotenv.find_dotenv())
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class FetchForecast:
|
| 10 |
+
def __init__(self, ticker: str, df_hist: pd.DataFrame, debug=False) -> None:
|
| 11 |
if debug:
|
| 12 |
self.logger_level = logging.DEBUG
|
| 13 |
else:
|
|
|
|
| 17 |
|
| 18 |
# args
|
| 19 |
self.ticker = ticker
|
| 20 |
+
self.df_hist = df_hist
|
| 21 |
+
if df_hist is None:
|
| 22 |
+
self.endpoint = "v1/forecast/from_symbol"
|
| 23 |
+
logdatasuffix = "without data"
|
| 24 |
+
else:
|
| 25 |
+
self.endpoint = "v1/forecast/from_data"
|
| 26 |
+
logdatasuffix = "with historic data"
|
| 27 |
+
self.logger.info(f"Initialized FetchForecast for ticker: {self.ticker} {logdatasuffix}")
|
| 28 |
# constants
|
| 29 |
self.past_horizon = 5 # number of past business days
|
|
|
|
| 30 |
# build the api-url based on env variables
|
| 31 |
self.api_env = os.environ.get("FORECAST_API_ENV")
|
| 32 |
api_url_temp = os.environ.get("API_URL_TEMPLATE")
|
|
|
|
| 37 |
return past_df, fcst_df
|
| 38 |
|
| 39 |
def call_api(self) -> tuple:
|
| 40 |
+
if self.endpoint.split("/")[-1] in ["from_symbol"]:
|
| 41 |
+
self.logger.info(f"Sending the ticker symmbol to the forecast API ({self.api_env}) {self.endpoint} endpoint")
|
| 42 |
+
pl_in = {"ticker": self.ticker, "past_horizon": self.past_horizon}
|
| 43 |
+
elif self.endpoint.split("/")[-1] in ["from_data"]:
|
| 44 |
+
self.logger.info(f"Formatting and sending ticker data to the forecast API ({self.api_env}) {self.endpoint} endpoint")
|
| 45 |
+
pl_in = self.build_payload_with_data(ticker=self.ticker, past_horizon=self.past_horizon)
|
| 46 |
resp = requests.post(f"{self.api_url}/{self.endpoint}", json=pl_in, timeout=30)
|
| 47 |
if resp.status_code == 200:
|
| 48 |
data = resp.json()
|
|
|
|
| 64 |
return past_df, fcst_df
|
| 65 |
|
| 66 |
|
| 67 |
+
|
| 68 |
+
def build_payload_with_data(self, ticker: str, past_horizon: int) -> dict:
|
| 69 |
+
"""
|
| 70 |
+
Takes a dataframe
|
| 71 |
+
and returns the columnar JSON dict expected by the /from_data endpoint.
|
| 72 |
+
"""
|
| 73 |
+
df = self.df_hist
|
| 74 |
+
|
| 75 |
+
df.columns.name = None
|
| 76 |
+
# Re-introduce the ticker as a regular column
|
| 77 |
+
df["Ticker"] = ticker
|
| 78 |
+
# Make 'Date' a regular column by resetting the index
|
| 79 |
+
df.reset_index(inplace=True)
|
| 80 |
+
|
| 81 |
+
if "Date" not in df.columns or "Close" not in df.columns:
|
| 82 |
+
raise ValueError("DataFrame must contain 'Date' and 'Close' columns.")
|
| 83 |
+
|
| 84 |
+
df = df[["Date", "Close"]].copy()
|
| 85 |
+
df["Date"] = pd.to_datetime(df["Date"], utc=True, errors="coerce")
|
| 86 |
+
|
| 87 |
+
# Clean + order
|
| 88 |
+
df = (
|
| 89 |
+
df.dropna(subset=["Date", "Close"])
|
| 90 |
+
.sort_values("Date")
|
| 91 |
+
.drop_duplicates(subset=["Date"], keep="last")
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
payload = {
|
| 95 |
+
"ticker": ticker,
|
| 96 |
+
"series": {
|
| 97 |
+
"date": df["Date"].dt.strftime("%Y-%m-%d").tolist(),
|
| 98 |
+
"close": df["Close"].astype(float).tolist(),
|
| 99 |
+
},
|
| 100 |
+
"past_horizon": past_horizon,
|
| 101 |
+
}
|
| 102 |
+
return payload
|
| 103 |
+
|
| 104 |
+
|
| 105 |
if __name__ == "__main__":
|
| 106 |
# load the env variables fom .env file
|
| 107 |
dotenv.load_dotenv(dotenv.find_dotenv())
|