dongxx1104's picture
Upload folder using huggingface_hub
db704cb verified
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rendering utils.
How to use:
renderer = Renderer(template, processor)
renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs
renderer.parse_message(text: str) -> Message
renderer.process_samples(samples: list[Sample]) -> list[ModelInput]
"""
import numpy as np
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor, Sample
def render_chatml_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
"""
tokenizer = get_tokenizer(processor)
input_ids, labels, loss_weights = [], [], []
for message in messages:
temp_str = "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
if is_generate:
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([0.0] * len(temp_ids))
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ModelInput(
input_ids=input_ids,
attention_mask=[1] * len(input_ids),
labels=labels,
loss_weights=loss_weights,
)
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
class Renderer:
def __init__(self, template: str, processor: Processor):
self.template = template
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
"""Apply template to messages and convert them to model input.
Args:
messages (list[Message]): The messages to render.
tools (str | None, optional): The tools to use. Defaults to None.
is_generate (bool, optional): Whether to render for generation. Defaults to False.
Returns:
ModelInput: The rendered model input.
"""
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
"""Parse a message in the template format.
Args:
generated_text (str): The generated text in the template format.
Returns:
Message: The parsed message.
"""
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)
def process_samples(self, samples: list[Sample]) -> list[ModelInput]:
"""Process samples to model input.
Args:
samples (list[Sample]): The samples to process.
Returns:
list[ModelInput]: The processed model inputs.
"""
model_inputs = []
for sample in samples:
if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools"))
elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
chosen_input["token_type_ids"] = [1] * len(chosen_input["input_ids"])
rejected_input["token_type_ids"] = [2] * len(rejected_input["input_ids"])
model_input = ModelInput(
input_ids=chosen_input["input_ids"] + rejected_input["input_ids"],
attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"],
labels=chosen_input["labels"] + rejected_input["labels"],
loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"],
token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"],
)
if "position_ids" in chosen_input:
model_input["position_ids"] = np.concatenate(
[chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1
)
else:
raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.")
if "extra_info" in sample:
model_input["extra_info"] = sample["extra_info"]
if "_dataset_name" in sample:
model_input["_dataset_name"] = sample["_dataset_name"]
model_inputs.append(model_input)
return model_inputs