Commit
·
98b6743
1
Parent(s):
b56a840
final
Browse files- api.py +5 -20
- requirements.txt +0 -7
api.py
CHANGED
|
@@ -10,24 +10,9 @@ from pytorch_forecasting import TemporalFusionTransformer
|
|
| 10 |
from bs4 import BeautifulSoup
|
| 11 |
import requests
|
| 12 |
import torch
|
| 13 |
-
from llama_index.core import StorageContext, load_index_from_storage
|
| 14 |
-
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
-
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
|
| 17 |
-
import os
|
| 18 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
-
|
| 20 |
-
load_dotenv()
|
| 21 |
-
|
| 22 |
-
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
| 23 |
-
storage_context = StorageContext.from_defaults(persist_dir="rag_index")
|
| 24 |
-
index = load_index_from_storage(storage_context, embed_model=embed_model)
|
| 25 |
|
| 26 |
-
|
| 27 |
-
model_name="HuggingFaceH4/zephyr-7b-alpha", token=os.getenv('HF_API')
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
query_engine = index.as_query_engine(llm=llm)
|
| 31 |
|
| 32 |
MODEL_PATH = "lib/20_lstm_model.h5"
|
| 33 |
model = tf.keras.models.load_model(MODEL_PATH)
|
|
@@ -354,9 +339,9 @@ async def predict_prices(request: TickerRequest):
|
|
| 354 |
raise HTTPException(status_code=500, detail=str(e))
|
| 355 |
|
| 356 |
|
| 357 |
-
@app.get("/query-rag/{user_query}")
|
| 358 |
-
def query_rag(user_query: str):
|
| 359 |
|
| 360 |
-
|
| 361 |
|
| 362 |
-
|
|
|
|
| 10 |
from bs4 import BeautifulSoup
|
| 11 |
import requests
|
| 12 |
import torch
|
|
|
|
|
|
|
| 13 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
MODEL_PATH = "lib/20_lstm_model.h5"
|
| 18 |
model = tf.keras.models.load_model(MODEL_PATH)
|
|
|
|
| 339 |
raise HTTPException(status_code=500, detail=str(e))
|
| 340 |
|
| 341 |
|
| 342 |
+
# @app.get("/query-rag/{user_query}")
|
| 343 |
+
# def query_rag(user_query: str):
|
| 344 |
|
| 345 |
+
# response = query_engine.query(user_query)
|
| 346 |
|
| 347 |
+
# return {'message': response}
|
requirements.txt
CHANGED
|
@@ -4,13 +4,6 @@ fastapi==0.110.2
|
|
| 4 |
huggingface-hub==0.23.5
|
| 5 |
lightning==2.4.0
|
| 6 |
lightning-utilities==0.11.8
|
| 7 |
-
llama-hub==0.0.79.post1
|
| 8 |
-
llama-index
|
| 9 |
-
llama-index-core
|
| 10 |
-
llama-index-embeddings-huggingface
|
| 11 |
-
llama-index-legacy
|
| 12 |
-
llama-index-llms-huggingface
|
| 13 |
-
llama-index-llms-huggingface-api
|
| 14 |
multidict==6.0.5
|
| 15 |
multiprocess==0.70.16
|
| 16 |
nest-asyncio==1.6.0
|
|
|
|
| 4 |
huggingface-hub==0.23.5
|
| 5 |
lightning==2.4.0
|
| 6 |
lightning-utilities==0.11.8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
multidict==6.0.5
|
| 8 |
multiprocess==0.70.16
|
| 9 |
nest-asyncio==1.6.0
|