Spaces:
Runtime error
Runtime error
| from langchain_qdrant import FastEmbedSparse, RetrievalMode | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from utils.initMethods import getConfig, readYaml | |
| from langchain_qdrant import QdrantVectorStore | |
| from utils.exceptions import CustomException | |
| from langchain_cerebras import ChatCerebras | |
| from qdrant_client import QdrantClient | |
| from utils.logger import logger | |
| import os | |
| config = getConfig(os.path.join(os.getcwd(), "config.ini")) | |
| modelName = config.get("RAGAGENT", "denseEmbeddings") | |
| modelKwargs = {'device': 'cpu'} | |
| encodeKwargs = {'normalize_embeddings': True} | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=modelName, | |
| model_kwargs=modelKwargs, | |
| encode_kwargs=encodeKwargs | |
| ) | |
| sparseEmbeddings = FastEmbedSparse(model_name=config.get("RAGAGENT", "sparseEmbeddings")) | |
| class RAGAgent: | |
| def __init__(self) -> None: | |
| try: | |
| logger.info("INITIALIZING RAG AGENT") | |
| client = QdrantClient( | |
| url=os.environ.get("QDRANT_URL"), | |
| api_key=os.environ.get("QDRANT_API_KEY"), | |
| ) | |
| vectorStore = QdrantVectorStore( | |
| client=client, | |
| collection_name="sampleCollection", | |
| embedding=embeddings, | |
| vector_name="semantic-search", | |
| sparse_vector_name="syntactic-search", | |
| retrieval_mode=RetrievalMode.SPARSE, | |
| sparse_embedding=sparseEmbeddings | |
| ) | |
| promptTemplate = ChatPromptTemplate.from_template(readYaml(os.path.join(os.getcwd(), "prompts.yaml")).get("ragTemplate")) | |
| retriever = vectorStore.as_retriever(search_kwargs = {"k": 5}) | |
| llm = ChatCerebras( | |
| model = config.get("RAGAGENT", "modelName"), | |
| temperature = config.getfloat("RAGAGENT", "temperature"), | |
| max_tokens = config.getint("RAGAGENT", "maxTokens") | |
| ) | |
| chain = {"query": RunnablePassthrough(), "context": RunnablePassthrough() | retriever} | promptTemplate | llm | StrOutputParser() | |
| self.chain = chain | |
| except Exception as e: | |
| exception = CustomException(e) | |
| logger.error(exception) | |
| raise exception | |
| def query(self, query) -> str: | |
| try: | |
| output = self.chain.invoke(query) | |
| return output | |
| except Exception as e: | |
| exception = CustomException(e) | |
| logger.error(exception) | |
| raise exception |