File size: 2,366 Bytes
aa24fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import numpy as np
# 加载 .pt 文件
pt_file_path = "/home/lab/LJ/wampee/ultralytics/yolov8n.pt"  # 替换为你的 pt 文件路径
model_data = torch.load(pt_file_path)

# 检查 state_dict 是否存在
if "state_dict" in model_data:
    print("state_dict found in the file.")
    state_dict = model_data["state_dict"]
else:
    print("state_dict not found. Assuming the file is directly a state_dict.")
    state_dict = model_data

i=0
# 遍历参数
#for name, param in state_dict['train_args'].items():
        #i+=1
        #print(f"Keys: {name}, Value: {param}")
        

train_args=state_dict['train_args'].items()


# 映射字典
task_mapping = {
    'detect': 0,
    'classify': 1,
    'segment': 2,
    'pose': 3,
    'track': 4,
    'unet': 5
}

mode_mapping = {
    'train': 0,
    'val': 1,
    'test': 2,
    'export': 3,
    'infer': 4,
    'predict': 5,
    'deploy': 6
}

# 函数:将布尔值转换为 0/1,将 None 转换为 -1,其他保持不变
def convert_values(key, value):
    if key == 'task':
        return task_mapping.get(value, value)  # 映射 task
    elif key == 'mode':
        return mode_mapping.get(value, value)  # 映射 mode
    elif isinstance(value, bool):
        return 0 if value else 1  # 布尔值映射
    elif value is None:
        return -1  # None 映射
    elif isinstance(value,str):
        return -2
    else:
        return value  # 其他值保持不变


# 将 `train_args` 转换为矩阵格式(字典转为列表)
# 直接使用 dict_items 遍历,避免 .keys() 错误
items = dict(train_args)



#print(items)

converted_items = {key: convert_values(key, value) for key, value in items.items()}
converted_items=dict(converted_items)



# 转换 train_args 中的所有值,并确保转换后为列表
converted_values = [value for key, value in converted_items.items()]
# print(type(converted_values))
# print(converted_values)

#生成矩阵
# 填充矩阵
matrix_size = 10 * 11  # 10x11 矩阵的总大小
if len(converted_values) < matrix_size:
    converted_values.extend([None] * (matrix_size - len(converted_values)))  # 填充 None
else:
    converted_values = converted_values[:matrix_size]  # 截取前 matrix_size 个值

# 创建 10x11 矩阵
matrix = np.array(converted_values).reshape(10, 11)

# 输出矩阵
print("\nConverted matrix (10x11):")
print(matrix)