Davit6174 commited on
Commit
b35d9f3
·
verified ·
1 Parent(s): edc75dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -6
app.py CHANGED
@@ -37,10 +37,19 @@ class ZephyrAPI:
37
  self.headers = {
38
  "Authorization": f"Bearer {os.getenv('HF_TOKEN')}"
39
  }
 
40
  print("ZephyrAPI initialized using Inference API.")
41
 
 
 
 
42
  def __call__(self, question: str) -> str:
43
- prompt = f"<|system|>\nYou are a helpful assistant.\n<|user|>\n{question}\n<|assistant|>\n"
 
 
 
 
 
44
  payload = {
45
  "inputs": prompt,
46
  "parameters": {
@@ -61,8 +70,9 @@ class ZephyrAPI:
61
 
62
 
63
  class LangGraphAgent:
64
- def __init__(self):
65
- self.model = ZephyrAPI()
 
66
 
67
  builder = StateGraph(dict)
68
 
@@ -72,8 +82,29 @@ class LangGraphAgent:
72
  if not user_msg:
73
  return {"messages": messages + [AIMessage(content="❌ No user input found.")]}
74
 
75
- response = self.model(user_msg.content)
76
- return {"messages": messages + [AIMessage(content=response)]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  builder.add_node("chat", call_model)
79
  builder.set_entry_point("chat")
@@ -254,7 +285,7 @@ with gr.Blocks() as demo:
254
 
255
  def test_agent_response(question: str) -> str:
256
  # agent = BasicAgent()
257
- agent = LangGraphAgent()
258
  print(agent("What's the capital of France?"))
259
  return agent(question)
260
 
 
37
  self.headers = {
38
  "Authorization": f"Bearer {os.getenv('HF_TOKEN')}"
39
  }
40
+ self.tool_descriptions = self.format_tools(tools or [])
41
  print("ZephyrAPI initialized using Inference API.")
42
 
43
+ def format_tools(self, tools):
44
+ return "\n".join([f"- {tool.name}: {tool.description}" for tool in tools])
45
+
46
  def __call__(self, question: str) -> str:
47
+ prompt = f"<|system|>\n"
48
+ f"You are a helpful AI agent. You can use the following tools when needed:\n"
49
+ f"{self.tool_descriptions}\n"
50
+ f"\nRespond to the user. If a tool is needed, use this format:\n"
51
+ f"Action: tool_name\nAction Input: input_for_tool\n"
52
+ f"\n<|user|>\n{question}\n<|assistant|>\n"
53
  payload = {
54
  "inputs": prompt,
55
  "parameters": {
 
70
 
71
 
72
  class LangGraphAgent:
73
+ def __init__(self, tools=None):
74
+ self.tools = {tool.name: tool for tool in tools} if tools else {}
75
+ self.model = ZephyrAPI(tools=tools)
76
 
77
  builder = StateGraph(dict)
78
 
 
82
  if not user_msg:
83
  return {"messages": messages + [AIMessage(content="❌ No user input found.")]}
84
 
85
+ content = user_msg.content.strip()
86
+ raw_response = self.model(content)
87
+
88
+ # Check if model issued a tool call
89
+ match = re.search(r"Action:\s*(\w+)\s*Action Input:\s*(.+)", raw_response, re.IGNORECASE)
90
+ if match:
91
+ tool_name, tool_input = match.groups()
92
+ tool_fn = self.tools.get(tool_name)
93
+ if tool_fn:
94
+ try:
95
+ tool_output = tool_fn(tool_input.strip('"'))
96
+ follow_up = self.model(f"User asked: {content}\nTool [{tool_name}] returned: {tool_output}")
97
+ return {"messages": messages + [
98
+ AIMessage(content=raw_response),
99
+ AIMessage(content=tool_output),
100
+ AIMessage(content=follow_up),
101
+ ]}
102
+ except Exception as e:
103
+ return {"messages": messages + [AIMessage(content=f"⚠️ Tool error: {e}")]}
104
+ else:
105
+ return {"messages": messages + [AIMessage(content=f"⚠️ Unknown tool: {tool_name}")]}
106
+
107
+ return {"messages": messages + [AIMessage(content=raw_response)]}
108
 
109
  builder.add_node("chat", call_model)
110
  builder.set_entry_point("chat")
 
285
 
286
  def test_agent_response(question: str) -> str:
287
  # agent = BasicAgent()
288
+ agent = LangGraphAgent(tools=tools)
289
  print(agent("What's the capital of France?"))
290
  return agent(question)
291