Muqeeth's picture
Add files using upload-large-folder tool
1c8c60e verified
import logging
import sys
import regex
import torch
from transformers import AutoTokenizer
from mllm.training.training_data_utils import TrainingChatTurn, TrajectoryBatch
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(sys.stdout))
# def get_chat_dicts(chat: list[TrainingChatTurn]) -> list[dict]:
# chat_dicts = [chat_turn.dict() for chat_turn in chat]
# return chat_dicts
def process_training_chat(
tokenizer: AutoTokenizer,
chat_history: list[TrainingChatTurn],
entropy_mask_regex: str | None = None,
exploration_prompts_to_remove: list[str] = [],
use_engine_out_token_ids: bool = False,
) -> tuple[torch.IntTensor, torch.BoolTensor, torch.IntTensor, torch.BoolTensor]:
"""Tokenize a single training chat and build aligned per-token masks.
Given an ordered list of `TrainingChatTurn`, this function tokenizes each
turn independently using the tokenizer's chat template, then concatenates
all resulting token sequences. It also constructs three parallel 1D masks
that align with the concatenated tokens:
- input_ids: token ids for the entire chat, turn by turn
- action_mask: True for tokens that belong to assistant turns (i.e., model
actions), False for tokens from other roles
- timesteps: per-token time step copied from the originating turn's
`time_step`
- state_ends_mask: True for the last token of any turn where
`is_state_end` is True, otherwise False
Important details:
- Each turn is passed as a single-message list to
`tokenizer.apply_chat_template` and flattened; the per-turn outputs are
then concatenated in the original order.
- Turn boundaries are not explicitly encoded beyond what the chat template
inserts; masks provide alignment for learning signals and state endings.
- No truncation or padding is performed here; downstream code should handle
batching/padding as needed.
- Note on dtypes: `input_ids` will be a LongTensor (int64). `action_mask`
and `state_ends_mask` are BoolTensors. `timesteps` is currently created
as a float tensor; adjust the implementation if integer dtype is
required downstream.
Args:
tokenizer: A Hugging Face tokenizer supporting `apply_chat_template`.
chat_history: Ordered list of `TrainingChatTurn` forming one dialogue.
Returns:
A tuple of four 1D tensors, all of equal length N (the total number of
tokens across all turns), in the following order:
- input_ids (LongTensor)
- action_mask (BoolTensor)
- timesteps (FloatTensor as implemented; see note above)
- state_ends_mask (BoolTensor)
"""
state_ends_mask = []
input_ids = []
action_mask = []
timesteps = []
entropy_mask = []
engine_log_probs = []
for train_chat_turn in chat_history:
is_state_end = train_chat_turn.is_state_end
time_step = train_chat_turn.time_step
is_action = train_chat_turn.role == "assistant"
# Remove exploration prompts from training data
for exploration_prompt in exploration_prompts_to_remove:
if exploration_prompt in train_chat_turn.content:
train_chat_turn.content = train_chat_turn.content.replace(
exploration_prompt, ""
)
chat_turn = {
"role": train_chat_turn.role,
"content": train_chat_turn.content,
}
if entropy_mask_regex is not None:
is_entropy_mask_true = (
regex.search(entropy_mask_regex, train_chat_turn.content) is not None
)
else:
is_entropy_mask_true = True
if is_action:
chat_turn_ids = train_chat_turn.out_token_ids
nb_chat_turns_ids = chat_turn_ids.numel()
action_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
engine_log_probs.append(train_chat_turn.log_probs)
else:
chat_turn_ids = train_chat_turn.chat_template_token_ids
nb_chat_turns_ids = chat_turn_ids.numel()
action_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
engine_log_probs.append(torch.zeros(nb_chat_turns_ids, dtype=torch.float))
nb_chat_turns_ids = chat_turn_ids.numel()
state_ends_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
if is_state_end:
state_ends_mask[-1][-1] = True # last token is state end
input_ids.append(chat_turn_ids)
entropy_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
if not is_entropy_mask_true:
entropy_mask[-1] = entropy_mask[-1] * False
timesteps.append(torch.ones(nb_chat_turns_ids) * time_step)
input_ids = torch.cat(input_ids)
action_mask = torch.cat(action_mask)
entropy_mask = torch.cat(entropy_mask)
timesteps = torch.cat(timesteps)
timesteps = timesteps.to(torch.long)
state_ends_mask = torch.cat(state_ends_mask)
engine_log_probs = torch.cat(engine_log_probs)
return (
input_ids,
action_mask,
entropy_mask,
timesteps,
state_ends_mask,
engine_log_probs,
)