cherrydata / ultralytics /check_pt.py
Voidljc
Your commit message
aa24fe8
Raw
History Blame Contribute Delete
2.37 kB
import torch
import numpy as np
# 加载 .pt 文件
pt_file_path = "/home/lab/LJ/wampee/ultralytics/yolov8n.pt" # 替换为你的 pt 文件路径
model_data = torch.load(pt_file_path)
# 检查 state_dict 是否存在
if "state_dict" in model_data:
print("state_dict found in the file.")
state_dict = model_data["state_dict"]
else:
print("state_dict not found. Assuming the file is directly a state_dict.")
state_dict = model_data
i=0
# 遍历参数
#for name, param in state_dict['train_args'].items():
#i+=1
#print(f"Keys: {name}, Value: {param}")
train_args=state_dict['train_args'].items()
# 映射字典
task_mapping = {
'detect': 0,
'classify': 1,
'segment': 2,
'pose': 3,
'track': 4,
'unet': 5
}
mode_mapping = {
'train': 0,
'val': 1,
'test': 2,
'export': 3,
'infer': 4,
'predict': 5,
'deploy': 6
}
# 函数:将布尔值转换为 0/1,将 None 转换为 -1,其他保持不变
def convert_values(key, value):
if key == 'task':
return task_mapping.get(value, value) # 映射 task
elif key == 'mode':
return mode_mapping.get(value, value) # 映射 mode
elif isinstance(value, bool):
return 0 if value else 1 # 布尔值映射
elif value is None:
return -1 # None 映射
elif isinstance(value,str):
return -2
else:
return value # 其他值保持不变
# 将 `train_args` 转换为矩阵格式(字典转为列表)
# 直接使用 dict_items 遍历,避免 .keys() 错误
items = dict(train_args)
#print(items)
converted_items = {key: convert_values(key, value) for key, value in items.items()}
converted_items=dict(converted_items)
# 转换 train_args 中的所有值,并确保转换后为列表
converted_values = [value for key, value in converted_items.items()]
# print(type(converted_values))
# print(converted_values)
#生成矩阵
# 填充矩阵
matrix_size = 10 * 11 # 10x11 矩阵的总大小
if len(converted_values) < matrix_size:
converted_values.extend([None] * (matrix_size - len(converted_values))) # 填充 None
else:
converted_values = converted_values[:matrix_size] # 截取前 matrix_size 个值
# 创建 10x11 矩阵
matrix = np.array(converted_values).reshape(10, 11)
# 输出矩阵
print("\nConverted matrix (10x11):")
print(matrix)