File size: 5,609 Bytes
4bd136e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Model and tokenizer initialization
"""
import torch
from typing import List, Set, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
from unsloth import FastLanguageModel
from config import (
    MODEL_NAME, MAX_SEQ_LEN, DTYPE,
    LORA_R, LORA_ALPHA, LORA_DROPOUT,
    LORA_TARGET_MODULES, LORA_MODULES_TO_SAVE,
    PAD_TOKEN, M_START, M_END
)

# ======================================================================================
# Logic from test_overfit.py (Standard Transformers)
# ======================================================================================

def setup_model_and_tokenizer_raw(model_name: str, motion_tokens: List[str]) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """Loads the model and tokenizer, adding special and motion tokens (Standard Transformers)."""
    print(f"\n---> Loading base model and tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

    # Add special tokens (matches test_overfit.py)
    tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
    
    print(f"Adding {len(motion_tokens)} motion tokens to the tokenizer.")
    tokenizer.add_tokens(motion_tokens, special_tokens=True)
    
    model.resize_token_embeddings(len(tokenizer))
    model.config.pad_token_id = tokenizer.pad_token_id

    return model, tokenizer

def ensure_tokenizer_has_motion_tokens(tokenizer: AutoTokenizer, motion_tokens: List[str]) -> int:
    """
    Adds any missing motion tokens to the tokenizer. Returns number of tokens added.
    """
    tokenizer.add_special_tokens({"pad_token": PAD_TOKEN, "additional_special_tokens": [M_START, M_END]})
    added = tokenizer.add_tokens(motion_tokens, special_tokens=True)
    return added

# ======================================================================================
# Existing Logic (Unsloth / LoRA)
# ======================================================================================

def build_special_tokens(codebook_size: int, unique_pids: List[str] = None) -> List[str]:
    """
    Build all special tokens for motion vocabulary
    """
    # Motion tokens
    motion_tokens = [f"<motion_{i}>" for i in range(codebook_size)]
    
    # Boundary tokens
    boundary_tokens = ["<MOT_BEGIN>", "<MOT_END>"]
    
    # Task tokens
    task_tokens = ["<T2M>", "<M2T>", "<DENOISE>", "<MOTION_MASK>"]
    
    # Participant ID tokens
    pid_tokens = []
    if unique_pids:
        pid_tokens = ["<PID_NULL>"] + [f"<PID_{pid}>" for pid in unique_pids]
    
    return boundary_tokens + motion_tokens + task_tokens + pid_tokens


def setup_model_and_tokenizer(codebook_size: int, unique_pids: List[str] = None):
    """
    Initialize model and tokenizer with custom tokens (Unsloth LoRA)
    Returns: (model, tokenizer, new_token_ids)
    """
    # Build special tokens
    additional_special_tokens = build_special_tokens(codebook_size, unique_pids)
    
    # Load base model
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_NAME,
        max_seq_length=MAX_SEQ_LEN,
        dtype=DTYPE,
        load_in_4bit=False,
        trust_remote_code=True,
    )
    
    # Configure tokenizer
    tokenizer.padding_side = "right"
    
    # Add special tokens
    existing = set(tokenizer.special_tokens_map_extended.get("additional_special_tokens", []))
    to_add = [t for t in additional_special_tokens if t not in existing]
    
    if to_add:
        tokenizer.add_special_tokens({"additional_special_tokens": to_add})
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Resize embeddings
    model.resize_token_embeddings(len(tokenizer))
    
    # Apply LoRA
    model = FastLanguageModel.get_peft_model(
        model,
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        target_modules=LORA_TARGET_MODULES,
        modules_to_save=LORA_MODULES_TO_SAVE,
        use_gradient_checkpointing="unsloth",
    )
    
    # Get new token IDs for gradient masking
    new_token_ids = set(tokenizer.convert_tokens_to_ids(additional_special_tokens))
    
    # Apply gradient mask to prevent base vocab drift
    apply_gradient_mask(model, new_token_ids)
    
    return model, tokenizer, new_token_ids


def apply_gradient_mask(model, new_token_ids: Set[int]):
    """
    Apply gradient mask so only new token embeddings are updated
    """
    def mask_rows_hook(param, rows: set):
        mask = torch.zeros(param.size(0), device=param.device, dtype=param.dtype)
        idxs = sorted(list(rows))
        if len(idxs) > 0:
            mask[idxs] = 1.0
        param.register_hook(lambda g: g * mask.unsqueeze(1))
    
    with torch.no_grad():
        emb = model.get_input_embeddings().weight
        head = model.get_output_embeddings().weight
        
        mask_rows_hook(emb, new_token_ids)
        mask_rows_hook(head, new_token_ids)


def get_motion_token_info(tokenizer, codebook_size: int):
    """
    Get motion token IDs and boundary token IDs
    Returns: (motion_token_ids, mot_begin_id, mot_end_id)
    """
    motion_token_strs = [f"<motion_{i}>" for i in range(codebook_size)]
    motion_token_ids = tokenizer.convert_tokens_to_ids(motion_token_strs)
    mot_begin_id = tokenizer.convert_tokens_to_ids("<MOT_BEGIN>")
    mot_end_id = tokenizer.convert_tokens_to_ids("<MOT_END>")
    
    return motion_token_ids, mot_begin_id, mot_end_id