| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Rendering utils. |
| | |
| | How to use: |
| | renderer = Renderer(template, processor) |
| | renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs |
| | renderer.parse_message(text: str) -> Message |
| | renderer.process_samples(samples: list[Sample]) -> list[ModelInput] |
| | """ |
| |
|
| | import numpy as np |
| |
|
| | from ...utils.constants import IGNORE_INDEX |
| | from ...utils.helper import get_tokenizer |
| | from ...utils.types import Message, ModelInput, Processor, Sample |
| |
|
| |
|
| | def render_chatml_messages( |
| | processor: Processor, |
| | messages: list[Message], |
| | tools: str | None = None, |
| | is_generate: bool = False, |
| | ) -> ModelInput: |
| | """Apply chatml template to messages and convert them to model input. |
| | |
| | See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct |
| | """ |
| | tokenizer = get_tokenizer(processor) |
| | input_ids, labels, loss_weights = [], [], [] |
| |
|
| | for message in messages: |
| | temp_str = "<|im_start|>" + message["role"] + "\n" |
| | for content in message["content"]: |
| | if content["type"] == "text": |
| | temp_str += content["value"] |
| | else: |
| | raise ValueError(f"Unsupported content type: {content['type']}") |
| |
|
| | temp_str += "<|im_end|>\n" |
| | temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0) |
| | temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) |
| | input_ids.extend(temp_ids) |
| | loss_weights.extend([temp_weight] * len(temp_ids)) |
| | if temp_weight > 1e-6: |
| | labels.extend(temp_ids) |
| | else: |
| | labels.extend([IGNORE_INDEX] * len(temp_ids)) |
| |
|
| | if is_generate: |
| | temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False) |
| | input_ids.extend(temp_ids) |
| | loss_weights.extend([0.0] * len(temp_ids)) |
| | labels.extend([IGNORE_INDEX] * len(temp_ids)) |
| |
|
| | return ModelInput( |
| | input_ids=input_ids, |
| | attention_mask=[1] * len(input_ids), |
| | labels=labels, |
| | loss_weights=loss_weights, |
| | ) |
| |
|
| |
|
| | def parse_chatml_message(generated_text: str) -> Message: |
| | """Parse a message in ChatML format. |
| | |
| | Args: |
| | generated_text (str): The generated text in ChatML format. |
| | |
| | Returns: |
| | Message: The parsed message. |
| | """ |
| | return Message(role="assistant", content=[{"type": "text", "value": generated_text}]) |
| |
|
| |
|
| | class Renderer: |
| | def __init__(self, template: str, processor: Processor): |
| | self.template = template |
| | self.processor = processor |
| |
|
| | def render_messages( |
| | self, messages: list[Message], tools: str | None = None, is_generate: bool = False |
| | ) -> ModelInput: |
| | """Apply template to messages and convert them to model input. |
| | |
| | Args: |
| | messages (list[Message]): The messages to render. |
| | tools (str | None, optional): The tools to use. Defaults to None. |
| | is_generate (bool, optional): Whether to render for generation. Defaults to False. |
| | |
| | Returns: |
| | ModelInput: The rendered model input. |
| | """ |
| | if self.template == "chatml": |
| | return render_chatml_messages(self.processor, messages, tools, is_generate) |
| | else: |
| | from ...plugins.model_plugins.rendering import RenderingPlugin |
| |
|
| | return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate) |
| |
|
| | def parse_message(self, generated_text: str) -> Message: |
| | """Parse a message in the template format. |
| | |
| | Args: |
| | generated_text (str): The generated text in the template format. |
| | |
| | Returns: |
| | Message: The parsed message. |
| | """ |
| | if self.template == "chatml": |
| | return parse_chatml_message(generated_text) |
| | else: |
| | from ...plugins.model_plugins.rendering import RenderingPlugin |
| |
|
| | return RenderingPlugin(self.template).parse_message(generated_text) |
| |
|
| | def process_samples(self, samples: list[Sample]) -> list[ModelInput]: |
| | """Process samples to model input. |
| | |
| | Args: |
| | samples (list[Sample]): The samples to process. |
| | |
| | Returns: |
| | list[ModelInput]: The processed model inputs. |
| | """ |
| | model_inputs = [] |
| | for sample in samples: |
| | if "messages" in sample: |
| | model_input = self.render_messages(sample["messages"], sample.get("tools")) |
| | elif "chosen_messages" in sample and "rejected_messages" in sample: |
| | chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools")) |
| | rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools")) |
| | chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"]) |
| | rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"]) |
| | model_input = ModelInput( |
| | input_ids=chosen_input["input_ids"] + rejected_input["input_ids"], |
| | attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"], |
| | labels=chosen_input["labels"] + rejected_input["labels"], |
| | loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"], |
| | token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"], |
| | ) |
| | if "position_ids" in chosen_input: |
| | model_input["position_ids"] = np.concatenate( |
| | [chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1 |
| | ) |
| | else: |
| | raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.") |
| |
|
| | if "extra_info" in sample: |
| | model_input["extra_info"] = sample["extra_info"] |
| |
|
| | if "_dataset_name" in sample: |
| | model_input["_dataset_name"] = sample["_dataset_name"] |
| |
|
| | model_inputs.append(model_input) |
| |
|
| | return model_inputs |
| |
|