Trouter-Library commited on
Commit
a2975bd
·
verified ·
1 Parent(s): b6202d7

Create tools_system.py

Browse files
Files changed (1) hide show
  1. tools_system.py +466 -0
tools_system.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V1.5 Tools and Function Calling System
3
+ Enables structured function calls and tool integration
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ from typing import List, Dict, Any, Callable, Optional
9
+ from dataclasses import dataclass, asdict
10
+ from enum import Enum
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ToolType(Enum):
17
+ """Types of tools available."""
18
+ FUNCTION = "function"
19
+ API = "api"
20
+ RETRIEVAL = "retrieval"
21
+ CODE_INTERPRETER = "code_interpreter"
22
+
23
+
24
+ @dataclass
25
+ class ToolParameter:
26
+ """Tool parameter specification."""
27
+ name: str
28
+ type: str
29
+ description: str
30
+ required: bool = True
31
+ enum: Optional[List[str]] = None
32
+ default: Optional[Any] = None
33
+
34
+
35
+ @dataclass
36
+ class Tool:
37
+ """Tool/Function definition."""
38
+ name: str
39
+ description: str
40
+ parameters: List[ToolParameter]
41
+ tool_type: ToolType = ToolType.FUNCTION
42
+ implementation: Optional[Callable] = None
43
+
44
+ def to_openai_format(self) -> Dict:
45
+ """Convert to OpenAI function calling format."""
46
+ properties = {}
47
+ required = []
48
+
49
+ for param in self.parameters:
50
+ prop = {
51
+ "type": param.type,
52
+ "description": param.description
53
+ }
54
+ if param.enum:
55
+ prop["enum"] = param.enum
56
+
57
+ properties[param.name] = prop
58
+
59
+ if param.required:
60
+ required.append(param.name)
61
+
62
+ return {
63
+ "type": "function",
64
+ "function": {
65
+ "name": self.name,
66
+ "description": self.description,
67
+ "parameters": {
68
+ "type": "object",
69
+ "properties": properties,
70
+ "required": required
71
+ }
72
+ }
73
+ }
74
+
75
+
76
+ class ToolRegistry:
77
+ """Registry for managing available tools."""
78
+
79
+ def __init__(self):
80
+ self.tools: Dict[str, Tool] = {}
81
+ self._register_default_tools()
82
+
83
+ def register(self, tool: Tool):
84
+ """Register a new tool."""
85
+ self.tools[tool.name] = tool
86
+ logger.info(f"Registered tool: {tool.name}")
87
+
88
+ def get(self, name: str) -> Optional[Tool]:
89
+ """Get tool by name."""
90
+ return self.tools.get(name)
91
+
92
+ def list_tools(self) -> List[str]:
93
+ """List all registered tools."""
94
+ return list(self.tools.keys())
95
+
96
+ def get_tools_schema(self) -> List[Dict]:
97
+ """Get tools in OpenAI schema format."""
98
+ return [tool.to_openai_format() for tool in self.tools.values()]
99
+
100
+ def _register_default_tools(self):
101
+ """Register default built-in tools."""
102
+
103
+ # Web search tool
104
+ search_tool = Tool(
105
+ name="web_search",
106
+ description="Search the web for current information",
107
+ parameters=[
108
+ ToolParameter(
109
+ name="query",
110
+ type="string",
111
+ description="Search query"
112
+ ),
113
+ ToolParameter(
114
+ name="num_results",
115
+ type="integer",
116
+ description="Number of results to return",
117
+ required=False,
118
+ default=5
119
+ )
120
+ ],
121
+ tool_type=ToolType.API
122
+ )
123
+ self.register(search_tool)
124
+
125
+ # Calculator tool
126
+ calc_tool = Tool(
127
+ name="calculator",
128
+ description="Perform mathematical calculations",
129
+ parameters=[
130
+ ToolParameter(
131
+ name="expression",
132
+ type="string",
133
+ description="Mathematical expression to evaluate"
134
+ )
135
+ ],
136
+ tool_type=ToolType.FUNCTION,
137
+ implementation=self._calculator_impl
138
+ )
139
+ self.register(calc_tool)
140
+
141
+ # Code execution tool
142
+ code_tool = Tool(
143
+ name="execute_python",
144
+ description="Execute Python code and return the result",
145
+ parameters=[
146
+ ToolParameter(
147
+ name="code",
148
+ type="string",
149
+ description="Python code to execute"
150
+ )
151
+ ],
152
+ tool_type=ToolType.CODE_INTERPRETER,
153
+ implementation=self._python_executor
154
+ )
155
+ self.register(code_tool)
156
+
157
+ def _calculator_impl(self, expression: str) -> Dict:
158
+ """Calculator implementation."""
159
+ try:
160
+ # Safe eval for math expressions
161
+ import ast
162
+ import operator
163
+
164
+ operators = {
165
+ ast.Add: operator.add,
166
+ ast.Sub: operator.sub,
167
+ ast.Mult: operator.mul,
168
+ ast.Div: operator.truediv,
169
+ ast.Pow: operator.pow,
170
+ ast.USub: operator.neg
171
+ }
172
+
173
+ def eval_expr(node):
174
+ if isinstance(node, ast.Num):
175
+ return node.n
176
+ elif isinstance(node, ast.BinOp):
177
+ return operators[type(node.op)](
178
+ eval_expr(node.left),
179
+ eval_expr(node.right)
180
+ )
181
+ elif isinstance(node, ast.UnaryOp):
182
+ return operators[type(node.op)](eval_expr(node.operand))
183
+ else:
184
+ raise TypeError(node)
185
+
186
+ result = eval_expr(ast.parse(expression, mode='eval').body)
187
+ return {"result": result, "success": True}
188
+
189
+ except Exception as e:
190
+ return {"error": str(e), "success": False}
191
+
192
+ def _python_executor(self, code: str) -> Dict:
193
+ """Python code executor (sandboxed)."""
194
+ try:
195
+ # Basic sandbox - restrict imports and dangerous functions
196
+ restricted_builtins = {
197
+ 'print': print,
198
+ 'range': range,
199
+ 'len': len,
200
+ 'str': str,
201
+ 'int': int,
202
+ 'float': float,
203
+ 'list': list,
204
+ 'dict': dict,
205
+ 'sum': sum,
206
+ 'max': max,
207
+ 'min': min
208
+ }
209
+
210
+ output = []
211
+
212
+ def custom_print(*args, **kwargs):
213
+ output.append(' '.join(str(arg) for arg in args))
214
+
215
+ restricted_builtins['print'] = custom_print
216
+
217
+ exec(code, {"__builtins__": restricted_builtins}, {})
218
+
219
+ return {
220
+ "output": '\n'.join(output),
221
+ "success": True
222
+ }
223
+
224
+ except Exception as e:
225
+ return {"error": str(e), "success": False}
226
+
227
+
228
+ class FunctionCallParser:
229
+ """Parse function calls from model output."""
230
+
231
+ @staticmethod
232
+ def extract_function_calls(text: str) -> List[Dict]:
233
+ """
234
+ Extract function calls from model output.
235
+
236
+ Args:
237
+ text: Model output text
238
+
239
+ Returns:
240
+ List of function call dictionaries
241
+ """
242
+ function_calls = []
243
+
244
+ # Look for JSON function call format
245
+ import re
246
+
247
+ # Pattern: {"function": "name", "parameters": {...}}
248
+ pattern = r'\{["\']function["\']\s*:\s*["\']([^"\']+)["\']\s*,\s*["\']parameters["\']\s*:\s*(\{[^}]+\})\}'
249
+
250
+ matches = re.finditer(pattern, text)
251
+
252
+ for match in matches:
253
+ try:
254
+ func_name = match.group(1)
255
+ params_str = match.group(2)
256
+ params = json.loads(params_str)
257
+
258
+ function_calls.append({
259
+ "function": func_name,
260
+ "parameters": params
261
+ })
262
+ except json.JSONDecodeError:
263
+ continue
264
+
265
+ return function_calls
266
+
267
+ @staticmethod
268
+ def format_function_result(
269
+ function_name: str,
270
+ result: Dict
271
+ ) -> str:
272
+ """Format function result for model."""
273
+ return f"\n[Function {function_name} returned: {json.dumps(result)}]\n"
274
+
275
+
276
+ class HelionToolSystem:
277
+ """
278
+ Complete tool system for Helion-V1.5.
279
+ Manages tool registration, execution, and integration.
280
+ """
281
+
282
+ def __init__(self, model, tokenizer):
283
+ self.model = model
284
+ self.tokenizer = tokenizer
285
+ self.registry = ToolRegistry()
286
+ self.parser = FunctionCallParser()
287
+
288
+ def add_tool(self, tool: Tool):
289
+ """Add a custom tool."""
290
+ self.registry.register(tool)
291
+
292
+ def generate_with_tools(
293
+ self,
294
+ messages: List[Dict[str, str]],
295
+ tools: Optional[List[str]] = None,
296
+ max_iterations: int = 5,
297
+ **kwargs
298
+ ) -> Dict[str, Any]:
299
+ """
300
+ Generate response with tool calling capability.
301
+
302
+ Args:
303
+ messages: Chat messages
304
+ tools: List of tool names to make available (None = all)
305
+ max_iterations: Max tool calling iterations
306
+ **kwargs: Generation parameters
307
+
308
+ Returns:
309
+ Dict with response and tool execution info
310
+ """
311
+ import torch
312
+
313
+ available_tools = tools or self.registry.list_tools()
314
+ tool_schemas = [
315
+ self.registry.get(name).to_openai_format()
316
+ for name in available_tools
317
+ if self.registry.get(name)
318
+ ]
319
+
320
+ # Add tools to system message
321
+ system_msg = {
322
+ "role": "system",
323
+ "content": f"""You have access to the following tools:
324
+
325
+ {json.dumps(tool_schemas, indent=2)}
326
+
327
+ To use a tool, output JSON in this format:
328
+ {{"function": "tool_name", "parameters": {{"param": "value"}}}}
329
+
330
+ After receiving tool results, continue your response."""
331
+ }
332
+
333
+ messages_with_tools = [system_msg] + messages
334
+ tool_calls = []
335
+
336
+ for iteration in range(max_iterations):
337
+ # Generate
338
+ input_ids = self.tokenizer.apply_chat_template(
339
+ messages_with_tools,
340
+ add_generation_prompt=True,
341
+ return_tensors="pt"
342
+ ).to(self.model.device)
343
+
344
+ with torch.no_grad():
345
+ output = self.model.generate(
346
+ input_ids,
347
+ max_new_tokens=kwargs.get('max_new_tokens', 512),
348
+ temperature=kwargs.get('temperature', 0.7),
349
+ top_p=kwargs.get('top_p', 0.9),
350
+ do_sample=kwargs.get('do_sample', True),
351
+ pad_token_id=self.tokenizer.pad_token_id,
352
+ eos_token_id=self.tokenizer.eos_token_id
353
+ )
354
+
355
+ response = self.tokenizer.decode(
356
+ output[0][input_ids.shape[1]:],
357
+ skip_special_tokens=True
358
+ )
359
+
360
+ # Check for function calls
361
+ calls = self.parser.extract_function_calls(response)
362
+
363
+ if not calls:
364
+ # No more function calls, return final response
365
+ return {
366
+ "response": response.strip(),
367
+ "tool_calls": tool_calls,
368
+ "iterations": iteration + 1
369
+ }
370
+
371
+ # Execute function calls
372
+ for call in calls:
373
+ func_name = call["function"]
374
+ params = call["parameters"]
375
+
376
+ tool = self.registry.get(func_name)
377
+ if not tool or not tool.implementation:
378
+ result = {"error": f"Tool {func_name} not found or not executable"}
379
+ else:
380
+ result = tool.implementation(**params)
381
+
382
+ tool_calls.append({
383
+ "function": func_name,
384
+ "parameters": params,
385
+ "result": result
386
+ })
387
+
388
+ # Add result to conversation
389
+ result_msg = self.parser.format_function_result(func_name, result)
390
+ messages_with_tools.append({
391
+ "role": "assistant",
392
+ "content": response
393
+ })
394
+ messages_with_tools.append({
395
+ "role": "system",
396
+ "content": result_msg
397
+ })
398
+
399
+ return {
400
+ "response": "Max iterations reached",
401
+ "tool_calls": tool_calls,
402
+ "iterations": max_iterations
403
+ }
404
+
405
+
406
+ # Example custom tool
407
+ def create_weather_tool() -> Tool:
408
+ """Example: Create a weather lookup tool."""
409
+
410
+ def get_weather(location: str, units: str = "celsius") -> Dict:
411
+ """Mock weather implementation."""
412
+ return {
413
+ "location": location,
414
+ "temperature": 22 if units == "celsius" else 72,
415
+ "conditions": "Partly cloudy",
416
+ "units": units
417
+ }
418
+
419
+ return Tool(
420
+ name="get_weather",
421
+ description="Get current weather for a location",
422
+ parameters=[
423
+ ToolParameter(
424
+ name="location",
425
+ type="string",
426
+ description="City name or location"
427
+ ),
428
+ ToolParameter(
429
+ name="units",
430
+ type="string",
431
+ description="Temperature units",
432
+ required=False,
433
+ enum=["celsius", "fahrenheit"],
434
+ default="celsius"
435
+ )
436
+ ],
437
+ tool_type=ToolType.API,
438
+ implementation=get_weather
439
+ )
440
+
441
+
442
+ if __name__ == "__main__":
443
+ # Demo tool system
444
+ registry = ToolRegistry()
445
+
446
+ print("Registered Tools:")
447
+ print("="*60)
448
+ for tool_name in registry.list_tools():
449
+ tool = registry.get(tool_name)
450
+ print(f"\n{tool.name}:")
451
+ print(f" Description: {tool.description}")
452
+ print(f" Type: {tool.tool_type.value}")
453
+ print(f" Parameters: {[p.name for p in tool.parameters]}")
454
+
455
+ # Test calculator
456
+ print("\n" + "="*60)
457
+ print("Testing Calculator:")
458
+ calc = registry.get("calculator")
459
+ result = calc.implementation(expression="2 + 3 * 4")
460
+ print(f" 2 + 3 * 4 = {result}")
461
+
462
+ # Test code executor
463
+ print("\nTesting Code Executor:")
464
+ executor = registry.get("execute_python")
465
+ result = executor.implementation(code="print('Hello'); print(sum([1,2,3]))")
466
+ print(f" Output: {result}")