| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import re |
| |
|
| | from ...utils.constants import IGNORE_INDEX |
| | from ...utils.helper import get_tokenizer |
| | from ...utils.plugin import BasePlugin |
| | from ...utils.types import Message, ModelInput, Processor, ToolCall |
| |
|
| |
|
| | class RenderingPlugin(BasePlugin): |
| | def render_messages( |
| | self, |
| | processor: Processor, |
| | messages: list[Message], |
| | tools: str | None = None, |
| | is_generate: bool = False, |
| | ) -> ModelInput: |
| | """Render messages in the template format.""" |
| | return self["render_messages"](processor, messages, tools, is_generate) |
| |
|
| | def parse_messages(self, generated_text: str) -> Message: |
| | """Parse messages in the template format.""" |
| | return self["parse_messages"](generated_text) |
| |
|
| |
|
| | def _update_model_input( |
| | processor: Processor, |
| | input_ids: list[int], |
| | labels: list[int], |
| | loss_weights: list[int], |
| | temp_str: str, |
| | temp_weight: float, |
| | ) -> str: |
| | """Update model input with temporary string.""" |
| | if not temp_str: |
| | return "" |
| |
|
| | tokenizer = get_tokenizer(processor) |
| | temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) |
| | input_ids.extend(temp_ids) |
| | loss_weights.extend([temp_weight] * len(temp_ids)) |
| | if temp_weight > 1e-6: |
| | labels.extend(temp_ids) |
| | else: |
| | labels.extend([IGNORE_INDEX] * len(temp_ids)) |
| |
|
| | return "" |
| |
|
| |
|
| | @RenderingPlugin("qwen3_nothink").register("render_messages") |
| | def render_qwen3_nothink_messages( |
| | processor: Processor, |
| | messages: list[Message], |
| | tools: str | None = None, |
| | is_generate: bool = False, |
| | ) -> ModelInput: |
| | """Render messages in the Qwen3 nothink template format. |
| | |
| | See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507 |
| | """ |
| | input_ids, labels, loss_weights = [], [], [] |
| | temp_str, temp_weight = "", 0.0 |
| | if tools: |
| | temp_str += "<|im_start|>system\n" |
| | if messages[0]["role"] == "system": |
| | for content in messages[0]["content"]: |
| | if content["type"] == "text": |
| | temp_str += content["value"] |
| | else: |
| | raise ValueError(f"Unsupported content type: {content['type']}") |
| |
|
| | temp_str += "\n\n" |
| | temp_weight = messages[0].get("loss_weight", 0.0) |
| |
|
| | temp_str += ( |
| | "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" |
| | "You are provided with function signatures within <tools></tools> XML tags:\n<tools>" |
| | ) |
| | try: |
| | tools = json.loads(tools) |
| | except json.JSONDecodeError: |
| | raise ValueError(f"Invalid tools format: {str(tools)}.") |
| |
|
| | if not isinstance(tools, list): |
| | tools = [tools] |
| |
|
| | for tool in tools: |
| | temp_str += "\n" + json.dumps(tool, ensure_ascii=False) |
| |
|
| | temp_str += ( |
| | "\n</tools>\n\nFor each function call, return a json object with function name " |
| | 'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": ' |
| | '<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n' |
| | ) |
| | elif messages[0]["role"] == "system": |
| | temp_str += "<|im_start|>system\n" |
| | for content in messages[0]["content"]: |
| | if content["type"] == "text": |
| | temp_str += content["value"] |
| | else: |
| | raise ValueError(f"Unsupported content type: {content['type']}") |
| |
|
| | temp_str += "<|im_end|>\n" |
| | temp_weight = messages[0].get("loss_weight", 0.0) |
| |
|
| | temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) |
| |
|
| | for turn_idx, message in enumerate(messages): |
| | if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): |
| | temp_str += "<|im_start|>" + message["role"] + "\n" |
| | for content in message["content"]: |
| | if content["type"] == "text": |
| | temp_str += content["value"] |
| | else: |
| | raise ValueError(f"Unsupported content type: {content['type']}") |
| |
|
| | temp_str += "<|im_end|>\n" |
| | temp_weight = message.get("loss_weight", 0.0) |
| | elif message["role"] == "assistant": |
| | temp_str += "<|im_start|>" + message["role"] + "\n" |
| | for val_idx, content in enumerate(message["content"]): |
| | if content["type"] == "text": |
| | temp_str += content["value"] |
| | elif content["type"] == "reasoning": |
| | temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" |
| | elif content["type"] == "tool_call": |
| | if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]: |
| | temp_str += "\n" |
| |
|
| | try: |
| | tool_call: ToolCall = json.loads(content["value"]) |
| | except json.JSONDecodeError: |
| | raise ValueError(f"Invalid tool call format: {content['value']}.") |
| |
|
| | temp_str += ( |
| | '<tool_call>\n{"name": "' |
| | + tool_call["name"] |
| | + '", "arguments": ' |
| | + json.dumps(tool_call["arguments"], ensure_ascii=False) |
| | + "}\n</tool_call>" |
| | ) |
| |
|
| | else: |
| | raise ValueError(f"Unsupported content type: {content['type']}") |
| |
|
| | temp_str += "<|im_end|>\n" |
| | temp_weight = message.get("loss_weight", 1.0) |
| | elif message["role"] == "tool": |
| | if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool": |
| | temp_str += "<|im_start|>user" |
| |
|
| | temp_str += "\n<tool_response>\n" |
| | for content in message["content"]: |
| | if content["type"] == "text": |
| | temp_str += content["value"] |
| | else: |
| | raise ValueError(f"Unsupported content type: {content['type']}") |
| |
|
| | temp_str += "\n</tool_response>" |
| | if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool": |
| | temp_str += "<|im_end|>\n" |
| |
|
| | temp_weight = message.get("loss_weight", 0.0) |
| |
|
| | temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) |
| |
|
| | if is_generate: |
| | temp_str += "<|im_start|>assistant\n" |
| | temp_weight = 0.0 |
| |
|
| | temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) |
| |
|
| | attention_mask = [1] * len(input_ids) |
| | return ModelInput( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | labels=labels, |
| | loss_weights=loss_weights, |
| | ) |
| |
|
| |
|
| | @RenderingPlugin("qwen3_nothink").register("parse_message") |
| | def parse_qwen3_nothink_message(generated_text: str) -> Message: |
| | """Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls. |
| | |
| | Args: |
| | generated_text (str): The generated text in the Qwen3 nothink template format. |
| | |
| | Returns: |
| | Message: The parsed message. |
| | """ |
| | pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL) |
| | content = [] |
| | last_end = 0 |
| | for match in pattern.finditer(generated_text): |
| | start, end = match.span() |
| | if start > last_end: |
| | text = generated_text[last_end:start].strip() |
| | if text: |
| | content.append({"type": "text", "value": text}) |
| |
|
| | tag_type = match.group(1) |
| | tag_value = match.group(2).strip() |
| | if tag_type == "thinking": |
| | content.append({"type": "reasoning", "value": tag_value.strip()}) |
| | elif tag_type == "tool_call": |
| | try: |
| | json.loads(tag_value.strip()) |
| | except json.JSONDecodeError: |
| | raise ValueError(f"Invalid tool call format: {tag_value.strip()}.") |
| |
|
| | content.append({"type": "tool_call", "value": tag_value.strip()}) |
| |
|
| | last_end = end |
| |
|
| | if last_end < len(generated_text): |
| | text = generated_text[last_end:].strip() |
| | if text: |
| | content.append({"type": "text", "value": text}) |
| |
|
| | return Message(role="assistant", content=content) |
| |
|