Forecast_Agent / src /rag_agent.py
ashantharosary's picture
Update src/rag_agent.py
f8807ba verified
Raw
History Blame Contribute Delete
618 Bytes
from transformers import pipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
class RAGAgent:
def __init__(self, vectorstore):
self.retriever = vectorstore.as_retriever()
self.qa = RetrievalQA.from_chain_type(
llm=HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0}),
retriever=self.retriever
)
def answer(self, query):
try:
return self.qa.run(query)
except Exception as e:
return f"Error: {str(e)}"