cherrydata / ultralytics /Else /weights2matrix.py
Voidljc
Your commit message
aa24fe8
Raw
History Blame Contribute Delete
1.22 kB
from ultralytics import YOLO
import torch
import math
# 加载预训练模型
model = YOLO('yolov8n.pt')
# 获取模型的状态字典(即权重)
state_dict = model.state_dict()
# 将所有参数的张量拉平并合并成一个大的向量
flattened_params = [torch.flatten(param) for param in state_dict.values()]
# 拼接所有拉平后的张量
merged_params = torch.cat(flattened_params)
# 输出合并后的大矩阵的形状
total_elements = merged_params.numel()
print(f"Total elements: {total_elements}")
# 计算合适的正方形矩阵的大小
square_size = math.ceil(total_elements ** 0.5) # 向上取整,以确保可以填充
total_required_elements = square_size ** 2 # 计算正方形矩阵需要的元素数量
# 如果需要的元素数量大于现有元素数量,用 -1 填充
if total_required_elements > total_elements:
padding = total_required_elements - total_elements
# 用 -1 填充
merged_params = torch.cat([merged_params, torch.full((padding,), -1)])
# 将填充后的张量转换为正方形矩阵
square_matrix = merged_params.view(square_size, square_size)
# 打印结果
print(f"Reshaped into square matrix: {square_matrix.shape}")
print(square_matrix)