change cuda
Browse files
minigpt4/models/minigpt4.py
CHANGED
|
@@ -451,11 +451,12 @@ class MiniGPT4(MiniGPTBase):
|
|
| 451 |
state_dict_a = ckpt['model']
|
| 452 |
state_dict_b = model.state_dict()
|
| 453 |
for name_b, param_b in state_dict_b.items():
|
| 454 |
-
print(name_b,param_b.shape)
|
| 455 |
if name_b in state_dict_a:
|
| 456 |
param_a = state_dict_a[name_b]
|
| 457 |
# 检查形状是否匹配,以避免错误
|
| 458 |
if param_a.shape == param_b.shape:
|
|
|
|
|
|
|
| 459 |
# print(f"Transferring weights for layer: {name_b}")
|
| 460 |
state_dict_b[name_b].copy_(param_a)
|
| 461 |
else:
|
|
|
|
| 451 |
state_dict_a = ckpt['model']
|
| 452 |
state_dict_b = model.state_dict()
|
| 453 |
for name_b, param_b in state_dict_b.items():
|
|
|
|
| 454 |
if name_b in state_dict_a:
|
| 455 |
param_a = state_dict_a[name_b]
|
| 456 |
# 检查形状是否匹配,以避免错误
|
| 457 |
if param_a.shape == param_b.shape:
|
| 458 |
+
print(name_b,param_b.shape,param_a.shape)
|
| 459 |
+
|
| 460 |
# print(f"Transferring weights for layer: {name_b}")
|
| 461 |
state_dict_b[name_b].copy_(param_a)
|
| 462 |
else:
|