WilliamGazeley
commited on
Commit
·
9efba8b
1
Parent(s):
391d6e2
Integrate working langchain ollama model
Browse files- Dockerfile +1 -0
- src/agents/__init__.py +34 -0
- src/agents/format_scratchpad/functions.py +63 -0
- src/agents/functions_agent/base.py +48 -0
- src/agents/output_parsers/functions.py +77 -0
- src/agents/output_parsers/utils.py +64 -0
- src/functioncall.py +3 -2
- src/functions.py +21 -17
- src/prompts/prompt.py +17 -0
- src/prompts/rag_template.yaml +12 -0
Dockerfile
CHANGED
|
@@ -45,6 +45,7 @@ RUN pyenv install ${PYTHON_VERSION} && \
|
|
| 45 |
COPY --chown=1000 ./requirements.txt /tmp/requirements.txt
|
| 46 |
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
| 47 |
pip install flash-attn --no-build-isolation
|
|
|
|
| 48 |
|
| 49 |
COPY --chown=1000 src ${HOME}/app
|
| 50 |
|
|
|
|
| 45 |
COPY --chown=1000 ./requirements.txt /tmp/requirements.txt
|
| 46 |
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
| 47 |
pip install flash-attn --no-build-isolation
|
| 48 |
+
RUN ollama pull ${OLLAMA_MODEL}
|
| 49 |
|
| 50 |
COPY --chown=1000 src ${HOME}/app
|
| 51 |
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.chat_models import ChatOllama
|
| 2 |
+
from prompts.prompt import rag_agent_prompt
|
| 3 |
+
from agents.functions_agent.base import create_functions_agent
|
| 4 |
+
from langchain.agents import AgentExecutor
|
| 5 |
+
from langchain.memory import ChatMessageHistory
|
| 6 |
+
from functions import get_openai_functions, tools, get_openai_tools
|
| 7 |
+
from config import config
|
| 8 |
+
|
| 9 |
+
llm = ChatOllama(model = config.ollama_model, temperature = 0.55)
|
| 10 |
+
|
| 11 |
+
tools_dict = get_openai_tools()
|
| 12 |
+
|
| 13 |
+
history = ChatMessageHistory()
|
| 14 |
+
|
| 15 |
+
functions_agent = create_functions_agent(llm=llm, prompt=rag_agent_prompt)
|
| 16 |
+
functions_agent_executor = AgentExecutor(agent=functions_agent, tools=tools, verbose=True, return_intermediate_steps=True)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
while True:
|
| 21 |
+
try:
|
| 22 |
+
inp = input("User:")
|
| 23 |
+
if inp == "/bye":
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
response = functions_agent_executor.invoke({"input": inp, "chat_history": history, "tools" : tools_dict})
|
| 27 |
+
response['output'] = response['output'].replace("<|im_end|>", "")
|
| 28 |
+
history.add_user_message(inp)
|
| 29 |
+
history.add_ai_message(response['output'])
|
| 30 |
+
|
| 31 |
+
print(response['output'])
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(e)
|
| 34 |
+
continue
|
src/agents/format_scratchpad/functions.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import List, Sequence, Tuple
|
| 3 |
+
|
| 4 |
+
from langchain_core.agents import AgentAction, AgentActionMessageLog
|
| 5 |
+
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
|
| 6 |
+
|
| 7 |
+
def _convert_agent_action_to_messages(
|
| 8 |
+
agent_action: AgentAction, observation: str
|
| 9 |
+
) -> List[BaseMessage]:
|
| 10 |
+
"""Convert an agent action to a message.
|
| 11 |
+
This code is used to reconstruct the original AI message from the agent action.
|
| 12 |
+
Args:
|
| 13 |
+
agent_action: Agent action to convert.
|
| 14 |
+
Returns:
|
| 15 |
+
AIMessage that corresponds to the original tool invocation.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
if isinstance(agent_action, AgentActionMessageLog):
|
| 19 |
+
return list(agent_action.message_log) + [f"<tool_response>\n{_create_function_message(agent_action, observation)}\n</tool_response>"]
|
| 20 |
+
else:
|
| 21 |
+
return [AIMessage(content=agent_action.log)]
|
| 22 |
+
|
| 23 |
+
def _create_function_message(
|
| 24 |
+
agent_action: AgentAction, observation: str
|
| 25 |
+
) -> str:
|
| 26 |
+
"""Convert agent action and observation into a function message.
|
| 27 |
+
Args:
|
| 28 |
+
agent_action: the tool invocation request from the agent
|
| 29 |
+
observation: the result of the tool invocation
|
| 30 |
+
Returns:
|
| 31 |
+
FunctionMessage that corresponds to the original tool invocation
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
if not isinstance(observation, str):
|
| 35 |
+
try:
|
| 36 |
+
content = json.dumps(observation, ensure_ascii=False)
|
| 37 |
+
except Exception:
|
| 38 |
+
content = str(observation)
|
| 39 |
+
else:
|
| 40 |
+
content = observation
|
| 41 |
+
tool_response = {
|
| 42 |
+
"name": agent_action.tool,
|
| 43 |
+
"content": content,
|
| 44 |
+
}
|
| 45 |
+
return json.dumps(tool_response)
|
| 46 |
+
|
| 47 |
+
def format_to_function_messages(
|
| 48 |
+
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
| 49 |
+
) -> List[BaseMessage]:
|
| 50 |
+
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
| 51 |
+
Args:
|
| 52 |
+
intermediate_steps: Steps the LLM has taken to date, along with observations
|
| 53 |
+
Returns:
|
| 54 |
+
list of messages to send to the LLM for the next prediction
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
messages = []
|
| 58 |
+
for agent_action, observation in intermediate_steps:
|
| 59 |
+
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
| 60 |
+
return messages
|
| 61 |
+
|
| 62 |
+
# Backwards compatibility
|
| 63 |
+
format_to_functions = format_to_function_messages
|
src/agents/functions_agent/base.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Sequence
|
| 2 |
+
|
| 3 |
+
from langchain_core.language_models import BaseLanguageModel
|
| 4 |
+
from langchain_core.prompts.chat import ChatPromptTemplate
|
| 5 |
+
from langchain_core.runnables import Runnable, RunnablePassthrough
|
| 6 |
+
from langchain_core.tools import BaseTool
|
| 7 |
+
|
| 8 |
+
from agents.format_scratchpad.functions import (
|
| 9 |
+
format_to_function_messages,
|
| 10 |
+
)
|
| 11 |
+
from agents.output_parsers.functions import (
|
| 12 |
+
FunctionsAgentOutputParser,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def create_functions_agent(
|
| 16 |
+
llm: BaseLanguageModel, prompt: ChatPromptTemplate
|
| 17 |
+
) -> Runnable:
|
| 18 |
+
"""Create an agent that uses function calling.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
llm: LLM to use as the agent. Should work with Nous Hermes function calling,
|
| 22 |
+
so either be an Nous Hermes based model that supports that or a wrapper of
|
| 23 |
+
a different model that adds in equivalent support.
|
| 24 |
+
prompt: The prompt to use. See Prompt section below for more.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
A Runnable sequence representing an agent. It takes as input all the same input
|
| 28 |
+
variables as the prompt passed in does. It returns as output either an
|
| 29 |
+
AgentAction or AgentFinish.
|
| 30 |
+
"""
|
| 31 |
+
if "agent_scratchpad" not in (
|
| 32 |
+
prompt.input_variables + list(prompt.partial_variables)
|
| 33 |
+
):
|
| 34 |
+
raise ValueError(
|
| 35 |
+
"Prompt must have input variable `agent_scratchpad`, but wasn't found."
|
| 36 |
+
f"Found {prompt.input_variables} instead."
|
| 37 |
+
)
|
| 38 |
+
agent = (
|
| 39 |
+
RunnablePassthrough.assign(
|
| 40 |
+
agent_scratchpad=lambda x: format_to_function_messages(
|
| 41 |
+
x["intermediate_steps"]
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
| prompt
|
| 45 |
+
| llm
|
| 46 |
+
| FunctionsAgentOutputParser()
|
| 47 |
+
)
|
| 48 |
+
return agent
|
src/agents/output_parsers/functions.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from json import JSONDecodeError
|
| 3 |
+
from typing import List, Union
|
| 4 |
+
|
| 5 |
+
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
| 6 |
+
from langchain_core.exceptions import OutputParserException
|
| 7 |
+
from langchain_core.messages import (
|
| 8 |
+
AIMessage,
|
| 9 |
+
BaseMessage,
|
| 10 |
+
)
|
| 11 |
+
from langchain_core.outputs import ChatGeneration, Generation
|
| 12 |
+
|
| 13 |
+
from langchain.agents.agent import AgentOutputParser
|
| 14 |
+
from agents.output_parsers.utils import parse_tool_call, check_tool_call
|
| 15 |
+
import ast
|
| 16 |
+
|
| 17 |
+
class FunctionsAgentOutputParser(AgentOutputParser):
|
| 18 |
+
"""Parses a message into agent action/finish.
|
| 19 |
+
|
| 20 |
+
Is meant to be used with a model with Nous Hermes 2 Pro as the base, as it relies on the specific
|
| 21 |
+
function_call parameter from Nous Research to convey what tools to use.
|
| 22 |
+
|
| 23 |
+
If a function_call parameter is passed, then that is used to get
|
| 24 |
+
the tool and tool input.
|
| 25 |
+
|
| 26 |
+
If one is not passed, then the AIMessage is assumed to be the final output.
|
| 27 |
+
It was add a
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def _type(self) -> str:
|
| 32 |
+
return "functions-agent"
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def _parse_ai_message(message: BaseMessage):
|
| 36 |
+
"""Parse an AI message."""
|
| 37 |
+
if not isinstance(message, AIMessage):
|
| 38 |
+
raise TypeError(f"Expected an AI message got {type(message)}")
|
| 39 |
+
|
| 40 |
+
actions = []
|
| 41 |
+
|
| 42 |
+
pattern = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
| 43 |
+
try:
|
| 44 |
+
tool_calls = [parse_tool_call(t.strip()) for t in pattern.findall(message.content)]
|
| 45 |
+
except:
|
| 46 |
+
raise OutputParserException(
|
| 47 |
+
f"Could not parse tool calls from message content: {message.content}. Please ensure that the tool calls are valid JSON."
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if not tool_calls:
|
| 51 |
+
return AgentFinish(
|
| 52 |
+
return_values={"output": message.content}, log=str(message.content)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
for tool_call in tool_calls:
|
| 56 |
+
tool_name, tool_input = check_tool_call(tool_call)
|
| 57 |
+
content_msg = f"\n{message.content}\n" if message.content else "\n"
|
| 58 |
+
log = f"\nInvoking: `{tool_name}` with `{tool_input}`\n{content_msg}\n"
|
| 59 |
+
actions.append(AgentActionMessageLog(
|
| 60 |
+
tool=tool_name,
|
| 61 |
+
tool_input=tool_input,
|
| 62 |
+
log=log,
|
| 63 |
+
message_log=[message],
|
| 64 |
+
))
|
| 65 |
+
|
| 66 |
+
return actions
|
| 67 |
+
|
| 68 |
+
def parse_result(
|
| 69 |
+
self, result: List[Generation], *, partial: bool = False
|
| 70 |
+
) -> Union[AgentAction, AgentFinish]:
|
| 71 |
+
if not isinstance(result[0], ChatGeneration):
|
| 72 |
+
raise ValueError("This output parser only works on ChatGeneration output")
|
| 73 |
+
message = result[0].message
|
| 74 |
+
return self._parse_ai_message(message)
|
| 75 |
+
|
| 76 |
+
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
| 77 |
+
raise ValueError("Can only parse messages")
|
src/agents/output_parsers/utils.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.utils.function_calling import convert_to_openai_function
|
| 2 |
+
from functions import tools
|
| 3 |
+
import re
|
| 4 |
+
import ast
|
| 5 |
+
|
| 6 |
+
def parse_args(args: str):
|
| 7 |
+
args = args.strip()
|
| 8 |
+
args = args.replace("true", "True")
|
| 9 |
+
args = args.replace("false", "False")
|
| 10 |
+
args = args.replace("null", "None")
|
| 11 |
+
args = args.replace("\"", "\"\"\"")
|
| 12 |
+
i = 0
|
| 13 |
+
while args[i] != "\"" and args[i] != "\'" and i < len(args) - 1:
|
| 14 |
+
i += 1
|
| 15 |
+
args = args[i:]
|
| 16 |
+
if args[-4:] != "True" and args[-5:] != "False":
|
| 17 |
+
i = len(args) - 1
|
| 18 |
+
while args[i] != "\"" and args[i] != "\'" and i > 0:
|
| 19 |
+
i -= 1
|
| 20 |
+
args = args[:i + 1]
|
| 21 |
+
print(args)
|
| 22 |
+
return ast.literal_eval("{" + args + "}")
|
| 23 |
+
|
| 24 |
+
def parse_tool_call(call: str):
|
| 25 |
+
call = call.strip()
|
| 26 |
+
name: bool = "\"name\": " in call or "\'name\':" in call
|
| 27 |
+
args: bool = "\"arguments\": " in call or "\'arguments\':" in call
|
| 28 |
+
if not name:
|
| 29 |
+
print({"arguments": {}, "name": "missing_function_call"})
|
| 30 |
+
return {"arguments": {}, "name": "missing_function_call"}
|
| 31 |
+
if not args:
|
| 32 |
+
pattern = re.compile(r"\"name\": \"(.*?)\"|\'name\': \'(.*?)\'", re.DOTALL)
|
| 33 |
+
match = pattern.findall(call)
|
| 34 |
+
for n in match:
|
| 35 |
+
if isinstance(n, tuple):
|
| 36 |
+
n = n[0]
|
| 37 |
+
print({"arguments": {}, "name": n})
|
| 38 |
+
return {"arguments": {}, "name": n}
|
| 39 |
+
args_pattern = re.compile(r"\"arguments\": {(.*?)}|\'arguments\': {(.*?)}", re.DOTALL)
|
| 40 |
+
args_match = args_pattern.findall(call)
|
| 41 |
+
for a in args_match:
|
| 42 |
+
print(a, "\n")
|
| 43 |
+
print(a[0])
|
| 44 |
+
args = parse_args(a[0])
|
| 45 |
+
name_pattern = re.compile(r"\"name\": \"(.*?)\"", re.DOTALL)
|
| 46 |
+
name_match = name_pattern.findall(call)
|
| 47 |
+
for n in name_match:
|
| 48 |
+
if isinstance(n, tuple):
|
| 49 |
+
n = n[0]
|
| 50 |
+
print({"arguments": args, "name": n})
|
| 51 |
+
return {"arguments": args, "name": n}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def check_tool_call(call: dict):
|
| 55 |
+
global tools
|
| 56 |
+
tools = [convert_to_openai_function(t) for t in tools]
|
| 57 |
+
if call["name"] not in [t["name"] for t in tools]:
|
| 58 |
+
return "handle_tools_error", {"error": {"error": {"name": call["name"]}}}
|
| 59 |
+
tool = next((t for t in tools if t["name"] == call["name"]), None)
|
| 60 |
+
|
| 61 |
+
if set(list(tool["parameters"]["properties"])) != set(list(call["arguments"])):
|
| 62 |
+
print({"tool_response": {"error": {"expected": list(tool["parameters"]["properties"]), "received": list(call["arguments"])}, "name": call["name"]}})
|
| 63 |
+
return "handle_tools_error", {"error": {"error": {"expected": list(tool["parameters"]["properties"]), "received": list(call["arguments"])}, "name": call["name"]}}
|
| 64 |
+
return call["name"], call["arguments"]
|
src/functioncall.py
CHANGED
|
@@ -11,6 +11,7 @@ import functions
|
|
| 11 |
from prompter import PromptManager
|
| 12 |
from validator import validate_function_call_schema
|
| 13 |
from langchain_community.chat_models import ChatOllama
|
|
|
|
| 14 |
from langchain.prompts import PromptTemplate
|
| 15 |
from langchain_core.output_parsers import StrOutputParser
|
| 16 |
|
|
@@ -23,7 +24,7 @@ class ModelInference:
|
|
| 23 |
def __init__(self, chat_template: str):
|
| 24 |
self.prompter = PromptManager()
|
| 25 |
|
| 26 |
-
self.model =
|
| 27 |
template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
|
| 28 |
chain = template | self.model | StrOutputParser()
|
| 29 |
|
|
@@ -69,6 +70,7 @@ class ModelInference:
|
|
| 69 |
add_generation_prompt=True,
|
| 70 |
tokenize=False,
|
| 71 |
)
|
|
|
|
| 72 |
completion = self.model.invoke(inputs, format='json')
|
| 73 |
return completion.content
|
| 74 |
|
|
@@ -84,7 +86,6 @@ class ModelInference:
|
|
| 84 |
|
| 85 |
def recursive_loop(prompt, completion, depth):
|
| 86 |
nonlocal max_depth
|
| 87 |
-
breakpoint()
|
| 88 |
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
|
| 89 |
prompt.append({"role": "assistant", "content": assistant_message})
|
| 90 |
|
|
|
|
| 11 |
from prompter import PromptManager
|
| 12 |
from validator import validate_function_call_schema
|
| 13 |
from langchain_community.chat_models import ChatOllama
|
| 14 |
+
from langchain_community.llms import Ollama
|
| 15 |
from langchain.prompts import PromptTemplate
|
| 16 |
from langchain_core.output_parsers import StrOutputParser
|
| 17 |
|
|
|
|
| 24 |
def __init__(self, chat_template: str):
|
| 25 |
self.prompter = PromptManager()
|
| 26 |
|
| 27 |
+
self.model = Ollama(model=config.ollama_model, temperature=0.0, format='json')
|
| 28 |
template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
|
| 29 |
chain = template | self.model | StrOutputParser()
|
| 30 |
|
|
|
|
| 70 |
add_generation_prompt=True,
|
| 71 |
tokenize=False,
|
| 72 |
)
|
| 73 |
+
inputs = inputs.replace("<|begin_of_text|>", "") # Something wrong with the chat template, hotfix
|
| 74 |
completion = self.model.invoke(inputs, format='json')
|
| 75 |
return completion.content
|
| 76 |
|
|
|
|
| 86 |
|
| 87 |
def recursive_loop(prompt, completion, depth):
|
| 88 |
nonlocal max_depth
|
|
|
|
| 89 |
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
|
| 90 |
prompt.append({"role": "assistant", "content": assistant_message})
|
| 91 |
|
src/functions.py
CHANGED
|
@@ -11,7 +11,7 @@ from bs4 import BeautifulSoup
|
|
| 11 |
from logger import logger
|
| 12 |
from openai import AzureOpenAI
|
| 13 |
from langchain.tools import tool
|
| 14 |
-
from langchain_core.utils.function_calling import convert_to_openai_tool
|
| 15 |
from config import config
|
| 16 |
|
| 17 |
from azure.core.credentials import AzureKeyCredential
|
|
@@ -281,20 +281,24 @@ def get_company_profile(symbol: str) -> dict:
|
|
| 281 |
print(f"Error fetching company profile for {symbol}: {e}")
|
| 282 |
return {}
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
def get_openai_tools() -> List[dict]:
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
# get_stock_fundamentals,
|
| 292 |
-
# get_financial_statements,
|
| 293 |
-
get_key_financial_ratios,
|
| 294 |
-
# get_analyst_recommendations,
|
| 295 |
-
# get_dividend_data,
|
| 296 |
-
# get_technical_indicators
|
| 297 |
-
]
|
| 298 |
-
|
| 299 |
-
tools = [convert_to_openai_tool(f) for f in functions]
|
| 300 |
-
return tools
|
|
|
|
| 11 |
from logger import logger
|
| 12 |
from openai import AzureOpenAI
|
| 13 |
from langchain.tools import tool
|
| 14 |
+
from langchain_core.utils.function_calling import convert_to_openai_tool, convert_to_openai_function
|
| 15 |
from config import config
|
| 16 |
|
| 17 |
from azure.core.credentials import AzureKeyCredential
|
|
|
|
| 281 |
print(f"Error fetching company profile for {symbol}: {e}")
|
| 282 |
return {}
|
| 283 |
|
| 284 |
+
tools = [
|
| 285 |
+
get_analysis,
|
| 286 |
+
# google_search_and_scrape,
|
| 287 |
+
get_current_stock_price,
|
| 288 |
+
get_company_news,
|
| 289 |
+
# get_company_profile,
|
| 290 |
+
# get_stock_fundamentals,
|
| 291 |
+
# get_financial_statements,
|
| 292 |
+
get_key_financial_ratios,
|
| 293 |
+
# get_analyst_recommendations,
|
| 294 |
+
# get_dividend_data,
|
| 295 |
+
# get_technical_indicators
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
def get_openai_tools() -> List[dict]:
|
| 299 |
+
tools_ = [convert_to_openai_tool(f) for f in tools]
|
| 300 |
+
return tools_
|
| 301 |
+
|
| 302 |
+
def get_openai_functions() -> List[str]:
|
| 303 |
+
functions = [convert_to_openai_function(f) for f in tools]
|
| 304 |
+
return functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/prompts/prompt.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 3 |
+
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
| 4 |
+
import yaml
|
| 5 |
+
|
| 6 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 7 |
+
with open(f"{current_dir}/rag_template.yaml", "r") as yaml_file:
|
| 8 |
+
templates = yaml.safe_load(yaml_file)
|
| 9 |
+
|
| 10 |
+
# RAG Agent
|
| 11 |
+
sys_msg_template: str = templates["sys_msg"]
|
| 12 |
+
human_msg_template: str = templates["human_msg"]
|
| 13 |
+
rag_agent_prompt = ChatPromptTemplate.from_messages([
|
| 14 |
+
SystemMessagePromptTemplate.from_template(sys_msg_template),
|
| 15 |
+
HumanMessagePromptTemplate.from_template(human_msg_template),
|
| 16 |
+
MessagesPlaceholder(variable_name = "agent_scratchpad")
|
| 17 |
+
])
|
src/prompts/rag_template.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sys_msg: "
|
| 2 |
+
You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
|
| 3 |
+
<tools>
|
| 4 |
+
{tools}
|
| 5 |
+
</tools>
|
| 6 |
+
Use the following pydantic model json schema for each tool call you will make: {{\"properties\": {{\"arguments\": {{\"title\": \"Arguments\", \"type\": \"object\"}}, \"name\": {{\"title\": \"Name\", \"type\": \"string\"}}}}, \"required\": [\"arguments\", \"name\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}
|
| 7 |
+
For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
| 8 |
+
<tool_call>
|
| 9 |
+
{{\"arguments\": <args-dict>, \"name\": <function-name>}}
|
| 10 |
+
</tool_call>"
|
| 11 |
+
human_msg: "
|
| 12 |
+
{input}"
|