| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import os |
| | import os.path as osp |
| | import warnings |
| | from dataclasses import asdict |
| | from typing import Any, Dict, List, Optional, Sequence, Tuple |
| |
|
| | import torch |
| | import transformers |
| | from huggingface_hub import file_exists, repo_exists |
| | from huggingface_hub.utils import HFValidationError |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | PretrainedConfig, |
| | PreTrainedModel, |
| | PreTrainedTokenizer, |
| | ) |
| | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| |
|
| | |
| | from .conversation import SeparatorStyle, default_conversation |
| |
|
| | SENTINEL_TOKEN = "<vila/sentinel>" |
| | MEDIA_TOKENS = { |
| | "image": "<image>", |
| | "video": "<vila/video>", |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | DUMMY_CONVERSATION = [ |
| | {"from": "human", "value": "question"}, |
| | {"from": "gpt", "value": "answer"}, |
| | ] * 10 |
| |
|
| |
|
| | def tokenizer_image_token(prompt, tokenizer, return_tensors=None): |
| | return tokenizer(prompt, return_tensors=return_tensors).input_ids[0] |
| |
|
| |
|
| | def has_tokenizer(repo_id_or_path: str) -> bool: |
| | |
| | if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): |
| | return True |
| |
|
| | |
| | try: |
| | return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json") |
| | except HFValidationError: |
| | return False |
| |
|
| |
|
| | def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: |
| | if not hasattr(tokenizer, "sentinel_token"): |
| | tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) |
| | tokenizer.sentinel_token = SENTINEL_TOKEN |
| | tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) |
| |
|
| |
|
| | def tokenize_conversation_legacy( |
| | messages: Sequence[Dict[str, str]], |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | add_generation_prompt: bool = False, |
| | overrides: Optional[Dict[str, str]] = None, |
| | no_system_prompt: bool = False, |
| | ) -> torch.Tensor: |
| | conv = default_conversation.copy() |
| | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
| |
|
| | if no_system_prompt: |
| | conv.system = "" |
| |
|
| | |
| | if messages[0]["from"] != "human": |
| | messages = messages[1:] |
| |
|
| | |
| | if add_generation_prompt: |
| | messages.append({"from": "gpt", "value": None}) |
| |
|
| | conv.messages = [] |
| | for turn, message in enumerate(messages): |
| | role = roles[message["from"]] |
| | assert role == conv.roles[turn % 2] |
| | if overrides is not None and message["from"] in overrides: |
| | conv.append_message(role, overrides[message["from"]]) |
| | else: |
| | conv.append_message(role, message["value"]) |
| |
|
| | return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") |
| |
|
| |
|
| | def tokenize_conversation( |
| | messages: Sequence[Dict[str, str]], |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | add_generation_prompt: bool = False, |
| | overrides: Optional[Dict[str, str]] = None, |
| | no_system_prompt: bool = False, |
| | ) -> torch.Tensor: |
| | |
| | for message in messages: |
| | message["value"] = message["value"].strip() |
| |
|
| | if default_conversation.sep_style != SeparatorStyle.AUTO: |
| | return tokenize_conversation_legacy( |
| | messages, |
| | tokenizer, |
| | add_generation_prompt=add_generation_prompt, |
| | overrides=overrides, |
| | no_system_prompt=no_system_prompt, |
| | ) |
| |
|
| | conversation = [] |
| | for m in messages: |
| | message = {} |
| | if m["from"] == "human": |
| | message["role"] = "user" |
| | elif m["from"] == "gpt": |
| | message["role"] = "assistant" |
| | else: |
| | raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") |
| |
|
| | message["content"] = m["value"] |
| | if overrides is not None and m["from"] in overrides: |
| | message["content"] = overrides[m["from"]] |
| | conversation.append(message) |
| |
|
| | if no_system_prompt: |
| | conversation = [{"role": "system", "content": ""}] + conversation |
| |
|
| | text = tokenizer.apply_chat_template( |
| | conversation, |
| | add_generation_prompt=add_generation_prompt, |
| | tokenize=False, |
| | ) |
| | return tokenizer_image_token(text, tokenizer, return_tensors="pt") |
| |
|
| |
|
| | def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: |
| | _maybe_add_sentinel_token(tokenizer) |
| | template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) |
| |
|
| | stop_tokens = {tokenizer.eos_token} |
| | for k in range(template.size(0) - 1): |
| | if template[k] == tokenizer.sentinel_token_id: |
| | stop_token = tokenizer.decode(template[k + 1]) |
| | stop_tokens.add(stop_token) |
| | return list(stop_tokens) |
| |
|
| |
|
| | def context_length_extension(config): |
| | orig_ctx_len = getattr(config, "max_position_embeddings", None) |
| | model_max_length = getattr(config, "model_max_length", None) |
| | if orig_ctx_len and model_max_length > orig_ctx_len: |
| | print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") |
| | scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) |
| | config.rope_scaling = {"type": "linear", "factor": scaling_factor} |
| | return config |
| |
|
| |
|
| | def build_llm_and_tokenizer( |
| | model_name_or_path: str, |
| | config: PretrainedConfig, |
| | attn_implementation=None, |
| | model_max_length=None, |
| | *args, |
| | **kwargs, |
| | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: |
| | |
| | llm_cfg = AutoConfig.from_pretrained(model_name_or_path) |
| | llm_cfg._attn_implementation = attn_implementation |
| | llm_cfg.model_max_length = model_max_length |
| | if model_max_length is not None: |
| | context_length_extension(llm_cfg) |
| |
|
| | |
| | quantization_restore_from_checkpoint = False |
| |
|
| | if quantization_restore_from_checkpoint: |
| | fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None) |
| |
|
| | llm = AutoModelForCausalLM.from_pretrained( |
| | fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
| | ) |
| | else: |
| | if is_deepspeed_zero3_enabled(): |
| | |
| | kwargs.pop("device_map") |
| | llm = AutoModelForCausalLM.from_pretrained( |
| | model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
| | ) |
| | |
| |
|
| | |
| | llm_path = model_name_or_path |
| | if not has_tokenizer(llm_path): |
| | llm_path = osp.join(llm_path, "llm") |
| | if not has_tokenizer(llm_path): |
| | raise ValueError(f"Cannot find tokenizer in {llm_path}.") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False) |
| | if model_max_length is not None: |
| | tokenizer.model_max_length = model_max_length |
| |
|
| | |
| | if getattr(config, "chat_template", None) is not None: |
| | print(f"Using chat template: {config.chat_template}") |
| | fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja") |
| | if not os.path.exists(fpath): |
| | fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja") |
| | with open(fpath) as fd: |
| | chat_template = fd.read() |
| | tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "") |
| |
|
| | |
| | tokenizer.stop_tokens = infer_stop_tokens(tokenizer) |
| | tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens) |
| |
|
| | |
| | tokenizer.media_tokens = MEDIA_TOKENS |
| | tokenizer.media_token_ids = {} |
| | for name, token in MEDIA_TOKENS.items(): |
| | tokenizer.add_tokens([token], special_tokens=True) |
| | tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token) |
| |
|
| | config.hidden_size = llm.config.hidden_size |
| | return llm, tokenizer |
| |
|