Update tools/tool_agent.py
Browse files- tools/tool_agent.py +14 -10
tools/tool_agent.py
CHANGED
|
@@ -3,25 +3,29 @@ import json
|
|
| 3 |
|
| 4 |
class ToolCallingAgent:
|
| 5 |
def __init__(self):
|
| 6 |
-
#
|
| 7 |
self.model = pipeline(
|
| 8 |
"text-generation",
|
| 9 |
-
model="
|
| 10 |
-
device=-1,
|
| 11 |
-
torch_dtype="float32"
|
| 12 |
)
|
| 13 |
|
| 14 |
def generate(self, prompt, tools):
|
| 15 |
try:
|
| 16 |
tools_json = json.dumps(tools, ensure_ascii=False)
|
| 17 |
-
prompt = f"""Respond with JSON for one tool call
|
| 18 |
|
| 19 |
response = self.model(
|
| 20 |
prompt,
|
| 21 |
-
max_new_tokens=
|
| 22 |
-
do_sample=False
|
| 23 |
)
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class ToolCallingAgent:
|
| 5 |
def __init__(self):
|
| 6 |
+
# Small CPU-friendly model
|
| 7 |
self.model = pipeline(
|
| 8 |
"text-generation",
|
| 9 |
+
model="gpt2", # Replace with small model you want
|
| 10 |
+
device=-1,
|
| 11 |
+
torch_dtype="float32"
|
| 12 |
)
|
| 13 |
|
| 14 |
def generate(self, prompt, tools):
|
| 15 |
try:
|
| 16 |
tools_json = json.dumps(tools, ensure_ascii=False)
|
| 17 |
+
prompt = f"""Respond ONLY with JSON for one tool call from the following list: {tools_json}\nUser input: {prompt}"""
|
| 18 |
|
| 19 |
response = self.model(
|
| 20 |
prompt,
|
| 21 |
+
max_new_tokens=100,
|
| 22 |
+
do_sample=False
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# Try to find JSON in output
|
| 26 |
+
text = response[0]['generated_text']
|
| 27 |
+
json_start = text.find("{")
|
| 28 |
+
json_end = text.rfind("}") + 1
|
| 29 |
+
return json.loads(text[json_start:json_end])
|
| 30 |
+
except Exception as e:
|
| 31 |
+
return {"tool_name": "error", "parameters": {"message": f"Failed to process request: {str(e)}"}}
|