Commit
·
b56a840
1
Parent(s):
077d7f6
final
Browse files
api.py
CHANGED
|
@@ -15,6 +15,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
|
| 17 |
import os
|
|
|
|
| 18 |
|
| 19 |
load_dotenv()
|
| 20 |
|
|
@@ -38,10 +39,18 @@ sentiment_model = AutoModelForSequenceClassification.from_pretrained(
|
|
| 38 |
|
| 39 |
best_model_path = 'lib/tft_pred.ckpt'
|
| 40 |
|
| 41 |
-
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
|
| 42 |
|
| 43 |
app = FastAPI()
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
class TickerRequest(BaseModel):
|
| 47 |
ticker: str
|
|
@@ -49,7 +58,6 @@ class TickerRequest(BaseModel):
|
|
| 49 |
end_date: str
|
| 50 |
interval: str = "1d"
|
| 51 |
|
| 52 |
-
|
| 53 |
def fetch_and_process_ticker_data(ticker, start_date, end_date, interval="1d"):
|
| 54 |
df = pd.DataFrame()
|
| 55 |
try:
|
|
@@ -59,7 +67,8 @@ def fetch_and_process_ticker_data(ticker, start_date, end_date, interval="1d"):
|
|
| 59 |
temp["revenue"] = temp["adjclose"] * temp["volume"]
|
| 60 |
temp["daily_profit"] = temp["adjclose"] - temp["open"]
|
| 61 |
df = pd.concat([df, temp], axis=0)
|
| 62 |
-
df.to_csv("api_test.csv", index=False)
|
|
|
|
| 63 |
except Exception as error:
|
| 64 |
raise HTTPException(
|
| 65 |
status_code=500, detail=f"Error processing ticker {ticker}: {error}")
|
|
@@ -283,14 +292,14 @@ def get_tft_predictions(df):
|
|
| 283 |
return predictions
|
| 284 |
|
| 285 |
|
| 286 |
-
@app.
|
| 287 |
-
async def fetch_ticker_data(
|
| 288 |
try:
|
| 289 |
result_df = fetch_and_process_ticker_data(
|
| 290 |
-
ticker=
|
| 291 |
-
start_date=
|
| 292 |
-
end_date=
|
| 293 |
-
interval=
|
| 294 |
)
|
| 295 |
return result_df.to_dict(orient="records")
|
| 296 |
except Exception as e:
|
|
@@ -332,6 +341,8 @@ async def predict_prices(request: TickerRequest):
|
|
| 332 |
|
| 333 |
sentiment_df['time_idx'] = range(1, len(sentiment_df) + 1)
|
| 334 |
|
|
|
|
|
|
|
| 335 |
predicted_values = get_tft_predictions(sentiment_df)
|
| 336 |
|
| 337 |
final_pred_open_price = predicted_values[0].item()
|
|
|
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
|
| 17 |
import os
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
|
| 20 |
load_dotenv()
|
| 21 |
|
|
|
|
| 39 |
|
| 40 |
best_model_path = 'lib/tft_pred.ckpt'
|
| 41 |
|
| 42 |
+
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path).to('cpu')
|
| 43 |
|
| 44 |
app = FastAPI()
|
| 45 |
|
| 46 |
+
app.add_middleware(
|
| 47 |
+
CORSMiddleware,
|
| 48 |
+
allow_origins=["*"],
|
| 49 |
+
allow_credentials=True,
|
| 50 |
+
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
| 51 |
+
allow_headers=["*"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
|
| 55 |
class TickerRequest(BaseModel):
|
| 56 |
ticker: str
|
|
|
|
| 58 |
end_date: str
|
| 59 |
interval: str = "1d"
|
| 60 |
|
|
|
|
| 61 |
def fetch_and_process_ticker_data(ticker, start_date, end_date, interval="1d"):
|
| 62 |
df = pd.DataFrame()
|
| 63 |
try:
|
|
|
|
| 67 |
temp["revenue"] = temp["adjclose"] * temp["volume"]
|
| 68 |
temp["daily_profit"] = temp["adjclose"] - temp["open"]
|
| 69 |
df = pd.concat([df, temp], axis=0)
|
| 70 |
+
df.to_csv("api_test.csv", index=False)
|
| 71 |
+
|
| 72 |
except Exception as error:
|
| 73 |
raise HTTPException(
|
| 74 |
status_code=500, detail=f"Error processing ticker {ticker}: {error}")
|
|
|
|
| 292 |
return predictions
|
| 293 |
|
| 294 |
|
| 295 |
+
@app.get("/fetch-ticker-data/{ticker_name}/{start_date}/{end_date}/{interval}")
|
| 296 |
+
async def fetch_ticker_data(ticker_name: str, start_date: str, end_date: str, interval: str):
|
| 297 |
try:
|
| 298 |
result_df = fetch_and_process_ticker_data(
|
| 299 |
+
ticker=ticker_name,
|
| 300 |
+
start_date=start_date,
|
| 301 |
+
end_date=end_date,
|
| 302 |
+
interval=interval
|
| 303 |
)
|
| 304 |
return result_df.to_dict(orient="records")
|
| 305 |
except Exception as e:
|
|
|
|
| 341 |
|
| 342 |
sentiment_df['time_idx'] = range(1, len(sentiment_df) + 1)
|
| 343 |
|
| 344 |
+
print(sentiment_df)
|
| 345 |
+
|
| 346 |
predicted_values = get_tft_predictions(sentiment_df)
|
| 347 |
|
| 348 |
final_pred_open_price = predicted_values[0].item()
|