fix bug fb16
Browse files- discusstion.txt +3 -0
- gainlora_baseline_origin/src/run_t5.py +39 -133
discusstion.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3453c142bfaf3afda3e18718267d871d49f5f1ebb22ac43b3dfe5b7e069da467
|
| 3 |
+
size 98496
|
gainlora_baseline_origin/src/run_t5.py
CHANGED
|
@@ -501,6 +501,22 @@ def main():
|
|
| 501 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 502 |
)
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
model.persent = training_args.persent
|
| 505 |
model.resize_token_embeddings(len(tokenizer))
|
| 506 |
|
|
@@ -880,142 +896,32 @@ def main():
|
|
| 880 |
trainer.is_deepspeed_enabled = False
|
| 881 |
print("is_deepspeed_enabled", trainer.is_deepspeed_enabled)
|
| 882 |
|
| 883 |
-
# ============
|
| 884 |
-
if training_args.
|
| 885 |
print("=" * 60)
|
| 886 |
-
print("[
|
| 887 |
-
model.train()
|
| 888 |
model.to(device)
|
| 889 |
-
|
| 890 |
-
# ---- TEST 1: Isolated LoRA layer test ----
|
| 891 |
-
print("\n--- TEST 1: Isolated LoRA layer ---")
|
| 892 |
-
_lora = model.encoder.block[0].layer[0].SelfAttention.lora_q
|
| 893 |
-
print(f" lora_A: shape={_lora.lora_A.shape}, requires_grad={_lora.lora_A.requires_grad}, "
|
| 894 |
-
f"norm={_lora.lora_A.data.norm().item():.6f}, all_zero={_lora.lora_A.data.eq(0).all().item()}")
|
| 895 |
-
print(f" lora_B: shape={_lora.lora_B.shape}, requires_grad={_lora.lora_B.requires_grad}, "
|
| 896 |
-
f"norm={_lora.lora_B.data.norm().item():.6f}, all_zero={_lora.lora_B.data.eq(0).all().item()}")
|
| 897 |
-
_test_x = torch.randn(1, 3, _lora.lora_A.shape[1], device=device)
|
| 898 |
-
_lora.lora_B.grad = None
|
| 899 |
-
_y = _lora(_test_x)
|
| 900 |
-
print(f" LoRA output: norm={_y.norm().item():.6f}, requires_grad={_y.requires_grad}")
|
| 901 |
-
_simple_loss = _y.sum()
|
| 902 |
-
print(f" simple_loss: {_simple_loss.item():.6f}, requires_grad={_simple_loss.requires_grad}, grad_fn={_simple_loss.grad_fn}")
|
| 903 |
-
_simple_loss.backward()
|
| 904 |
-
print(f" lora_B.grad: {'None' if _lora.lora_B.grad is None else f'norm={_lora.lora_B.grad.norm().item():.6e}'}")
|
| 905 |
-
# Also test with x that doesn't require grad (like in the real model when base is frozen)
|
| 906 |
-
_lora.lora_B.grad = None
|
| 907 |
-
_test_x2 = torch.randn(1, 3, _lora.lora_A.shape[1], device=device, requires_grad=False)
|
| 908 |
-
_y2 = _lora(_test_x2)
|
| 909 |
-
print(f" LoRA output (x.requires_grad=False): requires_grad={_y2.requires_grad}")
|
| 910 |
-
_y2.sum().backward()
|
| 911 |
-
print(f" lora_B.grad (x no grad): {'None' if _lora.lora_B.grad is None else f'norm={_lora.lora_B.grad.norm().item():.6e}'}")
|
| 912 |
-
model.zero_grad()
|
| 913 |
-
|
| 914 |
-
# ---- TEST 2: Check requires_grad propagation through the model ----
|
| 915 |
-
print("\n--- TEST 2: requires_grad propagation ---")
|
| 916 |
-
# Disable GC for this test to simplify
|
| 917 |
for _m in model.modules():
|
| 918 |
-
if hasattr(_m, '
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
# Also hook LoRA output
|
| 938 |
-
_lora_hook_data = {}
|
| 939 |
-
def _lora_fwd_hook(module, inp, out):
|
| 940 |
-
_lora_hook_data['output_requires_grad'] = out.requires_grad
|
| 941 |
-
_lora_hook_data['output_norm'] = out.norm().item()
|
| 942 |
-
if isinstance(inp, tuple):
|
| 943 |
-
_lora_hook_data['input_requires_grad'] = inp[0].requires_grad if isinstance(inp[0], torch.Tensor) else 'N/A'
|
| 944 |
-
else:
|
| 945 |
-
_lora_hook_data['input_requires_grad'] = inp.requires_grad if isinstance(inp, torch.Tensor) else 'N/A'
|
| 946 |
-
_lh = _attn_module.lora_q.register_forward_hook(_lora_fwd_hook)
|
| 947 |
-
model.zero_grad()
|
| 948 |
-
_outputs = model(**_test_input)
|
| 949 |
-
_loss = _outputs.loss
|
| 950 |
-
print(f" T5Attention input: requires_grad={_hook_data.get('input_requires_grad', 'N/A')}, norm={_hook_data.get('input_norm', 'N/A')}")
|
| 951 |
-
print(f" LoRA_q input: requires_grad={_lora_hook_data.get('input_requires_grad', 'N/A')}")
|
| 952 |
-
print(f" LoRA_q output: requires_grad={_lora_hook_data.get('output_requires_grad', 'N/A')}, norm={_lora_hook_data.get('output_norm', 'N/A')}")
|
| 953 |
-
_h.remove()
|
| 954 |
-
_lh.remove()
|
| 955 |
-
|
| 956 |
-
# ---- TEST 3: Backward with grad hooks on lora_B ----
|
| 957 |
-
print("\n--- TEST 3: Backward with hooks ---")
|
| 958 |
-
_grad_hooks = {}
|
| 959 |
-
def _make_grad_hook(name):
|
| 960 |
-
def _hook(grad):
|
| 961 |
-
_grad_hooks[name] = grad.norm().item()
|
| 962 |
-
return _hook
|
| 963 |
-
# Register hooks on a few lora_B params
|
| 964 |
-
_hook_handles = []
|
| 965 |
-
_hook_targets = [
|
| 966 |
-
'encoder.block.0.layer.0.SelfAttention.lora_q.lora_B',
|
| 967 |
-
'encoder.block.0.layer.0.SelfAttention.lora_v.lora_B',
|
| 968 |
-
'decoder.block.0.layer.0.SelfAttention.lora_q.lora_B',
|
| 969 |
-
'decoder.block.0.layer.1.EncDecAttention.lora_q.lora_B',
|
| 970 |
-
]
|
| 971 |
-
for name, p in model.named_parameters():
|
| 972 |
-
if name in _hook_targets:
|
| 973 |
-
_hook_handles.append(p.register_hook(_make_grad_hook(name)))
|
| 974 |
-
_loss.backward()
|
| 975 |
-
print(f" Grad hooks captured ({len(_grad_hooks)} hooks fired):")
|
| 976 |
-
for name, norm in _grad_hooks.items():
|
| 977 |
-
print(f" {name}: grad_norm={norm:.6e}")
|
| 978 |
-
if not _grad_hooks:
|
| 979 |
-
print(" WARNING: No grad hooks fired! Backward didn't reach lora_B.")
|
| 980 |
-
# Also check final gradient state
|
| 981 |
-
_n_ok, _n_zero, _n_none = 0, 0, 0
|
| 982 |
-
for name, p in model.named_parameters():
|
| 983 |
-
if p.requires_grad:
|
| 984 |
-
if p.grad is None:
|
| 985 |
-
_n_none += 1
|
| 986 |
-
elif p.grad.norm().item() > 0:
|
| 987 |
-
_n_ok += 1
|
| 988 |
-
else:
|
| 989 |
-
_n_zero += 1
|
| 990 |
-
print(f" Final grad counts: grad>0={_n_ok}, grad==0={_n_zero}, grad=None={_n_none}")
|
| 991 |
-
for _hh in _hook_handles:
|
| 992 |
-
_hh.remove()
|
| 993 |
-
|
| 994 |
-
# ---- TEST 4: Manual single-layer backward ----
|
| 995 |
-
print("\n--- TEST 4: Single T5Block forward+backward ---")
|
| 996 |
-
model.zero_grad()
|
| 997 |
-
_block = model.encoder.block[0]
|
| 998 |
-
_hidden = torch.randn(2, 10, model.config.d_model, device=device, requires_grad=True)
|
| 999 |
-
_mask = torch.zeros(2, 1, 1, 10, device=device)
|
| 1000 |
-
_block_out = _block(_hidden, attention_mask=_mask, key_attention_weights=None)
|
| 1001 |
-
_block_loss = _block_out[0].sum()
|
| 1002 |
-
print(f" block output requires_grad={_block_out[0].requires_grad}")
|
| 1003 |
-
_block_loss.backward()
|
| 1004 |
-
_bq = _block.layer[0].SelfAttention.lora_q.lora_B
|
| 1005 |
-
_bv = _block.layer[0].SelfAttention.lora_v.lora_B
|
| 1006 |
-
print(f" encoder.block[0] lora_q.B grad: {'None' if _bq.grad is None else f'norm={_bq.grad.norm().item():.6e}'}")
|
| 1007 |
-
print(f" encoder.block[0] lora_v.B grad: {'None' if _bv.grad is None else f'norm={_bv.grad.norm().item():.6e}'}")
|
| 1008 |
-
model.zero_grad()
|
| 1009 |
-
|
| 1010 |
-
# ---- Restore GC and cleanup ----
|
| 1011 |
-
_use_gc = training_args.gradient_checkpointing
|
| 1012 |
-
for _m in model.modules():
|
| 1013 |
-
if hasattr(_m, 'gradient_checkpointing'):
|
| 1014 |
-
_m.gradient_checkpointing = _use_gc
|
| 1015 |
-
del _test_loader, _test_batch, _test_input
|
| 1016 |
-
torch.cuda.empty_cache()
|
| 1017 |
-
print("\n" + "=" * 60)
|
| 1018 |
-
# ============ END DEEP GRADIENT DIAGNOSTIC ============
|
| 1019 |
|
| 1020 |
all_metrics = {"run_name": training_args.run_name}
|
| 1021 |
|
|
|
|
| 501 |
use_auth_token=True if model_args.use_auth_token else None,
|
| 502 |
)
|
| 503 |
|
| 504 |
+
# FIX: from_pretrained wraps model construction in no_init_weights() context,
|
| 505 |
+
# which replaces nn.init.kaiming_uniform_ with a no-op. This leaves lora_A
|
| 506 |
+
# as all zeros (from torch.zeros in constructor), making LoRA output = 0
|
| 507 |
+
# and all lora_B gradients = 0. Re-initialize lora_A here.
|
| 508 |
+
_n_reinit = 0
|
| 509 |
+
for _module in model.modules():
|
| 510 |
+
if hasattr(_module, 'lora_A') and hasattr(_module, 'lora_B') and hasattr(_module, 'reset_parameters'):
|
| 511 |
+
nn.init.kaiming_uniform_(_module.lora_A, a=math.sqrt(5))
|
| 512 |
+
_n_reinit += 1
|
| 513 |
+
print(f"[FIX] Re-initialized lora_A in {_n_reinit} LoRA layers with kaiming_uniform_")
|
| 514 |
+
# Verify fix
|
| 515 |
+
for _module in model.modules():
|
| 516 |
+
if hasattr(_module, 'lora_A'):
|
| 517 |
+
print(f" lora_A: norm={_module.lora_A.data.norm().item():.6f}, all_zero={(_module.lora_A.data == 0).all().item()}")
|
| 518 |
+
break
|
| 519 |
+
|
| 520 |
model.persent = training_args.persent
|
| 521 |
model.resize_token_embeddings(len(tokenizer))
|
| 522 |
|
|
|
|
| 896 |
trainer.is_deepspeed_enabled = False
|
| 897 |
print("is_deepspeed_enabled", trainer.is_deepspeed_enabled)
|
| 898 |
|
| 899 |
+
# ============ QUICK SANITY CHECK: LoRA weights ============
|
| 900 |
+
if training_args.do_train:
|
| 901 |
print("=" * 60)
|
| 902 |
+
print("[SANITY] Checking LoRA layer initialization...")
|
|
|
|
| 903 |
model.to(device)
|
| 904 |
+
_lora = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 905 |
for _m in model.modules():
|
| 906 |
+
if hasattr(_m, 'lora_A') and hasattr(_m, 'lora_B'):
|
| 907 |
+
_lora = _m
|
| 908 |
+
break
|
| 909 |
+
if _lora is not None:
|
| 910 |
+
_a_ok = _lora.lora_A.data.norm().item() > 0
|
| 911 |
+
print(f" lora_A norm={_lora.lora_A.data.norm().item():.6f}, requires_grad={_lora.lora_A.requires_grad} {'OK' if _a_ok else 'ZERO - BUG!'}")
|
| 912 |
+
print(f" lora_B norm={_lora.lora_B.data.norm().item():.6f}, requires_grad={_lora.lora_B.requires_grad}")
|
| 913 |
+
# Quick forward+backward test
|
| 914 |
+
_test_x = torch.randn(1, 3, _lora.lora_A.shape[1], device=device)
|
| 915 |
+
_lora.lora_B.grad = None
|
| 916 |
+
_y = _lora(_test_x)
|
| 917 |
+
_y.sum().backward()
|
| 918 |
+
_b_grad = _lora.lora_B.grad.norm().item() if _lora.lora_B.grad is not None else 0
|
| 919 |
+
print(f" lora_B.grad norm={_b_grad:.6e} {'OK' if _b_grad > 0 else 'ZERO - BUG!'}")
|
| 920 |
+
model.zero_grad()
|
| 921 |
+
if not _a_ok:
|
| 922 |
+
raise RuntimeError("lora_A is all zeros! from_pretrained no_init_weights fix failed.")
|
| 923 |
+
print("=" * 60)
|
| 924 |
+
# ============ END SANITY CHECK ============
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
|
| 926 |
all_metrics = {"run_name": training_args.run_name}
|
| 927 |
|