File size: 9,836 Bytes
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3768f67
582ea12
 
 
 
 
 
 
 
 
 
 
c0742fe
582ea12
c0742fe
582ea12
 
 
4e8105c
8730f5f
582ea12
 
 
e63a1d1
 
 
 
 
4e82a89
582ea12
e63a1d1
 
 
 
 
 
 
 
 
 
 
 
 
582ea12
 
 
 
 
 
 
 
e63a1d1
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e63a1d1
582ea12
 
 
 
 
e63a1d1
 
582ea12
 
 
 
 
8730f5f
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f411d7
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e63a1d1
 
582ea12
 
 
 
 
 
 
 
 
c0742fe
582ea12
 
 
 
3768f67
 
 
5c65b5e
582ea12
e63a1d1
 
9a48e97
582ea12
c0742fe
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af3855
582ea12
 
e63a1d1
582ea12
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# router_backend.py
"""
Plug your real model routing function here.

Implement the function:
    get_expert_routing(model_id: str, prompt: str) -> list[float] | dict[str, float] | tuple[float, float, float, float]

It must return 4 values (percentages) corresponding to the experts:
["Language", "Logic", "Social", "World"]

Example return formats:
- [12.5, 45.0, 22.5, 20.0]
- {"Language": 12.5, "Logic": 45.0, "Social": 22.5, "World": 20.0}
- (12.5, 45.0, 22.5, 20.0)
"""
import torch
import pathlib
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from typing import Union, Dict, List, Tuple

from models.micro_olmo import MiCRoOLMo
from models.micro_llama import MiCRoLlama
from models.micro_moe_llama import MiCRoLlamaMoE

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]], ablations: List[str] = None) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:

    model, tokenizer = build_model(model_id, hf_token, ablations=ablations)

    if isinstance(prompt, str):
        generation, routing_weights = generate_continuation(model, tokenizer, prompt)
        generation = generation[0] if type(generation) is list else generation
    elif isinstance(prompt, list):
        generation = None
        routing_weights = get_routing_weights(model, tokenizer, [prompt])

    model_routing_percentages, layer_token_routing = aggregate_routing_weights(routing_weights)

    layer_token_routing = np.array(layer_token_routing)
    num_experts, num_layers = layer_token_routing.shape

    print(model_routing_percentages)

    layer_token_routing = np.roll(layer_token_routing, shift=1, axis=0)

    all_layer_routing_percentages = []
    for layer_idx in range(num_layers):
        layer_token_percentages = []
        for expert_idx in range(num_experts):
            percentage = (layer_token_routing[expert_idx][layer_idx] / sum(layer_token_routing[:, layer_idx])) * 100
            layer_token_percentages.append(percentage)
        all_layer_routing_percentages.append(layer_token_percentages)

    layer_routing_percentages = np.array(all_layer_routing_percentages)


    if generation is not None:
        print(f"Generation:\n{generation}")
    
    return {
        "Language": float(model_routing_percentages[3]),
        "Logic": float(model_routing_percentages[0]),
        "Social": float(model_routing_percentages[1]),
        "World": float(model_routing_percentages[2]),
    }, layer_routing_percentages, generation

def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
    return {
        # MiCRo-Llama
        "micro-llama-1b": ("bkhmsi/micro-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
        "micro-llama-3b": ("bkhmsi/micro-llama-3b", "meta-llama/Llama-3.2-3B-Instruct", MiCRoLlama),
        "micro-llama-1b-dpo": ("bkhmsi/micro-llama-1b-dpo", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),

        # MiCRo-MoE-Llama
        "micro-moe-llama-1b": ("bkhmsi/micro-moe-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlamaMoE),
        
        # MiCRo-OLMo
        "micro-olmo": ("bkhmsi/micro-olmo-1b", "allenai/OLMo-2-0425-1B-Instruct", MiCRoOLMo),

        # MiCRo-SmolLM2
        "micro-smollm2-135m": ("bkhmsi/micro-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlama),
        "micro-smollm2-360m": ("bkhmsi/micro-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlama),

        # MiCRo-MoE-SmolLM2
        "micro-moe-smollm2-135m": ("bkhmsi/micro-moe-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlamaMoE),
        "micro-moe-smollm2-360m": ("bkhmsi/micro-moe-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlamaMoE),
    }.get(model_name, (model_name, model_name, AutoModelForCausalLM))

def aggregate_routing_weights(routing_weights):
    experts = ["Logic", "Social", "World", "Language"]
    expert_token_model = np.zeros((len(experts)), dtype=int)
    expert_layer_token = np.zeros((len(experts), routing_weights.shape[0]), dtype=int)
    num_layers = routing_weights.shape[0]

    for layer_idx in range(num_layers):
        for token_idx in range(len(routing_weights[layer_idx])):
            expert_idx = routing_weights[layer_idx][token_idx].argmax()
            expert_token_model[expert_idx] += 1
            expert_layer_token[expert_idx][layer_idx] += 1
    return expert_token_model, expert_layer_token

def generate_continuation(model, 
    tokenizer, 
    prompts, 
    max_tokens=128,
    use_cache=True, 
    return_routing_weights=True
):

    if isinstance(prompts, str):
        prompts = [{"role": "user", "content": prompts}]

    tokenizer.padding_side = "left"
    inputs = tokenizer.apply_chat_template([
        prompt for prompt in prompts
    ], return_tensors="pt", padding=True, add_generation_prompt=True).to(DEVICE)

    attention_mask = torch.ones_like(inputs)
    attention_mask[inputs == tokenizer.pad_token_id] = 0

    outputs = model.generate(
        input_ids=inputs,
        attention_mask=attention_mask, 
        max_new_tokens=max_tokens,
        use_cache=use_cache,
        stop_strings=["</s>","<|eot_id|>", "<|im_start|>user", "user"],
        tokenizer=tokenizer,
        pad_token_id=tokenizer.pad_token_id,
        temperature=0,
        top_p=1.0,
        do_sample=False,
    )
    
    if return_routing_weights:
        attention_mask = torch.ones_like(outputs)
        attention_mask[outputs == tokenizer.pad_token_id] = 0
        model_output = model(input_ids=outputs, attention_mask=attention_mask)
        torch.cuda.empty_cache()

        routing_weights = model_output.routing_weights        
        routing_weights = np.concatenate([
            F.softmax(rw, dim=-1)[:, inputs.shape[1]:].detach().float().cpu().numpy() 
            for rw in routing_weights
        ])
        
    else:
        routing_weights = None

    inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=False)

    generations = []
    for i, output in enumerate(outputs):
        decoded_output = tokenizer.decode(output, skip_special_tokens=False)
        decoded_output = decoded_output.replace(inputs_text[i], "")
        decoded_output = decoded_output.replace(tokenizer.pad_token, "").strip()
        decoded_output = decoded_output.replace("<|end_of_text|>", "").strip()
        decoded_output = decoded_output.replace("<|endoftext|>", "").strip()
        decoded_output = decoded_output.replace("<|eot_id|>", "").strip()
        decoded_output = decoded_output.replace("\n<|im_start|>user", "").strip()
        generations.append(decoded_output)

    return (generations, routing_weights) if return_routing_weights else generations

def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
    """
    Get routing weights for the given prompts using the model.
    Args:
        model: The MiCRoLlama or MiCRoOLMo model.
        tokenizer: The tokenizer for the model.
        prompts: A string or list of dictionaries containing the prompts.
    Returns:
        routing_weights: A list of routing weights for each layer.
    """

    tokenizer.padding_side = "left"
    if apply_chat_template:
        if isinstance(prompts, str):
            prompts = [{"role": "user", "content": prompts}]

        inputs = tokenizer.apply_chat_template([
            prompt for prompt in prompts
        ], return_tensors="pt", padding=True).to(DEVICE)

        input_without_response = tokenizer.apply_chat_template([
                prompt[:-1] for prompt in prompts
            ], return_tensors="pt", padding=True,
        ).to(DEVICE)
    else:
        inputs = tokenizer(prompts[0] + prompts[1], return_tensors="pt", padding=True).input_ids.to(DEVICE)
        input_without_response = tokenizer(prompts[0], return_tensors="pt", padding=True).input_ids.to(DEVICE)

    attention_mask = torch.ones_like(inputs)
    attention_mask[inputs == tokenizer.pad_token_id] = 0

    with torch.no_grad():
        model_output = model(input_ids=inputs, attention_mask=attention_mask)

    routing_weights = model_output.routing_weights   
    routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()

    offset = len(input_without_response[0])-1
    routing_weights = routing_weights[:, offset:-1]

    return routing_weights

def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: bool = True):

    model_path, base_model, model_class = get_model_path(model_id)

    model_config = AutoConfig.from_pretrained(base_model, use_auth_token=hf_token)

    parent_path = pathlib.Path(__file__).parent

    model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml"

    model_config.torch_dtype = torch.float16
    model_config.use_bfloat16 = False
    model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
    model_config.use_cache = use_cache
    model_config.ablate = ablations

    tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
    tokenizer.padding_side = "left"

    if "llama" in model_id:
        tokenizer.pad_token_id = 128004
    if "olmo" in model_id:
        tokenizer.pad_token_id = 100277
        tokenizer.add_special_tokens({'additional_special_tokens': ['<|assistant|>']})
    elif "smollm2" in model_id:
        tokenizer.pad_token_id = 2
    else:
        tokenizer.pad_token_id = 128004

    if "olmo" in model_id:
        model_config.vocab_size = len(tokenizer)

    model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True)

    model.to(DEVICE)
    model = model.half()
    model.eval()
    return model, tokenizer