| | import logging |
| | from dataclasses import dataclass |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | from mistral_common.protocol.instruct.messages import ( |
| | FinetuningAssistantMessage, |
| | Roles, |
| | SystemMessage, |
| | ToolMessage, |
| | UserMessage, |
| | ) |
| | from mistral_common.protocol.instruct.tool_calls import ( |
| | Function, |
| | FunctionCall, |
| | Tool, |
| | ToolCall, |
| | ) |
| | from mistral_common.protocol.instruct.validator import ( |
| | MistralRequestValidatorV3, |
| | ValidationMode, |
| | ) |
| | from mistral_common.tokens.instruct.request import InstructRequest |
| | from mistral_common.tokens.tokenizers.base import Tokenizer |
| | from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase |
| |
|
| | from .exceptions import ( |
| | ConversationFormatError, |
| | FunctionFormatError, |
| | MessageFormatError, |
| | ToolCallFormatError, |
| | UnrecognizedRoleError, |
| | ) |
| |
|
| | logger = logging.getLogger("tokenize") |
| |
|
| | Sequence = List[int] |
| | Mask = List[bool] |
| |
|
| |
|
| | class TrainingInstructSample(InstructRequest): |
| | available_tools: Optional[List[Tool]] = None |
| | only_last: bool = False |
| |
|
| |
|
| | @dataclass() |
| | class TokenSample: |
| | tokens: Sequence |
| | masks: Mask |
| |
|
| |
|
| | class SampleType(str, Enum): |
| | PRETRAIN = "pretrain" |
| | INSTRUCT = "instruct" |
| |
|
| |
|
| | def encode( |
| | data: Dict[str, Any], |
| | instruct_tokenizer: InstructTokenizerBase, |
| | as_type: SampleType, |
| | ) -> TokenSample: |
| | sample: Union[str, TrainingInstructSample] |
| | if as_type == SampleType.PRETRAIN: |
| | sample = get_pretrain_sample(data) |
| | elif as_type == SampleType.INSTRUCT: |
| | sample = build_instruct_sample(data) |
| |
|
| | return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer) |
| |
|
| |
|
| | def get_pretrain_sample(data: Dict[str, Any]) -> str: |
| | content_keys = ["text", "content"] |
| | assert not all( |
| | k in data for k in content_keys |
| | ), "Make sure to have either 'text' or 'content' in your data. Not both." |
| | assert any( |
| | data.get(k) is not None for k in content_keys |
| | ), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}" |
| |
|
| | |
| | sample = None |
| | for key in content_keys: |
| | sample = data[key] if key in data else sample |
| |
|
| | assert isinstance(sample, str), sample |
| |
|
| | return sample |
| |
|
| |
|
| | def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample: |
| | messages: List[ |
| | SystemMessage | UserMessage | FinetuningAssistantMessage | ToolMessage |
| | ] = [] |
| | |
| | available_tools: Optional[List[Tool]] = data.get("available_tools") |
| | system_prompt = data.get("system_prompt") |
| |
|
| | messages_keys = ["messages", "interactions"] |
| | content_keys = ["content", "text"] |
| | allowed_roles = [role.value for role in Roles] |
| |
|
| | if not any(messages_key in data for messages_key in messages_keys): |
| | err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'." |
| | raise ConversationFormatError(err, str(data)) |
| |
|
| | if all(messages_key in data for messages_key in messages_keys): |
| | err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two." |
| | raise ConversationFormatError(err, str(data)) |
| |
|
| | |
| | data_messages: Optional[List[Dict[str, Any]]] = None |
| | for key in messages_keys: |
| | data_messages = data[key] if key in data else data_messages |
| |
|
| | assert data_messages is not None, "data_messages can't be None" |
| |
|
| | if "available_tools" in data and "tools" in data: |
| | err = "The conversation contains both an `available_tools` and `tools` key. You can only have one." |
| | raise ConversationFormatError(err, str(data)) |
| |
|
| | if data.get("tools", None) is not None and len(data["tools"]) > 0: |
| | available_tools = _parse_available_tools(data["tools"]) |
| | elif ( |
| | data.get("available_tools", None) is not None |
| | and len(data["available_tools"]) > 0 |
| | ): |
| | available_tools = _parse_available_tools(data["available_tools"]) |
| |
|
| | for data_message in data_messages: |
| | is_tool_call = data_message.get("tool_calls") is not None |
| |
|
| | if "role" not in data_message: |
| | err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'." |
| | raise MessageFormatError(err, str(data)) |
| |
|
| | role = data_message["role"] |
| |
|
| | if all(key in data_message for key in content_keys): |
| | err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two." |
| | raise MessageFormatError(err, str(data)) |
| |
|
| | content: Optional[str] = None |
| | for key in content_keys: |
| | content = content if content is not None else data_message.get(key) |
| |
|
| | |
| | if not is_tool_call and content is None: |
| | err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys." |
| | raise MessageFormatError(err, str(data)) |
| |
|
| | if role not in allowed_roles: |
| | raise UnrecognizedRoleError(role, allowed_roles) |
| |
|
| | if data_message["role"] == "user": |
| | assert content is not None |
| | messages.append(UserMessage(content=content)) |
| | elif data_message["role"] == "assistant": |
| | tool_calls: Optional[List[ToolCall]] = None |
| |
|
| | if is_tool_call: |
| | tool_calls = _parse_tool_calls(data_message["tool_calls"]) |
| |
|
| | weight = data_message.get("weight") |
| | messages.append( |
| | FinetuningAssistantMessage( |
| | content=content, tool_calls=tool_calls, weight=weight |
| | ) |
| | ) |
| | elif data_message["role"] == "system": |
| | if system_prompt is not None: |
| | err = "Multiple messages with role 'system' encountered. Only one is allowed." |
| | raise MessageFormatError(err, str(data)) |
| |
|
| | system_prompt = content |
| | elif data_message["role"] == "tool": |
| | assert content is not None |
| | tool_message = _parse_tool_message(content, data_message) |
| | messages.append(tool_message) |
| |
|
| | |
| | validator = MistralRequestValidatorV3(ValidationMode.finetuning) |
| | validator.validate_messages(messages) |
| | validator._validate_tools(available_tools or []) |
| |
|
| | |
| | only_last = data.get("only_last", False) or available_tools is not None |
| |
|
| | return TrainingInstructSample( |
| | messages=messages, |
| | system_prompt=system_prompt, |
| | available_tools=available_tools, |
| | only_last=only_last, |
| | ) |
| |
|
| |
|
| | def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]: |
| | available_tools = [] |
| | for tool in tools: |
| | if "function" not in tool: |
| | raise FunctionFormatError( |
| | "A tool dict does not have a 'function' key.", str(tool) |
| | ) |
| |
|
| | func_data = tool["function"] |
| |
|
| | for key in ["name", "description", "parameters"]: |
| | if key not in func_data: |
| | raise FunctionFormatError( |
| | f"A function dict does not have a {key} key.", str(func_data) |
| | ) |
| |
|
| | if not isinstance(func_data["parameters"], dict): |
| | raise FunctionFormatError( |
| | f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empyt dict ", str(func_data) |
| | ) |
| |
|
| | description = func_data["description"] |
| | function = Function( |
| | name=func_data["name"], |
| | description=description, |
| | parameters=func_data["parameters"], |
| | ) |
| |
|
| | available_tools.append(Tool(function=function)) |
| | return available_tools |
| |
|
| |
|
| | def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]: |
| | for key in ["id", "function"]: |
| | if not all(key in call for call in calls): |
| | err = f"A tool call of an assistant message does not have a {key} key" |
| | raise ToolCallFormatError(err, str(calls)) |
| |
|
| | for key in ["name", "arguments"]: |
| | if not all(key in call["function"] for call in calls): |
| | err = ( |
| | f"A tool call function of an assistant message does not have a {key} key" |
| | ) |
| | raise ToolCallFormatError(err, str(calls)) |
| |
|
| | if not all(isinstance(call["function"]["arguments"], str) for call in calls): |
| | err = "A tool call function of an assistant message does not have a 'arguments' key of type str" |
| | raise ToolCallFormatError(err, str(calls)) |
| |
|
| | tool_calls = [ |
| | ToolCall( |
| | id=call["id"], |
| | function=FunctionCall( |
| | name=call["function"]["name"], |
| | arguments=call["function"]["arguments"], |
| | ), |
| | ) |
| | for call in calls |
| | ] |
| | return tool_calls |
| |
|
| |
|
| | def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage: |
| | if "tool_call_id" not in data_message: |
| | err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'." |
| | raise MessageFormatError(err, str(data_message)) |
| |
|
| | call_id = data_message["tool_call_id"] |
| | |
| | name = data_message.get("name") |
| |
|
| | return ToolMessage(content=content, tool_call_id=call_id, name=name) |
| |
|
| |
|
| | def tokenize( |
| | sample: Union[str, TrainingInstructSample], |
| | instruct_tokenizer: InstructTokenizerBase, |
| | ) -> TokenSample: |
| | if isinstance(sample, str): |
| | tokenizer: Tokenizer = instruct_tokenizer.tokenizer |
| | return tokenize_pretrain(sample, tokenizer) |
| | elif isinstance(sample, TrainingInstructSample): |
| | return tokenize_instruct(sample, instruct_tokenizer) |
| |
|
| | raise ValueError( |
| | f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}." |
| | ) |
| |
|
| |
|
| | def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample: |
| | tokens = tokenizer.encode(sample, bos=True, eos=True) |
| | masks = [True] * len(tokens) |
| | return TokenSample(tokens, masks) |
| |
|
| |
|
| | def tokenize_instruct( |
| | sample: TrainingInstructSample, |
| | instruct_tokenizer: InstructTokenizerBase, |
| | ) -> TokenSample: |
| | tokens: List[int] = instruct_tokenizer.start() |
| | masks: List[bool] = [False] |
| |
|
| | mask_all_but_last = sample.only_last |
| |
|
| | |
| | user_messages = [ |
| | i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage) |
| | ] |
| | first_user_idx = user_messages[0] if user_messages else -1 |
| | last_user_idx = user_messages[-1] if user_messages else -1 |
| |
|
| | for msg_idx, message in enumerate(sample.messages): |
| | if isinstance(message, UserMessage): |
| | curr_tokens = instruct_tokenizer.encode_user_message( |
| | message, |
| | available_tools=sample.available_tools, |
| | is_last=msg_idx == last_user_idx, |
| | is_first=msg_idx == first_user_idx, |
| | system_prompt=sample.system_prompt, |
| | ) |
| | curr_masks = [False] * len(curr_tokens) |
| | elif isinstance(message, ToolMessage): |
| | curr_tokens = instruct_tokenizer.encode_tool_message( |
| | message, is_before_last_user_message=msg_idx < last_user_idx |
| | ) |
| | curr_masks = [False] * len(curr_tokens) |
| | elif isinstance(message, FinetuningAssistantMessage): |
| | is_last_message = msg_idx == (len(sample.messages) - 1) |
| |
|
| | |
| | message = maybe_remove_call_id(message, is_last_message=is_last_message) |
| |
|
| | curr_tokens = instruct_tokenizer.encode_assistant_message( |
| | message, is_before_last_user_message=False |
| | ) |
| |
|
| | is_weighted = message.weight is None or message.weight == 1 |
| | is_relevant = (not mask_all_but_last) or is_last_message |
| | if is_weighted and is_relevant: |
| | curr_masks = [True] * len(curr_tokens) |
| | else: |
| | |
| | curr_masks = [False] * len(curr_tokens) |
| |
|
| | tokens.extend(curr_tokens) |
| | masks.extend(curr_masks) |
| |
|
| | return TokenSample(tokens, masks) |
| |
|
| |
|
| | def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool): |
| | if message.tool_calls is None or not is_last_message: |
| | return message |
| |
|
| | |
| | message.tool_calls = [ |
| | ToolCall(function=call.function) for call in message.tool_calls |
| | ] |
| |
|
| | return message |
| |
|