|
|
import math |
|
|
import re |
|
|
|
|
|
import numexpr |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_core.tools import BaseTool, tool |
|
|
from core.embeddings import get_embeddings |
|
|
from core.settings import settings |
|
|
|
|
|
|
|
|
def calculator_func(expression: str) -> str: |
|
|
"""Calculates a math expression using numexpr. |
|
|
|
|
|
Useful for when you need to answer questions about math using numexpr. |
|
|
This tool is only for math questions and nothing else. Only input |
|
|
math expressions. |
|
|
|
|
|
Args: |
|
|
expression (str): A valid numexpr formatted math expression. |
|
|
|
|
|
Returns: |
|
|
str: The result of the math expression. |
|
|
""" |
|
|
|
|
|
try: |
|
|
local_dict = {"pi": math.pi, "e": math.e} |
|
|
output = str( |
|
|
numexpr.evaluate( |
|
|
expression.strip(), |
|
|
global_dict={}, |
|
|
local_dict=local_dict, |
|
|
) |
|
|
) |
|
|
return re.sub(r"^\[|\]$", "", output) |
|
|
except Exception as e: |
|
|
raise ValueError( |
|
|
f'calculator("{expression}") raised error: {e}.' |
|
|
" Please try again with a valid numerical expression" |
|
|
) |
|
|
|
|
|
|
|
|
calculator: BaseTool = tool(calculator_func) |
|
|
calculator.name = "Calculator" |
|
|
|
|
|
|
|
|
|
|
|
def format_contexts(docs): |
|
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
|
|
|
def load_chroma_db(): |
|
|
|
|
|
try: |
|
|
embeddings = get_embeddings(settings.DEFAULT_EMBEDDING_MODEL) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Failed to initialize embeddings for {settings.DEFAULT_EMBEDDING_MODEL}. " |
|
|
"Ensure the appropriate API key is set." |
|
|
) from e |
|
|
|
|
|
|
|
|
chroma_db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings) |
|
|
retriever = chroma_db.as_retriever(search_kwargs={"k": 5}) |
|
|
return retriever |
|
|
|
|
|
|
|
|
def database_search_func(query: str) -> str: |
|
|
"""Searches chroma_db for information in the company's handbook.""" |
|
|
|
|
|
retriever = load_chroma_db() |
|
|
|
|
|
|
|
|
documents = retriever.invoke(query) |
|
|
|
|
|
|
|
|
context_str = format_contexts(documents) |
|
|
|
|
|
return context_str |
|
|
|
|
|
|
|
|
database_search: BaseTool = tool(database_search_func) |
|
|
database_search.name = "Database_Search" |
|
|
|