File size: 3,371 Bytes
93345e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os
import sys
from safetensors.torch import load_file, save_file

class C:
    HEADER = '\033[95m'
    BLUE = '\033[94m'
    CYAN = '\033[96m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    GREY = '\033[90m'
    RESET = '\033[0m'
    BOLD = '\033[1m'

def convert_key(old_key):
    parts = old_key.split('.')

    if "attn" not in parts or "processor" not in parts or "weight" not in parts:
        return None
    
    lora_type_idx = -1
    for i, part in enumerate(parts):
        if part.endswith('_loras'):
            lora_type_idx = i
            break
    
    if lora_type_idx == -1:
        return None

    raw_type = parts[lora_type_idx].replace('_loras', '')
    
    try:
        direction = parts[lora_type_idx + 2] 
    except IndexError:
        return None

    type_map = {
        "q": "to_q",
        "k": "to_k",
        "v": "to_v",
        "proj": "to_out.0"
    }
    
    lora_map = {
        "down": "lora_A",
        "up": "lora_B"
    }

    if raw_type not in type_map or direction not in lora_map:
        return None

    new_type_str = type_map[raw_type]
    new_lora_str = lora_map[direction]

    prefix_parts = parts[:lora_type_idx - 1] 
    prefix_str = ".".join(prefix_parts)

    new_key = f"transformer.{prefix_str}.{new_type_str}.{new_lora_str}.weight"
    
    return new_key

def main():
    parser = argparse.ArgumentParser(description="Colorful LoRA Key Converter")
    parser.add_argument("input_file", type=str, help="Input safetensors file")
    parser.add_argument("output_file", type=str, help="Output safetensors file")
    
    args = parser.parse_args()

    if not os.path.exists(args.input_file):
        print(f"{C.RED}Error: Input file '{args.input_file}' not found.{C.RESET}")
        return

    print(f"{C.HEADER}Loading weights from: {C.RESET}{C.BOLD}{args.input_file}{C.RESET}")
    try:
        tensors = load_file(args.input_file)
    except Exception as e:
        print(f"{C.RED}Failed to load file: {e}{C.RESET}")
        return

    new_tensors = {}
    skipped_keys = []
    
    print(f"{C.HEADER}{'-'*60}{C.RESET}")
    print(f"{C.BOLD}Processing Keys:{C.RESET}")

    for key, tensor in tensors.items():
        new_key = convert_key(key)
        
        if new_key:
            print(f"{C.RED}{key}{C.RESET}")
            print(f"   ↓ {C.GREY}converted to{C.RESET}")
            print(f"{C.GREEN}{new_key}{C.RESET}")
            print(f"{C.GREY}-{C.RESET}" * 20)
            
            new_tensors[new_key] = tensor
        else:
            skipped_keys.append(key)

    print(f"\n{C.HEADER}Saving...{C.RESET}")
    save_file(new_tensors, args.output_file)
    print(f"{C.BOLD}Saved converted model to:{C.RESET} {C.CYAN}{args.output_file}{C.RESET}")

    print(f"\n{C.HEADER}{'-'*60}{C.RESET}")
    print(f"{C.YELLOW}Skipped Keys (Not Converted): {len(skipped_keys)}{C.RESET}")
    if skipped_keys:
        for sk in skipped_keys:
            print(f"  {C.GREY}{sk}{C.RESET}")
    else:
        print(f"  {C.GREEN}(None - All keys were converted){C.RESET}")
    
    print(f"{C.HEADER}{'-'*60}{C.RESET}")
    print(f"Total Original Keys: {len(tensors)}")
    print(f"Converted Keys:      {C.GREEN}{len(new_tensors)}{C.RESET}")
    print(f"Skipped Keys:        {C.YELLOW}{len(skipped_keys)}{C.RESET}")

if __name__ == "__main__":
    main()