""" Model and tokenizer initialization """ import torch from typing import List, Set, Tuple from transformers import AutoModelForCausalLM, AutoTokenizer from unsloth import FastLanguageModel from config import ( MODEL_NAME, MAX_SEQ_LEN, DTYPE, LORA_R, LORA_ALPHA, LORA_DROPOUT, LORA_TARGET_MODULES, LORA_MODULES_TO_SAVE, PAD_TOKEN, M_START, M_END ) # ====================================================================================== # Logic from test_overfit.py (Standard Transformers) # ====================================================================================== def setup_model_and_tokenizer_raw(model_name: str, motion_tokens: List[str]) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: """Loads the model and tokenizer, adding special and motion tokens (Standard Transformers).""" print(f"\n---> Loading base model and tokenizer: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) # Add special tokens (matches test_overfit.py) tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]}) print(f"Adding {len(motion_tokens)} motion tokens to the tokenizer.") tokenizer.add_tokens(motion_tokens, special_tokens=True) model.resize_token_embeddings(len(tokenizer)) model.config.pad_token_id = tokenizer.pad_token_id return model, tokenizer def ensure_tokenizer_has_motion_tokens(tokenizer: AutoTokenizer, motion_tokens: List[str]) -> int: """ Adds any missing motion tokens to the tokenizer. Returns number of tokens added. """ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]}) added = tokenizer.add_tokens(motion_tokens, special_tokens=True) return added # ====================================================================================== # Existing Logic (Unsloth / LoRA) # ====================================================================================== def build_special_tokens(codebook_size: int, unique_pids: List[str] = None) -> List[str]: """ Build all special tokens for motion vocabulary """ # Motion tokens motion_tokens = [f"" for i in range(codebook_size)] # Boundary tokens boundary_tokens = ["", ""] # Task tokens task_tokens = ["", "", "", ""] # Participant ID tokens pid_tokens = [] if unique_pids: pid_tokens = [""] + [f"" for pid in unique_pids] return boundary_tokens + motion_tokens + task_tokens + pid_tokens def setup_model_and_tokenizer(codebook_size: int, unique_pids: List[str] = None): """ Initialize model and tokenizer with custom tokens (Unsloth LoRA) Returns: (model, tokenizer, new_token_ids) """ # Build special tokens additional_special_tokens = build_special_tokens(codebook_size, unique_pids) # Load base model model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME, max_seq_length=MAX_SEQ_LEN, dtype=DTYPE, load_in_4bit=False, trust_remote_code=True, ) # Configure tokenizer tokenizer.padding_side = "right" # Add special tokens existing = set(tokenizer.special_tokens_map_extended.get("additional_special_tokens", [])) to_add = [t for t in additional_special_tokens if t not in existing] if to_add: tokenizer.add_special_tokens({"additional_special_tokens": to_add}) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Resize embeddings model.resize_token_embeddings(len(tokenizer)) # Apply LoRA model = FastLanguageModel.get_peft_model( model, r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, bias="none", target_modules=LORA_TARGET_MODULES, modules_to_save=LORA_MODULES_TO_SAVE, use_gradient_checkpointing="unsloth", ) # Get new token IDs for gradient masking new_token_ids = set(tokenizer.convert_tokens_to_ids(additional_special_tokens)) # Apply gradient mask to prevent base vocab drift apply_gradient_mask(model, new_token_ids) return model, tokenizer, new_token_ids def apply_gradient_mask(model, new_token_ids: Set[int]): """ Apply gradient mask so only new token embeddings are updated """ def mask_rows_hook(param, rows: set): mask = torch.zeros(param.size(0), device=param.device, dtype=param.dtype) idxs = sorted(list(rows)) if len(idxs) > 0: mask[idxs] = 1.0 param.register_hook(lambda g: g * mask.unsqueeze(1)) with torch.no_grad(): emb = model.get_input_embeddings().weight head = model.get_output_embeddings().weight mask_rows_hook(emb, new_token_ids) mask_rows_hook(head, new_token_ids) def get_motion_token_info(tokenizer, codebook_size: int): """ Get motion token IDs and boundary token IDs Returns: (motion_token_ids, mot_begin_id, mot_end_id) """ motion_token_strs = [f"" for i in range(codebook_size)] motion_token_ids = tokenizer.convert_tokens_to_ids(motion_token_strs) mot_begin_id = tokenizer.convert_tokens_to_ids("") mot_end_id = tokenizer.convert_tokens_to_ids("") return motion_token_ids, mot_begin_id, mot_end_id