Spaces:
Runtime error
Runtime error
File size: 4,781 Bytes
5ccf219 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
"""
Model setup utilities for RosettaModel training/evaluation
"""
import torch
from typing import Dict, Any, List
from transformers import AutoModelForCausalLM, AutoTokenizer
from rosetta.model.wrapper import RosettaModel
from rosetta.model.projector import create_projector
"""
Mapping strategies
"""
def k_nearest_sources(num_target_layers: int, num_source_layers: int, k: int) -> Dict[int, List[int]]:
"""
Compute a per-target mapping to K nearest source layers.
Returns: Dict[target_idx, List[source_idx]] only for targets we map.
Distances are computed by placing target and source layers uniformly in [0, 1]
and sorting by absolute distance.
"""
if num_target_layers <= 1:
target_positions = [0.0]
else:
target_positions = [i / (num_target_layers - 1) for i in range(num_target_layers)]
if num_source_layers <= 1:
source_positions = [0.0]
else:
source_positions = [j / (num_source_layers - 1) for j in range(num_source_layers)]
mapping: Dict[int, List[int]] = {}
for t_idx, t_pos in enumerate(target_positions):
sorted_src = sorted(range(num_source_layers), key=lambda j: abs(source_positions[j] - t_pos))
chosen = sorted_src[:max(0, k)]
if len(chosen) > 0:
mapping[t_idx] = chosen
return mapping
def last_aligned_sources(num_target_layers: int, num_source_layers: int, k: int = 1) -> Dict[int, List[int]]:
"""
Return a per-target mapping that aligns the last target layer to the last
source layer and walks toward the front.
Returns: Dict[target_idx, List[source_idx]] only for targets we map. For each
target t, we choose up to K sources anchored at the aligned index, preferring
backward indices first then forward to satisfy K.
Example (T=11, S=33): target 10 -> [32, 31, ...], target 9 -> [31, 30, ...]
"""
mapping: Dict[int, List[int]] = {}
if num_target_layers <= 0 or num_source_layers <= 0:
return mapping
# Align ends; offset >= 0 means extra source layers at the front
offset = num_source_layers - num_target_layers
def take_k_from(s0: int) -> List[int]:
result: List[int] = []
# Prefer moving backward from the anchor (last-to-front)
for back in range(k):
idx = s0 - back
if 0 <= idx < num_source_layers:
result.append(idx)
# If not enough due to boundary, extend forward
next_idx = s0 + 1
while len(result) < k and next_idx < num_source_layers:
result.append(next_idx)
next_idx += 1
return result
for t in range(num_target_layers):
s0 = offset + t
# Clamp to valid range for edge cases (e.g., fewer source layers)
if s0 < 0:
s0 = 0
elif s0 > num_source_layers - 1:
s0 = num_source_layers - 1
chosen = take_k_from(s0)
if len(chosen) > 0:
mapping[t] = chosen
return mapping
def setup_models(model_config: Dict[str, Any], device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
"""Setup RosettaModel with base model, teacher model, and projectors"""
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_config["base_model"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load models
base_model = AutoModelForCausalLM.from_pretrained(
model_config["base_model"],
torch_dtype=dtype,
device_map=device
)
teacher_model = AutoModelForCausalLM.from_pretrained(
model_config["teacher_model"],
torch_dtype=dtype,
device_map=device
)
# Create projector
projector_config = model_config["projector"]
projector_params = projector_config["params"].copy()
projector_params["dtype"] = dtype
projector = create_projector(
projector_config["type"],
source_dim=teacher_model.config.head_dim,
target_dim=base_model.config.head_dim,
**projector_params
)
# Setup RosettaModel
rosetta_model = RosettaModel(
model_list=[base_model, teacher_model],
base_model_idx=0,
projector_list=[projector]
).to(device)
# Configure projector mappings
num_layers_to_map = min(
base_model.config.num_hidden_layers,
teacher_model.config.num_hidden_layers
)
for layer_idx in range(num_layers_to_map):
rosetta_model.set_projector_config(
source_model_idx=1, # Teacher
source_model_layer_idx=layer_idx,
target_model_idx=0, # Base
target_model_layer_idx=layer_idx,
projector_idx=0
)
return rosetta_model, tokenizer |