File size: 5,218 Bytes
b496a3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from fastapi import Depends
import os
import datetime
from langchain_qdrant import QdrantVectorStore
from langchain_pinecone import PineconeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_community.chat_message_histories import RedisChatMessageHistory
import tiktoken
from models.tables import Queries
from config.db import SessionLocal
from sqlalchemy.orm import Session
from typing import Annotated
import logging
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool, StructuredTool, tool
from dotenv import load_dotenv
from langchain_groq import ChatGroq
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

llm = ChatGroq(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
# llm = ChatOpenAI(api_key=os.getenv("DEEP_INFRA_API_KEY"), model="meta-llama/Meta-Llama-3-70B-Instruct", base_url="https://api.deepinfra.com/v1/openai")

embeddings = PineconeEmbeddings(model="multilingual-e5-large")


def get_db():
    db = SessionLocal()
    try:
        yield db 
    finally:
        db.close()

db_dependency = Annotated[Session,Depends(get_db)]

# funtion for history retrieval using redis
def get_message_history(session_id: str) -> RedisChatMessageHistory:
    return RedisChatMessageHistory(session_id, url=os.getenv("REDIS_URL"))





text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=200,
    chunk_overlap=20,
    length_function=len,
    is_separator_regex=False,
)


meeting_finder = ["schedule", "meeting", "appointment", "call", "discuss", "connect", "book"]


def escape_template_string(template: str) -> str:
    """
    Escapes special characters in the template string.

    Args:
        template (str): The original template string.

    Returns:
        str: The escaped template string.
    """

    escaped_template = template.replace("'", "\\'")
  
    escaped_template = escaped_template.replace('"', '\\"')
    
   
    escaped_template = escaped_template.replace("\\", "\\\\")

    escaped_template = escaped_template.replace("\n", "\\n")

    escaped_template = escaped_template.replace("\r", "\\r")

    escaped_template = escaped_template.replace("\t", "\\t")
    
    return escaped_template



def context_retriever(query,session_id,company_id,chatbot_id,db,collection_name, embeddings=PineconeEmbeddings(model="multilingual-e5-large")):
    """
    Retrieves the context for the given query by searching the Qdrant vector database
    and retrieving the most similar documents. If no documents are found, it returns
    a default message.

    Args:
        query (str): The query to retrieve context for.
        session_id (str): The session ID of the user.
        company_id (str): The ID of the company.
        chatbot_id (str): The ID of the chatbot.
        db (Session): The database session.
        collection_name (str): The name of the Qdrant collection.
        embeddings (Embeddings, optional): The embeddings to use for the search. Defaults to OpenAIEmbeddings.

    Returns:
        str: The retrieved context.
    """

    try:
        vectorstore = QdrantVectorStore.from_existing_collection(
            embedding=embeddings,
            collection_name=collection_name,
            url=os.getenv("QDRANT_URL", "http://localhost:6333"),
            api_key=os.getenv("QDRANT_API_KEY"),
        )
        manual_filter={
        "must": [
                {
                    "key": "metadata.source",
                    "match": {
                        "value": "manual"
                    }
                }
            ]
        }
        docs = vectorstore.similarity_search(query, k=20)
        content = ""
        if len(docs) != 0:
            for i in range(len(docs)):
                try:
                    page_content = docs[i].page_content 
                    source = docs[i].metadata.get('source', "")
                    title = docs[i].metadata.get('title', "")
                    description = docs[i].metadata.get('description', "")
                    
                    content += f"""{i+1}. Content: {page_content}.\nContent's Page URL: {source}.\nTitle of the page: {title}.\nDescription of the page: {description}.\n"""
                except Exception as e:
                    content = f"An error occurred while processing document {i}: {str(e)}"
        else:
            content = "Frame a professional answer which shows the positive image of the company and should be relevant to the query, don't answer on your own if you think question is not relevant to companies benfits"
        
    except Exception as e:
        content = f"An error occurred: {str(e)}"
    create_query_model = Queries(
    company_id = company_id,
    chatbot_id = chatbot_id,
    session_id = session_id,
    query_text_user = query,
    query_context = content,
    query_time = datetime.datetime.now()
    )
    db.add(create_query_model)
    db.commit()
    logger.info("Context retrieved: %s", content)
    return content


def count_tokens(text):
    tokenizer = tiktoken.get_encoding("cl100k_base")
    tokens = tokenizer.encode(text)
    return len(tokens)