Spaces:
Runtime error
Runtime error
| import os | |
| from typing import TypedDict, Annotated, Optional, Any | |
| from langgraph.graph.message import add_messages | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import AnyMessage, HumanMessage | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.prebuilt import tools_condition | |
| from langchain_core.tools import Tool | |
| from search_tools import vector_store, arvix_search, question_search, web_search, wiki_search | |
| from math_tools import add, subtract, modulus, multiply, divide | |
| hf_token = os.environ.get("HF_TOKEN") | |
| together_token = os.environ.get("TOGETHER_API_TOKEN") | |
| tools = [ | |
| add, | |
| arvix_search, | |
| divide, | |
| modulus, | |
| multiply, | |
| question_search, | |
| subtract, | |
| web_search, | |
| wiki_search | |
| ] | |
| def build_system_prompt(): | |
| with open("system_prompt.txt", "r", encoding="utf-8") as f: | |
| sys_msg = f.read() | |
| print("System Prompt: " + sys_msg) | |
| return sys_msg | |
| class Bariq: | |
| def __init__(self): | |
| base_llm = ChatOpenAI( | |
| model="ServiceNow-AI/Apriel-1.6-15b-Thinker", | |
| openai_api_key=together_token, | |
| openai_api_base="https://api.together.xyz/v1", | |
| max_tokens=1024, | |
| temperature=0.1, | |
| top_p=0.9, | |
| frequency_penalty=0.0, | |
| presence_penalty=0.0, | |
| stop=["</s>", "User:", "Assistant:", "Final answer:"], | |
| ) | |
| # Bind your tools exactly like before | |
| # self.llm = base_llm.bind(tools=tools) | |
| self.llm = base_llm | |
| class AgentState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| previous_messages: Optional[Annotated[list[AnyMessage], add_messages]] | |
| agent = Bariq() | |
| # ------------------------ | |
| # Nodes | |
| # ------------------------ | |
| # General Assistant Node | |
| def assistant(state: AgentState) -> AgentState: | |
| """Assistant Node""" | |
| try: | |
| prompts = [] | |
| for m in state["previous_messages"]: | |
| prompt_content = getattr(m, "content", None) | |
| message_type = getattr(m, "type", None) | |
| print("prompt_content " + prompt_content) | |
| if message_type == "human" or message_type == "ai": | |
| prompts.append({ | |
| "role": "assistant", | |
| "content": prompt_content | |
| }) | |
| else: | |
| prompts.append({ | |
| "role": "system", | |
| "content": prompt_content | |
| }) | |
| # print("Assistant Node Prompts: " + str(prompts)) | |
| llm_response = agent.llm.invoke(prompts) | |
| print("HumanMessage(content=llm_response.content) " + str(HumanMessage(content=llm_response.content))) | |
| return { | |
| "previous_messages": state["previous_messages"], | |
| "messages": state["previous_messages"] + [HumanMessage(content=llm_response.content)] | |
| } | |
| except Exception as e: | |
| print("Exception in Assistant Node:", e) | |
| return state | |
| # Retrieval Node | |
| def retriever(state: AgentState) -> AgentState: | |
| """Retriever Node""" | |
| try: | |
| user_messages = [m for m in state["messages"] if getattr(m, "type", None) == "human"] | |
| query = getattr(user_messages[-1], "content", None) | |
| # Vector search | |
| similar_docs = vector_store.similarity_search(query, k=2) | |
| print("similar documents:", str(similar_docs)) | |
| if similar_docs: | |
| context = similar_docs[0].page_content | |
| new_messages = state["messages"] + [HumanMessage(content=context)] | |
| return { | |
| "previous_messages": state["messages"], | |
| "messages": new_messages | |
| } | |
| else: | |
| return { | |
| "previous_messages": state["messages"] if state["previous_messages"] is None else state["previous_messages"], | |
| "messages": state["messages"] | |
| } | |
| except Exception as e: | |
| print("Exception in retriever node:", e) | |
| # ------------------------ | |
| # Workflow | |
| # ------------------------ | |
| def build_workflow() -> Any: | |
| graph = StateGraph(AgentState) | |
| # Add nodes | |
| graph.add_node("retriever", retriever) | |
| graph.add_node("assistant", assistant) | |
| graph.add_node("tools", ToolNode(tools)) | |
| # Add edges | |
| graph.add_edge(START, "retriever") | |
| graph.add_edge("retriever", "assistant") | |
| graph.add_conditional_edges( | |
| "assistant", | |
| tools_condition | |
| ) | |
| graph.add_edge("tools", "assistant") | |
| graph.add_edge("assistant", END) | |
| return graph.compile() |