| import random |
| import openai |
| import json |
|
|
| from langchain.docstore.document import Document as LangChainDocument |
| from langchain.embeddings.openai import OpenAIEmbeddings |
| from fastapi import HTTPException |
| from uuid import UUID, uuid4 |
| from langchain.text_splitter import ( |
| CharacterTextSplitter, |
| MarkdownTextSplitter |
| ) |
| from sqlmodel import ( |
| Session, |
| text |
| ) |
| from util import ( |
| sanitize_input, |
| sanitize_output |
| ) |
| from langchain import OpenAI |
| from typing import ( |
| List, |
| Union, |
| Optional, |
| Dict, |
| Tuple, |
| Any |
| ) |
| from helpers import ( |
| get_user_by_uuid_or_identifier, |
| get_chat_session_by_uuid |
| ) |
| from models import ( |
| User, |
| Organization, |
| Project, |
| Node, |
| ChatSession, |
| ChatSessionResponse, |
| get_engine |
| ) |
| from config import ( |
| CHANNEL_TYPE, |
| DOCUMENT_TYPE, |
| LLM_MODELS, |
| LLM_DISTANCE_THRESHOLD, |
| LLM_DEFAULT_TEMPERATURE, |
| LLM_MAX_OUTPUT_TOKENS, |
| LLM_CHUNK_SIZE, |
| LLM_CHUNK_OVERLAP, |
| LLM_MIN_NODE_LIMIT, |
| LLM_DEFAULT_DISTANCE_STRATEGY, |
| VECTOR_EMBEDDINGS_COUNT, |
| DISTANCE_STRATEGY, |
| AGENT_NAMES, |
| logger |
| ) |
|
|
|
|
| |
| |
| |
| def chat_query( |
| query_str: str, |
| session_id: Optional[Union[str, UUID]] = None, |
| meta: Optional[Dict[str, Any]] = {}, |
| channel: Optional[CHANNEL_TYPE] = None, |
| identifier: Optional[str] = None, |
| project: Optional[Project] = None, |
| organization: Optional[Organization] = None, |
| session: Optional[Session] = None, |
| user_data: Optional[Dict[str, Any]] = None, |
| distance_strategy: Optional[DISTANCE_STRATEGY] = DISTANCE_STRATEGY.EUCLIDEAN, |
| distance_threshold: Optional[float] = LLM_DISTANCE_THRESHOLD, |
| node_limit: Optional[int] = LLM_MIN_NODE_LIMIT, |
| model: Optional[LLM_MODELS] = LLM_MODELS.GPT_35_TURBO, |
| max_output_tokens: Optional[int] = LLM_MAX_OUTPUT_TOKENS, |
| ) -> ChatSessionResponse: |
| """ |
| Steps: |
| 1. β
Clean user input |
| 2. β
Create input embeddings |
| 3. β
Search for similar nodes |
| 4. β
Create prompt template w/ similar nodes |
| 5. β
Submit prompt template to LLM |
| 6. β
Get response from LLM |
| 7. Create ChatSession |
| - Store embeddings |
| - Store tags |
| - Store is_escalate |
| 8. Return response |
| """ |
| meta = {} |
| agent_name = None |
| embeddings = [] |
| tags = [] |
| is_escalate = False |
| response_message = None |
| prompt = None |
| context_str = None |
| MODEL_TOKEN_LIMIT = ( |
| model.token_limit if isinstance(model, OpenAI) else LLM_MAX_OUTPUT_TOKENS |
| ) |
|
|
| |
| |
| |
| prev_chat_session = ( |
| get_chat_session_by_uuid(session_id=session_id, session=session) |
| if session_id |
| else None |
| ) |
|
|
| |
| if session_id and not prev_chat_session: |
| return HTTPException( |
| status_code=404, detail=f"Chat session with ID {session_id} not found." |
| ) |
| |
| elif session_id and prev_chat_session and prev_chat_session.meta.get("agent"): |
| agent_name = prev_chat_session.meta["agent"] |
| |
| else: |
| session_id = str(uuid4()) |
|
|
| meta["agent"] = agent_name if agent_name else random.choice(AGENT_NAMES) |
|
|
| |
| |
| |
| query_str = sanitize_input(query_str) |
| logger.debug(f"π¬ Query received: {query_str}") |
|
|
| |
| |
| |
| query_token_count = get_token_count(query_str) |
| prompt_token_count = 0 |
|
|
| |
| |
| |
| arr_query, embeddings = get_embeddings(query_str) |
|
|
| query_embeddings = embeddings[0] |
|
|
| |
| |
| |
| nodes = get_nodes_by_embedding( |
| query_embeddings, |
| node_limit, |
| distance_strategy=distance_strategy |
| if isinstance(distance_strategy, DISTANCE_STRATEGY) |
| else LLM_DEFAULT_DISTANCE_STRATEGY, |
| distance_threshold=distance_threshold, |
| session=session, |
| ) |
|
|
| if len(nodes) > 0: |
| if (not project or not organization) and session: |
| |
| document = session.get(Node, nodes[0].id).document |
| project = document.project |
| organization = project.organization |
|
|
| |
| |
| |
|
|
| |
| context_str = "\n\n".join([node.text for node in nodes]) |
|
|
| |
| |
| |
| context_token_count = get_token_count(context_str) |
|
|
| |
| |
| |
| if ( |
| context_token_count + query_token_count + prompt_token_count |
| ) > MODEL_TOKEN_LIMIT: |
| logger.debug("π§ Exceeded token limit, truncating context") |
| token_delta = MODEL_TOKEN_LIMIT - (query_token_count + prompt_token_count) |
| context_str = context_str[:token_delta] |
|
|
| |
| system_prompt, user_prompt = get_prompt_template( |
| user_query=query_str, |
| context_str=context_str, |
| project=project, |
| organization=organization, |
| agent=agent_name, |
| ) |
|
|
| prompt_token_count = get_token_count(prompt) |
| token_count = context_token_count + query_token_count + prompt_token_count |
|
|
| |
| |
| |
| |
| llm_response = json.loads( |
| retrieve_llm_response( |
| user_prompt, |
| model=model, |
| max_output_tokens=max_output_tokens, |
| prefix_messages=system_prompt, |
| ) |
| ) |
| tags = llm_response.get("tags", []) |
| is_escalate = llm_response.get("is_escalate", False) |
| response_message = llm_response.get("message", None) |
| else: |
| logger.info("π«π No similar nodes found, returning default response") |
|
|
| |
| |
| |
| user = get_user_by_uuid_or_identifier( |
| identifier, session=session, should_except=False |
| ) |
|
|
| if not user: |
| logger.debug("π«π€ User not found, creating new user") |
| user_params = { |
| "identifier": identifier, |
| "identifier_type": channel.value |
| if isinstance(channel, CHANNEL_TYPE) |
| else channel, |
| } |
| if user_data: |
| user_params = {**user_params, **user_data} |
|
|
| user = User.create(user_params) |
| else: |
| logger.debug(f"π€ User found: {user}") |
|
|
| |
| |
| |
| token_count = get_token_count(prompt) + get_token_count(response_message) |
|
|
| |
| |
| |
| if tags: |
| meta["tags"] = tags |
|
|
| meta["is_escalate"] = is_escalate |
|
|
| if session_id: |
| meta["session_id"] = session_id |
|
|
| chat_session = ChatSession( |
| user_id=user.id, |
| session_id=session_id, |
| project_id=project.id if project else None, |
| channel=channel.value if isinstance(channel, CHANNEL_TYPE) else channel, |
| user_message=query_str, |
| embeddings=query_embeddings, |
| token_count=token_count if token_count > 0 else None, |
| response=response_message, |
| meta=meta, |
| ) |
|
|
| if session: |
| session.add(chat_session) |
| session.commit() |
| session.refresh(chat_session) |
|
|
| else: |
| with Session(get_engine()) as session: |
| session.add(chat_session) |
| session.commit() |
| session.refresh(chat_session) |
|
|
| return chat_session |
|
|
|
|
| |
| |
| |
| def get_random_agent(): |
| return random.choice(AGENT_NAMES) |
|
|
|
|
| |
| |
| |
| def get_prompt_template( |
| user_query: str = None, |
| context_str: str = None, |
| project: Optional[Project] = None, |
| organization: Optional[Organization] = None, |
| agent: str = None, |
| ) -> str: |
| agent = f"{agent}, " if agent else "" |
| user_query = user_query if user_query else "" |
| context_str = context_str if context_str else "" |
| organization = ( |
| project.organization.display_name |
| if project |
| else organization.display_name |
| if organization |
| else None |
| ) |
|
|
| if not context_str or not user_query: |
| raise ValueError( |
| "Missing required arguments context_str, user_query, organization, agent" |
| ) |
|
|
| system_prompt = [ |
| { |
| "role": "system", |
| "content": f"""[AGENT]: |
| I am {agent} a very kind and enthusiastic customer support agent who loves to help customers. I am working on the behalf of "{organization}" |
| |
| Given the following document from "{organization}", I will answer the [USER] questions using only the [DOCUMENT] and following the [RULES]. |
| |
| [DOCUMENT]: |
| {context_str} |
| |
| [RULES]: |
| I will answer the user's questions using only the [DOCUMENT] provided. I will abide by the following rules: |
| - I am a kind and helpful human, the best customer support agent in existence |
| - I never lie or invent answers not explicitly provided in [DOCUMENT] |
| - If I am unsure of the answer response or the answer is not explicitly contained in [DOCUMENT], I will say: "I apologize, I'm not sure how to help with that". |
| - I always keep my answers short, relevant and concise. |
| - I will always respond in JSON format with the following keys: "message" my response to the user, "tags" an array of short labels categorizing user input, "is_escalate" a boolean, returning false if I am unsure and true if I do have a relevant answer |
| """, |
| } |
| ] |
|
|
| return (system_prompt, f"[USER]:\n{user_query}") |
|
|
|
|
| |
| |
| |
| |
| def get_token_count(text: str): |
| if not text: |
| return 0 |
|
|
| return OpenAI().get_num_tokens(text=text) |
|
|
|
|
| |
| |
| |
| def get_nodes_by_embedding( |
| embeddings: List[float], |
| k: int = LLM_MIN_NODE_LIMIT, |
| distance_strategy: Optional[DISTANCE_STRATEGY] = LLM_DEFAULT_DISTANCE_STRATEGY, |
| distance_threshold: Optional[float] = LLM_DISTANCE_THRESHOLD, |
| session: Optional[Session] = None, |
| ) -> List[Node]: |
| |
| embeddings_str = str(embeddings) |
|
|
| if distance_strategy == DISTANCE_STRATEGY.EUCLIDEAN: |
| distance_fn = "match_node_euclidean" |
| elif distance_strategy == DISTANCE_STRATEGY.COSINE: |
| distance_fn = "match_node_cosine" |
| elif distance_strategy == DISTANCE_STRATEGY.MAX_INNER_PRODUCT: |
| distance_fn = "match_node_max_inner_product" |
| else: |
| raise Exception(f"Invalid distance strategy {distance_strategy}") |
|
|
| |
| |
| |
| sql = f"""SELECT * FROM {distance_fn}( |
| '{embeddings_str}'::vector({VECTOR_EMBEDDINGS_COUNT}), |
| {float(distance_threshold)}::double precision, |
| {int(k)});""" |
|
|
| |
|
|
| |
| if not session: |
| with Session(get_engine()) as session: |
| nodes = session.exec(text(sql)).all() |
| else: |
| nodes = session.exec(text(sql)).all() |
|
|
| return [Node.by_uuid(str(node[0])) for node in nodes] if nodes else [] |
|
|
|
|
| |
| |
| |
| def retrieve_llm_response( |
| query_str: str, |
| model: Optional[LLM_MODELS] = LLM_MODELS.GPT_35_TURBO, |
| temperature: Optional[float] = LLM_DEFAULT_TEMPERATURE, |
| max_output_tokens: Optional[int] = LLM_MAX_OUTPUT_TOKENS, |
| prefix_messages: Optional[List[dict]] = None, |
| ): |
| llm = OpenAI( |
| temperature=temperature, |
| model_name=model.model_name |
| if isinstance(model, LLM_MODELS) |
| else LLM_MODELS.GPT_35_TURBO.model_name, |
| max_tokens=max_output_tokens, |
| prefix_messages=prefix_messages, |
| ) |
| try: |
| result = llm(prompt=query_str) |
| except openai.error.InvalidRequestError as e: |
| logger.error(f"π¨ LLM error: {e}") |
| raise HTTPException(status_code=500, detail=f"LLM error: {e}") |
| logger.debug(f"π¬ LLM result: {str(result)}") |
| return sanitize_output(result) |
|
|
|
|
| |
| |
| |
| def get_embeddings( |
| document_data: str, |
| document_type: DOCUMENT_TYPE = DOCUMENT_TYPE.PLAINTEXT, |
| ) -> Tuple[List[str], List[float]]: |
| documents = [LangChainDocument(page_content=document_data)] |
|
|
| logger.debug(documents) |
| if document_type == DOCUMENT_TYPE.MARKDOWN: |
| doc_splitter = MarkdownTextSplitter( |
| chunk_size=LLM_CHUNK_SIZE, chunk_overlap=LLM_CHUNK_OVERLAP |
| ) |
| else: |
| doc_splitter = CharacterTextSplitter( |
| chunk_size=LLM_CHUNK_SIZE, chunk_overlap=LLM_CHUNK_OVERLAP |
| ) |
|
|
| |
| split_documents = doc_splitter.split_documents(documents) |
|
|
| |
| arr_documents = [doc.page_content for doc in split_documents] |
|
|
| |
| embed_func = OpenAIEmbeddings() |
|
|
| embeddings = embed_func.embed_documents( |
| texts=arr_documents, chunk_size=LLM_CHUNK_SIZE |
| ) |
|
|
| return arr_documents, embeddings |