| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, Dict, List, Optional, Sequence |
| |
|
| | import torch |
| | import transformers |
| |
|
| | from .constants import IGNORE_INDEX, SENTINEL_TOKEN |
| | from .conversation import SeparatorStyle, default_conversation |
| | from .mm_utils import tokenizer_image_token |
| |
|
| | DUMMY_CONVERSATION = [ |
| | {"from": "human", "value": "question"}, |
| | {"from": "gpt", "value": "answer"}, |
| | ] * 10 |
| |
|
| |
|
| | def tokenize_conversation_legacy( |
| | messages: Sequence[Dict[str, str]], |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | add_generation_prompt: bool = False, |
| | overrides: Optional[Dict[str, str]] = None, |
| | no_system_prompt: bool = False, |
| | ) -> torch.Tensor: |
| | conv = default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| |
|
| | if no_system_prompt: |
| | conv.system = "" |
| |
|
| | |
| | if messages[0]["from"] != "human": |
| | messages = messages[1:] |
| |
|
| | |
| | if add_generation_prompt: |
| | messages.append({"from": "gpt", "value": None}) |
| |
|
| | conv.messages = [] |
| | for turn, message in enumerate(messages): |
| | role = roles[message["from"]] |
| | assert role == conv.roles[turn % 2] |
| | if overrides is not None and message["from"] in overrides: |
| | conv.append_message(role, overrides[message["from"]]) |
| | else: |
| | conv.append_message(role, message["value"]) |
| |
|
| | return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") |
| |
|
| |
|
| | def tokenize_conversation( |
| | messages: Sequence[Dict[str, str]], |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | add_generation_prompt: bool = False, |
| | overrides: Optional[Dict[str, str]] = None, |
| | no_system_prompt: bool = False, |
| | return_ids_only=True, |
| | ) -> torch.Tensor: |
| | |
| | for message in messages: |
| | message["value"] = message["value"].strip() |
| |
|
| | if default_conversation.sep_style != SeparatorStyle.AUTO: |
| | return tokenize_conversation_legacy( |
| | messages, |
| | tokenizer, |
| | add_generation_prompt=add_generation_prompt, |
| | overrides=overrides, |
| | no_system_prompt=no_system_prompt, |
| | ) |
| |
|
| | conversation = [] |
| | for m in messages: |
| | message = {} |
| | if m["from"] == "human": |
| | message["role"] = "user" |
| | elif m["from"] == "gpt": |
| | message["role"] = "assistant" |
| | elif m["from"] == "system": |
| | message["role"] = "system" |
| | if no_system_prompt: |
| | raise ValueError("message[role]=system is not allowed when no_system_prompt is set to True.") |
| | else: |
| | raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") |
| |
|
| | message["content"] = m["value"] |
| | if overrides is not None and m["from"] in overrides: |
| | message["content"] = overrides[m["from"]] |
| | conversation.append(message) |
| |
|
| | if no_system_prompt: |
| | conversation = [{"role": "system", "content": ""}] + conversation |
| |
|
| | text = tokenizer.apply_chat_template( |
| | conversation, |
| | add_generation_prompt=add_generation_prompt, |
| | tokenize=False, |
| | ) |
| | return tokenizer_image_token(text, tokenizer, return_tensors="pt", return_ids=return_ids_only) |
| |
|
| |
|
| | def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: |
| | if not hasattr(tokenizer, "sentinel_token"): |
| | tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) |
| | tokenizer.sentinel_token = SENTINEL_TOKEN |
| | tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) |
| |
|
| |
|
| | def preprocess_conversation( |
| | conversation: Sequence[Dict[str, str]], |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | no_system_prompt: bool = False, |
| | retried: bool = False, |
| | **kwargs: Any, |
| | ) -> Dict[str, Any]: |
| | inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt) |
| | labels = torch.ones_like(inputs) * IGNORE_INDEX |
| |
|
| | |
| | _maybe_add_sentinel_token(tokenizer) |
| | template = tokenize_conversation( |
| | conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt |
| | ) |
| |
|
| | |
| | mask = torch.ones_like(template, dtype=torch.bool) |
| | for k in range(template.size(0) - 1): |
| | if template[k] == tokenizer.sentinel_token_id: |
| | mask[k : k + 2] = False |
| | if k > 0 and retried: |
| | mask[k - 1] = False |
| | template = template[mask] |
| |
|
| | |
| | |
| | p = 0 |
| | for k in range(inputs.size(0)): |
| | if p < template.size(0) and inputs[k] == template[p]: |
| | p += 1 |
| | else: |
| | labels[k] = inputs[k] |
| |
|
| | |
| | if p < template.size(0): |
| | if not retried: |
| | return preprocess_conversation( |
| | conversation, |
| | tokenizer, |
| | no_system_prompt=no_system_prompt, |
| | retried=True, |
| | ) |
| | print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.") |
| | labels[:] = IGNORE_INDEX |
| |
|
| | return {"input_ids": inputs, "labels": labels} |
| |
|
| |
|
| | def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: |
| | _maybe_add_sentinel_token(tokenizer) |
| | template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) |
| |
|
| | stop_tokens = {tokenizer.eos_token} |
| | for k in range(template.size(0) - 1): |
| | if template[k] == tokenizer.sentinel_token_id: |
| | stop_token = tokenizer.decode(template[k + 1]) |
| | stop_tokens.add(stop_token) |
| | return list(stop_tokens) |
| |
|