import torch import numpy as np # 加载 .pt 文件 pt_file_path = "/home/lab/LJ/wampee/ultralytics/yolov8n.pt" # 替换为你的 pt 文件路径 model_data = torch.load(pt_file_path) # 检查 state_dict 是否存在 if "state_dict" in model_data: print("state_dict found in the file.") state_dict = model_data["state_dict"] else: print("state_dict not found. Assuming the file is directly a state_dict.") state_dict = model_data i=0 # 遍历参数 #for name, param in state_dict['train_args'].items(): #i+=1 #print(f"Keys: {name}, Value: {param}") train_args=state_dict['train_args'].items() # 映射字典 task_mapping = { 'detect': 0, 'classify': 1, 'segment': 2, 'pose': 3, 'track': 4, 'unet': 5 } mode_mapping = { 'train': 0, 'val': 1, 'test': 2, 'export': 3, 'infer': 4, 'predict': 5, 'deploy': 6 } # 函数:将布尔值转换为 0/1,将 None 转换为 -1,其他保持不变 def convert_values(key, value): if key == 'task': return task_mapping.get(value, value) # 映射 task elif key == 'mode': return mode_mapping.get(value, value) # 映射 mode elif isinstance(value, bool): return 0 if value else 1 # 布尔值映射 elif value is None: return -1 # None 映射 elif isinstance(value,str): return -2 else: return value # 其他值保持不变 # 将 `train_args` 转换为矩阵格式(字典转为列表) # 直接使用 dict_items 遍历,避免 .keys() 错误 items = dict(train_args) #print(items) converted_items = {key: convert_values(key, value) for key, value in items.items()} converted_items=dict(converted_items) # 转换 train_args 中的所有值,并确保转换后为列表 converted_values = [value for key, value in converted_items.items()] # print(type(converted_values)) # print(converted_values) #生成矩阵 # 填充矩阵 matrix_size = 10 * 11 # 10x11 矩阵的总大小 if len(converted_values) < matrix_size: converted_values.extend([None] * (matrix_size - len(converted_values))) # 填充 None else: converted_values = converted_values[:matrix_size] # 截取前 matrix_size 个值 # 创建 10x11 矩阵 matrix = np.array(converted_values).reshape(10, 11) # 输出矩阵 print("\nConverted matrix (10x11):") print(matrix)