Spaces:
Runtime error
Runtime error
File size: 2,643 Bytes
5ec1ba2 ffc9104 5ec1ba2 eebb03e 5ec1ba2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from langchain_qdrant import FastEmbedSparse, RetrievalMode
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from utils.initMethods import getConfig, readYaml
from langchain_qdrant import QdrantVectorStore
from utils.exceptions import CustomException
from langchain_cerebras import ChatCerebras
from qdrant_client import QdrantClient
from utils.logger import logger
import os
config = getConfig(os.path.join(os.getcwd(), "config.ini"))
modelName = config.get("RAGAGENT", "denseEmbeddings")
modelKwargs = {'device': 'cpu'}
encodeKwargs = {'normalize_embeddings': True}
embeddings = HuggingFaceEmbeddings(
model_name=modelName,
model_kwargs=modelKwargs,
encode_kwargs=encodeKwargs
)
sparseEmbeddings = FastEmbedSparse(model_name=config.get("RAGAGENT", "sparseEmbeddings"))
class RAGAgent:
def __init__(self) -> None:
try:
logger.info("INITIALIZING RAG AGENT")
client = QdrantClient(
url=os.environ.get("QDRANT_URL"),
api_key=os.environ.get("QDRANT_API_KEY"),
)
vectorStore = QdrantVectorStore(
client=client,
collection_name="sampleCollection",
embedding=embeddings,
vector_name="semantic-search",
sparse_vector_name="syntactic-search",
retrieval_mode=RetrievalMode.SPARSE,
sparse_embedding=sparseEmbeddings
)
promptTemplate = ChatPromptTemplate.from_template(readYaml(os.path.join(os.getcwd(), "prompts.yaml")).get("ragTemplate"))
retriever = vectorStore.as_retriever(search_kwargs = {"k": 5})
llm = ChatCerebras(
model = config.get("RAGAGENT", "modelName"),
temperature = config.getfloat("RAGAGENT", "temperature"),
max_tokens = config.getint("RAGAGENT", "maxTokens")
)
chain = {"query": RunnablePassthrough(), "context": RunnablePassthrough() | retriever} | promptTemplate | llm | StrOutputParser()
self.chain = chain
except Exception as e:
exception = CustomException(e)
logger.error(exception)
raise exception
def query(self, query) -> str:
try:
output = self.chain.invoke(query)
return output
except Exception as e:
exception = CustomException(e)
logger.error(exception)
raise exception |