File size: 4,781 Bytes
5ccf219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Model setup utilities for RosettaModel training/evaluation
"""

import torch
from typing import Dict, Any, List
from transformers import AutoModelForCausalLM, AutoTokenizer

from rosetta.model.wrapper import RosettaModel
from rosetta.model.projector import create_projector

"""
Mapping strategies
"""
def k_nearest_sources(num_target_layers: int, num_source_layers: int, k: int) -> Dict[int, List[int]]:
    """
    Compute a per-target mapping to K nearest source layers.

    Returns: Dict[target_idx, List[source_idx]] only for targets we map.
    Distances are computed by placing target and source layers uniformly in [0, 1]
    and sorting by absolute distance.
    """
    if num_target_layers <= 1:
        target_positions = [0.0]
    else:
        target_positions = [i / (num_target_layers - 1) for i in range(num_target_layers)]
    if num_source_layers <= 1:
        source_positions = [0.0]
    else:
        source_positions = [j / (num_source_layers - 1) for j in range(num_source_layers)]

    mapping: Dict[int, List[int]] = {}
    for t_idx, t_pos in enumerate(target_positions):
        sorted_src = sorted(range(num_source_layers), key=lambda j: abs(source_positions[j] - t_pos))
        chosen = sorted_src[:max(0, k)]
        if len(chosen) > 0:
            mapping[t_idx] = chosen
    return mapping


def last_aligned_sources(num_target_layers: int, num_source_layers: int, k: int = 1) -> Dict[int, List[int]]:
    """
    Return a per-target mapping that aligns the last target layer to the last
    source layer and walks toward the front.

    Returns: Dict[target_idx, List[source_idx]] only for targets we map. For each
    target t, we choose up to K sources anchored at the aligned index, preferring
    backward indices first then forward to satisfy K.

    Example (T=11, S=33): target 10 -> [32, 31, ...], target 9 -> [31, 30, ...]
    """
    mapping: Dict[int, List[int]] = {}
    if num_target_layers <= 0 or num_source_layers <= 0:
        return mapping

    # Align ends; offset >= 0 means extra source layers at the front
    offset = num_source_layers - num_target_layers

    def take_k_from(s0: int) -> List[int]:
        result: List[int] = []
        # Prefer moving backward from the anchor (last-to-front)
        for back in range(k):
            idx = s0 - back
            if 0 <= idx < num_source_layers:
                result.append(idx)
        # If not enough due to boundary, extend forward
        next_idx = s0 + 1
        while len(result) < k and next_idx < num_source_layers:
            result.append(next_idx)
            next_idx += 1
        return result

    for t in range(num_target_layers):
        s0 = offset + t
        # Clamp to valid range for edge cases (e.g., fewer source layers)
        if s0 < 0:
            s0 = 0
        elif s0 > num_source_layers - 1:
            s0 = num_source_layers - 1
        chosen = take_k_from(s0)
        if len(chosen) > 0:
            mapping[t] = chosen

    return mapping


def setup_models(model_config: Dict[str, Any], device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
    """Setup RosettaModel with base model, teacher model, and projectors"""
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_config["base_model"])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load models
    base_model = AutoModelForCausalLM.from_pretrained(
        model_config["base_model"],
        torch_dtype=dtype,
        device_map=device
    )
    
    teacher_model = AutoModelForCausalLM.from_pretrained(
        model_config["teacher_model"],
        torch_dtype=dtype,
        device_map=device
    )
    
    # Create projector
    projector_config = model_config["projector"]
    projector_params = projector_config["params"].copy()
    projector_params["dtype"] = dtype
    
    projector = create_projector(
        projector_config["type"],
        source_dim=teacher_model.config.head_dim,
        target_dim=base_model.config.head_dim,
        **projector_params
    )

    # Setup RosettaModel
    rosetta_model = RosettaModel(
        model_list=[base_model, teacher_model],
        base_model_idx=0,
        projector_list=[projector]
    ).to(device)
    
    # Configure projector mappings
    num_layers_to_map = min(
        base_model.config.num_hidden_layers, 
        teacher_model.config.num_hidden_layers
    )
    
    for layer_idx in range(num_layers_to_map):
        rosetta_model.set_projector_config(
            source_model_idx=1,  # Teacher
            source_model_layer_idx=layer_idx,
            target_model_idx=0,  # Base
            target_model_layer_idx=layer_idx,
            projector_idx=0
        )
    
    return rosetta_model, tokenizer