File size: 1,360 Bytes
b59223f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# ztrain/model.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted

from collections import defaultdict
import re

def generate_merge_group(group_data : list, parents : list[int] = []):
    # drill down until we find a list of strings, then yield it with a parent tree index
    for i, g in enumerate(group_data):
        if isinstance(g, list):
            yield from generate_merge_group(g, parents + [i])
        else:
            yield g, parents + [i]

def merge_groups(group_data : list):
    results = defaultdict(list)
    for g, k in generate_merge_group(group_data):
        key = tuple(k[:-1])
        results[key].append(g)
    return results

def get_layer_type(k : str) -> tuple[int, str, str, str]:
    matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)\.(.+)")

    m = matcher.match(k)
    if m is not None:
        return int(m.group(1)), m.group(2), m.group(3), m.group(4)
    matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)")
    if m is not None:
        return int(m.group(1)), m.group(2), "", m.group(3)

    if "model.norm.weight" == k:
        return -1, "norm", "", "weight"
    if "model.embed_tokens.weight" == k:
        return -1, "embed_tokens", "", "weight"
    if "lm_head.weight" == k:
        return -1, "lm_head", "", "weight"
    print(f"Unknown key {k}")
    return -1, "unknown", "unknown", "unknown"