Spaces:
Sleeping
Sleeping
| package tools | |
| import ( | |
| "context" | |
| "fmt" | |
| "sync" | |
| "time" | |
| "github.com/sipeed/picoclaw/pkg/logger" | |
| "github.com/sipeed/picoclaw/pkg/metrics" | |
| "github.com/sipeed/picoclaw/pkg/providers" | |
| ) | |
| type ToolRegistry struct { | |
| tools map[string]Tool | |
| breakers map[string]*CircuitBreaker | |
| mu sync.RWMutex | |
| cache *ToolCache | |
| } | |
| func NewToolRegistry() *ToolRegistry { | |
| return &ToolRegistry{ | |
| tools: make(map[string]Tool), | |
| breakers: make(map[string]*CircuitBreaker), | |
| cache: NewToolCache(5 * time.Minute), | |
| } | |
| } | |
| func (r *ToolRegistry) Register(tool Tool) { | |
| r.mu.Lock() | |
| defer r.mu.Unlock() | |
| r.tools[tool.Name()] = tool | |
| } | |
| func (r *ToolRegistry) Get(name string) (Tool, bool) { | |
| r.mu.RLock() | |
| defer r.mu.RUnlock() | |
| tool, ok := r.tools[name] | |
| return tool, ok | |
| } | |
| func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult { | |
| return r.ExecuteWithContext(ctx, name, args, "", "", nil) | |
| } | |
| // ExecuteWithContext executes a tool with channel/chatID context and optional async callback. | |
| // If the tool implements AsyncTool and a non-nil callback is provided, | |
| // the callback will be set on the tool before execution. | |
| func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult { | |
| logger.InfoCF("tool", "Tool execution started", | |
| map[string]interface{}{ | |
| "tool": name, | |
| "args": args, | |
| }) | |
| tool, ok := r.Get(name) | |
| if !ok { | |
| logger.ErrorCF("tool", "Tool not found", | |
| map[string]interface{}{ | |
| "tool": name, | |
| }) | |
| return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) | |
| } | |
| // Check Circuit Breaker | |
| r.mu.Lock() | |
| if r.breakers[name] == nil { | |
| r.breakers[name] = NewCircuitBreaker(5, 1*time.Minute) | |
| } | |
| breaker := r.breakers[name] | |
| r.mu.Unlock() | |
| if !breaker.AllowRequest() { | |
| logger.WarnCF("tool", "Circuit breaker open for tool", | |
| map[string]interface{}{ | |
| "tool": name, | |
| }) | |
| return ErrorResult("Circuit breaker open due to multiple failures").WithError(fmt.Errorf("circuit breaker open")) | |
| } | |
| // If tool implements ContextualTool, set context | |
| if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" { | |
| contextualTool.SetContext(channel, chatID) | |
| } | |
| // If tool implements AsyncTool and callback is provided, set callback | |
| if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil { | |
| asyncTool.SetCallback(asyncCallback) | |
| logger.DebugCF("tool", "Async callback injected", | |
| map[string]interface{}{ | |
| "tool": name, | |
| }) | |
| } | |
| // Check cache for idempotent tools | |
| var cacheKey string | |
| if r.cache != nil && IsCacheable(name) { | |
| cacheKey = r.cache.GenerateKey(name, args) | |
| if result, found := r.cache.Get(cacheKey); found { | |
| logger.InfoCF("tool", "Tool result retrieved from cache", | |
| map[string]interface{}{ | |
| "tool": name, | |
| }) | |
| return result | |
| } | |
| } | |
| start := time.Now() | |
| result := tool.Execute(ctx, args) | |
| duration := time.Since(start) | |
| if result.IsError { | |
| breaker.RecordFailure() | |
| } else { | |
| breaker.RecordSuccess() | |
| // Cache successful results for idempotent tools | |
| if cacheKey != "" && r.cache != nil { | |
| r.cache.Set(cacheKey, result, 0) // Use default TTL | |
| } | |
| } | |
| // Log based on result type | |
| status := "success" | |
| if result.IsError { | |
| status = "error" | |
| logger.ErrorCF("tool", "Tool execution failed", | |
| map[string]interface{}{ | |
| "tool": name, | |
| "duration": duration.Milliseconds(), | |
| "error": result.ForLLM, | |
| }) | |
| } else if result.Async { | |
| logger.InfoCF("tool", "Tool started (async)", | |
| map[string]interface{}{ | |
| "tool": name, | |
| "duration": duration.Milliseconds(), | |
| }) | |
| } else { | |
| logger.InfoCF("tool", "Tool execution completed", | |
| map[string]interface{}{ | |
| "tool": name, | |
| "duration_ms": duration.Milliseconds(), | |
| "result_length": len(result.ForLLM), | |
| }) | |
| } | |
| metrics.ToolExecutions.WithLabelValues(name, status).Inc() | |
| return result | |
| } | |
| func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { | |
| r.mu.RLock() | |
| defer r.mu.RUnlock() | |
| definitions := make([]map[string]interface{}, 0, len(r.tools)) | |
| for _, tool := range r.tools { | |
| definitions = append(definitions, ToolToSchema(tool)) | |
| } | |
| return definitions | |
| } | |
| // ToProviderDefs converts tool definitions to provider-compatible format. | |
| // This is the format expected by LLM provider APIs. | |
| func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { | |
| r.mu.RLock() | |
| defer r.mu.RUnlock() | |
| definitions := make([]providers.ToolDefinition, 0, len(r.tools)) | |
| for _, tool := range r.tools { | |
| schema := ToolToSchema(tool) | |
| // Safely extract nested values with type checks | |
| fn, ok := schema["function"].(map[string]interface{}) | |
| if !ok { | |
| continue | |
| } | |
| name, _ := fn["name"].(string) | |
| desc, _ := fn["description"].(string) | |
| params, _ := fn["parameters"].(map[string]interface{}) | |
| definitions = append(definitions, providers.ToolDefinition{ | |
| Type: "function", | |
| Function: providers.ToolFunctionDefinition{ | |
| Name: name, | |
| Description: desc, | |
| Parameters: params, | |
| }, | |
| }) | |
| } | |
| return definitions | |
| } | |
| // List returns a list of all registered tool names. | |
| func (r *ToolRegistry) List() []string { | |
| r.mu.RLock() | |
| defer r.mu.RUnlock() | |
| names := make([]string, 0, len(r.tools)) | |
| for name := range r.tools { | |
| names = append(names, name) | |
| } | |
| return names | |
| } | |
| // Count returns the number of registered tools. | |
| func (r *ToolRegistry) Count() int { | |
| r.mu.RLock() | |
| defer r.mu.RUnlock() | |
| return len(r.tools) | |
| } | |
| // GetSummaries returns human-readable summaries of all registered tools. | |
| // Returns a slice of "name - description" strings. | |
| func (r *ToolRegistry) GetSummaries() []string { | |
| r.mu.RLock() | |
| defer r.mu.RUnlock() | |
| summaries := make([]string, 0, len(r.tools)) | |
| for _, tool := range r.tools { | |
| summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description())) | |
| } | |
| return summaries | |
| } | |