diff --git a/common/chat.cpp b/common/chat.cpp index 9639af9..971e6d6 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1979,6 +1979,109 @@ static common_chat_params common_chat_params_init_deepseek_v3_2(const common_cha return data; } +// Cohere Command (Cohere2 / command-a) tool-call + reasoning format: +// <|START_THINKING|>...plan...<|END_THINKING|> +// <|START_ACTION|>[{"tool_call_id":"0","tool_name":"X","parameters":{...}}]<|END_ACTION|> +// <|START_RESPONSE|>...text...<|END_RESPONSE|> +// The action block is a JSON array of objects keyed by tool_name/parameters, with a +// generated integer tool_call_id, which maps directly onto standard_json_tools(). +static common_chat_params common_chat_params_init_cohere2(const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + common_chat_params data; + + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); + data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.thinking_start_tag = "<|START_THINKING|>"; + data.thinking_end_tag = "<|END_THINKING|>"; + data.preserved_tokens = { + "<|START_THINKING|>", "<|END_THINKING|>", + "<|START_ACTION|>", "<|END_ACTION|>", + "<|START_RESPONSE|>", "<|END_RESPONSE|>", + }; + + const std::string THINK_START = "<|START_THINKING|>"; + const std::string THINK_END = "<|END_THINKING|>"; + const std::string ACTION_START = "<|START_ACTION|>"; + const std::string ACTION_END = "<|END_ACTION|>"; + const std::string RESP_START = "<|START_RESPONSE|>"; + const std::string RESP_END = "<|END_RESPONSE|>"; + + const bool has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + const bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + const bool include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + + // With tools the template leaves generation open after <|CHATBOT_TOKEN|>; without tools + // it primes <|START_RESPONSE|> directly. + const bool response_primed = + data.generation_prompt.size() >= RESP_START.size() && + data.generation_prompt.compare(data.generation_prompt.size() - RESP_START.size(), + RESP_START.size(), RESP_START) == 0; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto gen = p.literal(data.generation_prompt); + auto end = p.end(); + + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + if (response_primed) { + return gen + p.content(p.until(RESP_END)) + p.literal(RESP_END) + end; + } + return gen + p.choice({ + p.literal(RESP_START) + p.content(p.until(RESP_END)) + p.literal(RESP_END), + p.content(p.rest()), + }) + end; + } + + auto reasoning = p.eps(); + if (extract_reasoning && inputs.enable_thinking) { + reasoning = p.optional(p.literal(THINK_START) + p.reasoning(p.until(THINK_END)) + p.literal(THINK_END)); + } else if (extract_reasoning) { + reasoning = p.optional(p.literal(THINK_START) + p.until(THINK_END) + p.literal(THINK_END)); + } + + auto tools_parser = p.standard_json_tools( + ACTION_START, ACTION_END, inputs.tools, + /* parallel_tool_calls = */ true, + /* force_tool_calls = */ inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, + /* name_key = */ "tool_name", + /* args_key = */ "parameters", + /* array_wrapped = */ true, + /* function_is_key = */ false, + /* call_id_key = */ "", + /* gen_call_id_key = */ "tool_call_id", + /* parameters_order = */ {}); + + if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) { + return gen + reasoning + p.space() + tools_parser + end; + } + // Content-first, then optional tool action (avoids non-backtracking choice on the + // tool-section trigger). The response may be <|START_RESPONSE|>-wrapped or bare; + // strip a leading wrapper token and let content run up to any tool action. + auto content_before = p.optional(p.literal(RESP_START)) + p.content(p.until(ACTION_START)); + return gen + reasoning + p.space() + content_before + p.optional(tools_parser) + end; + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.contains("parameters") ? function.at("parameters") : json::object(); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ACTION_START }, + }; + } + + return data; +} + namespace workaround { static void map_developer_role_to_system(json & messages) { @@ -2256,6 +2359,13 @@ std::optional common_chat_try_specialized_template( return common_chat_params_init_deepseek_v3_2(tmpl, params); } + // Cohere Command (Cohere2 / command-a) format detection: thinking + action-array tool calls. + if (src.find("<|START_ACTION|>") != std::string::npos && + src.find("<|START_THINKING|>") != std::string::npos && + src.find("tool_name") != std::string::npos) { + return common_chat_params_init_cohere2(tmpl, params); + } + // Gemma4 format detection if (src.find("'<|tool_call>call:'") != std::string::npos) { if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) {