architojha commited on
Commit
b56a840
·
1 Parent(s): 077d7f6
Files changed (1) hide show
  1. api.py +20 -9
api.py CHANGED
@@ -15,6 +15,7 @@ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
15
  from dotenv import load_dotenv
16
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
17
  import os
 
18
 
19
  load_dotenv()
20
 
@@ -38,10 +39,18 @@ sentiment_model = AutoModelForSequenceClassification.from_pretrained(
38
 
39
  best_model_path = 'lib/tft_pred.ckpt'
40
 
41
- best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
42
 
43
  app = FastAPI()
44
 
 
 
 
 
 
 
 
 
45
 
46
  class TickerRequest(BaseModel):
47
  ticker: str
@@ -49,7 +58,6 @@ class TickerRequest(BaseModel):
49
  end_date: str
50
  interval: str = "1d"
51
 
52
-
53
  def fetch_and_process_ticker_data(ticker, start_date, end_date, interval="1d"):
54
  df = pd.DataFrame()
55
  try:
@@ -59,7 +67,8 @@ def fetch_and_process_ticker_data(ticker, start_date, end_date, interval="1d"):
59
  temp["revenue"] = temp["adjclose"] * temp["volume"]
60
  temp["daily_profit"] = temp["adjclose"] - temp["open"]
61
  df = pd.concat([df, temp], axis=0)
62
- df.to_csv("api_test.csv", index=False) # Save locally for reference
 
63
  except Exception as error:
64
  raise HTTPException(
65
  status_code=500, detail=f"Error processing ticker {ticker}: {error}")
@@ -283,14 +292,14 @@ def get_tft_predictions(df):
283
  return predictions
284
 
285
 
286
- @app.post("/fetch-ticker-data/")
287
- async def fetch_ticker_data(request: TickerRequest):
288
  try:
289
  result_df = fetch_and_process_ticker_data(
290
- ticker=request.ticker,
291
- start_date=request.start_date,
292
- end_date=request.end_date,
293
- interval=request.interval
294
  )
295
  return result_df.to_dict(orient="records")
296
  except Exception as e:
@@ -332,6 +341,8 @@ async def predict_prices(request: TickerRequest):
332
 
333
  sentiment_df['time_idx'] = range(1, len(sentiment_df) + 1)
334
 
 
 
335
  predicted_values = get_tft_predictions(sentiment_df)
336
 
337
  final_pred_open_price = predicted_values[0].item()
 
15
  from dotenv import load_dotenv
16
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
17
  import os
18
+ from fastapi.middleware.cors import CORSMiddleware
19
 
20
  load_dotenv()
21
 
 
39
 
40
  best_model_path = 'lib/tft_pred.ckpt'
41
 
42
+ best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path).to('cpu')
43
 
44
  app = FastAPI()
45
 
46
+ app.add_middleware(
47
+ CORSMiddleware,
48
+ allow_origins=["*"],
49
+ allow_credentials=True,
50
+ allow_methods=["GET", "POST", "PUT", "DELETE"],
51
+ allow_headers=["*"],
52
+ )
53
+
54
 
55
  class TickerRequest(BaseModel):
56
  ticker: str
 
58
  end_date: str
59
  interval: str = "1d"
60
 
 
61
  def fetch_and_process_ticker_data(ticker, start_date, end_date, interval="1d"):
62
  df = pd.DataFrame()
63
  try:
 
67
  temp["revenue"] = temp["adjclose"] * temp["volume"]
68
  temp["daily_profit"] = temp["adjclose"] - temp["open"]
69
  df = pd.concat([df, temp], axis=0)
70
+ df.to_csv("api_test.csv", index=False)
71
+
72
  except Exception as error:
73
  raise HTTPException(
74
  status_code=500, detail=f"Error processing ticker {ticker}: {error}")
 
292
  return predictions
293
 
294
 
295
+ @app.get("/fetch-ticker-data/{ticker_name}/{start_date}/{end_date}/{interval}")
296
+ async def fetch_ticker_data(ticker_name: str, start_date: str, end_date: str, interval: str):
297
  try:
298
  result_df = fetch_and_process_ticker_data(
299
+ ticker=ticker_name,
300
+ start_date=start_date,
301
+ end_date=end_date,
302
+ interval=interval
303
  )
304
  return result_df.to_dict(orient="records")
305
  except Exception as e:
 
341
 
342
  sentiment_df['time_idx'] = range(1, len(sentiment_df) + 1)
343
 
344
+ print(sentiment_df)
345
+
346
  predicted_values = get_tft_predictions(sentiment_df)
347
 
348
  final_pred_open_price = predicted_values[0].item()