File size: 1,695 Bytes
7e24b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from src.components.vectors.vectorstore import VectorStore
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from src.utils.exceptions import CustomException
from src.utils.functions import getConfig
from src.utils.functions import loadYaml
from src.utils.logging import logger
from langchain_groq import ChatGroq


class Chain:
    def __init__(self):
        self.config = getConfig(path = "config.ini")
        self.store = VectorStore()
        prompt = loadYaml(path = "params.yaml")["prompt"]
        self.prompt = ChatPromptTemplate.from_template(prompt)

    def formatDocs(self, docs):
        context = ""
        for doc in docs:
            context += f"{doc}\n\n\n"
        if context == "":
            context = "No Context Found"
        else:
            pass
        return context

    def returnChain(self, text: str):
        try:
            logger.info("preparing chain")
            store = self.store.setupStore(text = text)
            chain = (
                    {"context": RunnableLambda(lambda x: x["question"]) | store | RunnableLambda(self.formatDocs),
                     "question": RunnableLambda(lambda x: x["question"])}
                    | self.prompt
                    | ChatGroq(model_name = self.config.get("LLM", "llmModel"), temperature = self.config.getfloat("LLM", "temperature"), max_tokens = self.config.getint("LLM", "maxTokens"))
                    | StrOutputParser()
            )
            return chain
        except Exception as e:
            logger.error(CustomException(e))