Update tools/tool_agent.py
Browse files- tools/tool_agent.py +27 -13
tools/tool_agent.py
CHANGED
|
@@ -10,22 +10,36 @@ class ToolCallingAgent:
|
|
| 10 |
)
|
| 11 |
|
| 12 |
def generate(self, prompt, tools):
|
| 13 |
-
# Format the tools specification
|
| 14 |
tools_json = json.dumps(tools, ensure_ascii=False)
|
| 15 |
-
|
| 16 |
-
# Create the tool-calling prompt
|
| 17 |
system_msg = f"""You are an AI assistant that can call tools.
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
#
|
|
|
|
|
|
|
| 22 |
response = self.model(
|
| 23 |
-
|
| 24 |
max_new_tokens=200,
|
| 25 |
-
do_sample=
|
| 26 |
)
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
try:
|
| 29 |
-
return json.loads(
|
| 30 |
-
except:
|
| 31 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
)
|
| 11 |
|
| 12 |
def generate(self, prompt, tools):
|
|
|
|
| 13 |
tools_json = json.dumps(tools, ensure_ascii=False)
|
|
|
|
|
|
|
| 14 |
system_msg = f"""You are an AI assistant that can call tools.
|
| 15 |
+
Available tools: {tools_json}
|
| 16 |
+
Respond ONLY with a valid JSON containing keys 'tool_name' and 'parameters'."""
|
| 17 |
+
|
| 18 |
+
# Construct prompt with system and user tokens (assuming model supports these)
|
| 19 |
+
full_prompt = f"<|system|>{system_msg}</s><|user|>{prompt}</s>"
|
| 20 |
+
|
| 21 |
response = self.model(
|
| 22 |
+
full_prompt,
|
| 23 |
max_new_tokens=200,
|
| 24 |
+
do_sample=False # deterministic output for better JSON consistency
|
| 25 |
)
|
| 26 |
+
|
| 27 |
+
text = response[0]['generated_text']
|
| 28 |
+
|
| 29 |
+
# Extract JSON substring between first '{' and last '}'
|
| 30 |
+
json_start = text.find("{")
|
| 31 |
+
json_end = text.rfind("}") + 1
|
| 32 |
+
if json_start == -1 or json_end == -1:
|
| 33 |
+
return {"error": "No JSON found in model output", "raw_output": text}
|
| 34 |
+
|
| 35 |
+
json_text = text[json_start:json_end]
|
| 36 |
+
|
| 37 |
try:
|
| 38 |
+
return json.loads(json_text)
|
| 39 |
+
except json.JSONDecodeError as e:
|
| 40 |
+
return {
|
| 41 |
+
"error": "Failed to parse JSON",
|
| 42 |
+
"message": str(e),
|
| 43 |
+
"raw_output": text,
|
| 44 |
+
"extracted_json": json_text
|
| 45 |
+
}
|