natmin322 commited on
Commit
f332851
·
1 Parent(s): 2325456

fix bug fb16

Browse files
discusstion.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3453c142bfaf3afda3e18718267d871d49f5f1ebb22ac43b3dfe5b7e069da467
3
+ size 98496
gainlora_baseline_origin/src/run_t5.py CHANGED
@@ -501,6 +501,22 @@ def main():
501
  use_auth_token=True if model_args.use_auth_token else None,
502
  )
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  model.persent = training_args.persent
505
  model.resize_token_embeddings(len(tokenizer))
506
 
@@ -880,142 +896,32 @@ def main():
880
  trainer.is_deepspeed_enabled = False
881
  print("is_deepspeed_enabled", trainer.is_deepspeed_enabled)
882
 
883
- # ============ DEEP GRADIENT DIAGNOSTIC ============
884
- if training_args.model_name == 'specroute' and training_args.do_train:
885
  print("=" * 60)
886
- print("[DIAG] Deep gradient diagnostic")
887
- model.train()
888
  model.to(device)
889
-
890
- # ---- TEST 1: Isolated LoRA layer test ----
891
- print("\n--- TEST 1: Isolated LoRA layer ---")
892
- _lora = model.encoder.block[0].layer[0].SelfAttention.lora_q
893
- print(f" lora_A: shape={_lora.lora_A.shape}, requires_grad={_lora.lora_A.requires_grad}, "
894
- f"norm={_lora.lora_A.data.norm().item():.6f}, all_zero={_lora.lora_A.data.eq(0).all().item()}")
895
- print(f" lora_B: shape={_lora.lora_B.shape}, requires_grad={_lora.lora_B.requires_grad}, "
896
- f"norm={_lora.lora_B.data.norm().item():.6f}, all_zero={_lora.lora_B.data.eq(0).all().item()}")
897
- _test_x = torch.randn(1, 3, _lora.lora_A.shape[1], device=device)
898
- _lora.lora_B.grad = None
899
- _y = _lora(_test_x)
900
- print(f" LoRA output: norm={_y.norm().item():.6f}, requires_grad={_y.requires_grad}")
901
- _simple_loss = _y.sum()
902
- print(f" simple_loss: {_simple_loss.item():.6f}, requires_grad={_simple_loss.requires_grad}, grad_fn={_simple_loss.grad_fn}")
903
- _simple_loss.backward()
904
- print(f" lora_B.grad: {'None' if _lora.lora_B.grad is None else f'norm={_lora.lora_B.grad.norm().item():.6e}'}")
905
- # Also test with x that doesn't require grad (like in the real model when base is frozen)
906
- _lora.lora_B.grad = None
907
- _test_x2 = torch.randn(1, 3, _lora.lora_A.shape[1], device=device, requires_grad=False)
908
- _y2 = _lora(_test_x2)
909
- print(f" LoRA output (x.requires_grad=False): requires_grad={_y2.requires_grad}")
910
- _y2.sum().backward()
911
- print(f" lora_B.grad (x no grad): {'None' if _lora.lora_B.grad is None else f'norm={_lora.lora_B.grad.norm().item():.6e}'}")
912
- model.zero_grad()
913
-
914
- # ---- TEST 2: Check requires_grad propagation through the model ----
915
- print("\n--- TEST 2: requires_grad propagation ---")
916
- # Disable GC for this test to simplify
917
  for _m in model.modules():
918
- if hasattr(_m, 'gradient_checkpointing'):
919
- _m.gradient_checkpointing = False
920
- from torch.utils.data import DataLoader as _DL
921
- _test_loader = _DL(train_dataset, batch_size=2, collate_fn=data_collator)
922
- _test_batch = next(iter(_test_loader))
923
- _test_input = {}
924
- for k, v in _test_batch.items():
925
- _test_input[k] = v.to(device) if isinstance(v, torch.Tensor) else v
926
- # Hook into first encoder attention to check input
927
- _attn_module = model.encoder.block[0].layer[0].SelfAttention
928
- _hook_data = {}
929
- def _fwd_hook(module, inp, out):
930
- if isinstance(inp, tuple):
931
- h = inp[0]
932
- else:
933
- h = inp
934
- _hook_data['input_requires_grad'] = h.requires_grad if isinstance(h, torch.Tensor) else 'N/A'
935
- _hook_data['input_norm'] = h.norm().item() if isinstance(h, torch.Tensor) else 'N/A'
936
- _h = _attn_module.register_forward_hook(_fwd_hook)
937
- # Also hook LoRA output
938
- _lora_hook_data = {}
939
- def _lora_fwd_hook(module, inp, out):
940
- _lora_hook_data['output_requires_grad'] = out.requires_grad
941
- _lora_hook_data['output_norm'] = out.norm().item()
942
- if isinstance(inp, tuple):
943
- _lora_hook_data['input_requires_grad'] = inp[0].requires_grad if isinstance(inp[0], torch.Tensor) else 'N/A'
944
- else:
945
- _lora_hook_data['input_requires_grad'] = inp.requires_grad if isinstance(inp, torch.Tensor) else 'N/A'
946
- _lh = _attn_module.lora_q.register_forward_hook(_lora_fwd_hook)
947
- model.zero_grad()
948
- _outputs = model(**_test_input)
949
- _loss = _outputs.loss
950
- print(f" T5Attention input: requires_grad={_hook_data.get('input_requires_grad', 'N/A')}, norm={_hook_data.get('input_norm', 'N/A')}")
951
- print(f" LoRA_q input: requires_grad={_lora_hook_data.get('input_requires_grad', 'N/A')}")
952
- print(f" LoRA_q output: requires_grad={_lora_hook_data.get('output_requires_grad', 'N/A')}, norm={_lora_hook_data.get('output_norm', 'N/A')}")
953
- _h.remove()
954
- _lh.remove()
955
-
956
- # ---- TEST 3: Backward with grad hooks on lora_B ----
957
- print("\n--- TEST 3: Backward with hooks ---")
958
- _grad_hooks = {}
959
- def _make_grad_hook(name):
960
- def _hook(grad):
961
- _grad_hooks[name] = grad.norm().item()
962
- return _hook
963
- # Register hooks on a few lora_B params
964
- _hook_handles = []
965
- _hook_targets = [
966
- 'encoder.block.0.layer.0.SelfAttention.lora_q.lora_B',
967
- 'encoder.block.0.layer.0.SelfAttention.lora_v.lora_B',
968
- 'decoder.block.0.layer.0.SelfAttention.lora_q.lora_B',
969
- 'decoder.block.0.layer.1.EncDecAttention.lora_q.lora_B',
970
- ]
971
- for name, p in model.named_parameters():
972
- if name in _hook_targets:
973
- _hook_handles.append(p.register_hook(_make_grad_hook(name)))
974
- _loss.backward()
975
- print(f" Grad hooks captured ({len(_grad_hooks)} hooks fired):")
976
- for name, norm in _grad_hooks.items():
977
- print(f" {name}: grad_norm={norm:.6e}")
978
- if not _grad_hooks:
979
- print(" WARNING: No grad hooks fired! Backward didn't reach lora_B.")
980
- # Also check final gradient state
981
- _n_ok, _n_zero, _n_none = 0, 0, 0
982
- for name, p in model.named_parameters():
983
- if p.requires_grad:
984
- if p.grad is None:
985
- _n_none += 1
986
- elif p.grad.norm().item() > 0:
987
- _n_ok += 1
988
- else:
989
- _n_zero += 1
990
- print(f" Final grad counts: grad>0={_n_ok}, grad==0={_n_zero}, grad=None={_n_none}")
991
- for _hh in _hook_handles:
992
- _hh.remove()
993
-
994
- # ---- TEST 4: Manual single-layer backward ----
995
- print("\n--- TEST 4: Single T5Block forward+backward ---")
996
- model.zero_grad()
997
- _block = model.encoder.block[0]
998
- _hidden = torch.randn(2, 10, model.config.d_model, device=device, requires_grad=True)
999
- _mask = torch.zeros(2, 1, 1, 10, device=device)
1000
- _block_out = _block(_hidden, attention_mask=_mask, key_attention_weights=None)
1001
- _block_loss = _block_out[0].sum()
1002
- print(f" block output requires_grad={_block_out[0].requires_grad}")
1003
- _block_loss.backward()
1004
- _bq = _block.layer[0].SelfAttention.lora_q.lora_B
1005
- _bv = _block.layer[0].SelfAttention.lora_v.lora_B
1006
- print(f" encoder.block[0] lora_q.B grad: {'None' if _bq.grad is None else f'norm={_bq.grad.norm().item():.6e}'}")
1007
- print(f" encoder.block[0] lora_v.B grad: {'None' if _bv.grad is None else f'norm={_bv.grad.norm().item():.6e}'}")
1008
- model.zero_grad()
1009
-
1010
- # ---- Restore GC and cleanup ----
1011
- _use_gc = training_args.gradient_checkpointing
1012
- for _m in model.modules():
1013
- if hasattr(_m, 'gradient_checkpointing'):
1014
- _m.gradient_checkpointing = _use_gc
1015
- del _test_loader, _test_batch, _test_input
1016
- torch.cuda.empty_cache()
1017
- print("\n" + "=" * 60)
1018
- # ============ END DEEP GRADIENT DIAGNOSTIC ============
1019
 
1020
  all_metrics = {"run_name": training_args.run_name}
1021
 
 
501
  use_auth_token=True if model_args.use_auth_token else None,
502
  )
503
 
504
+ # FIX: from_pretrained wraps model construction in no_init_weights() context,
505
+ # which replaces nn.init.kaiming_uniform_ with a no-op. This leaves lora_A
506
+ # as all zeros (from torch.zeros in constructor), making LoRA output = 0
507
+ # and all lora_B gradients = 0. Re-initialize lora_A here.
508
+ _n_reinit = 0
509
+ for _module in model.modules():
510
+ if hasattr(_module, 'lora_A') and hasattr(_module, 'lora_B') and hasattr(_module, 'reset_parameters'):
511
+ nn.init.kaiming_uniform_(_module.lora_A, a=math.sqrt(5))
512
+ _n_reinit += 1
513
+ print(f"[FIX] Re-initialized lora_A in {_n_reinit} LoRA layers with kaiming_uniform_")
514
+ # Verify fix
515
+ for _module in model.modules():
516
+ if hasattr(_module, 'lora_A'):
517
+ print(f" lora_A: norm={_module.lora_A.data.norm().item():.6f}, all_zero={(_module.lora_A.data == 0).all().item()}")
518
+ break
519
+
520
  model.persent = training_args.persent
521
  model.resize_token_embeddings(len(tokenizer))
522
 
 
896
  trainer.is_deepspeed_enabled = False
897
  print("is_deepspeed_enabled", trainer.is_deepspeed_enabled)
898
 
899
+ # ============ QUICK SANITY CHECK: LoRA weights ============
900
+ if training_args.do_train:
901
  print("=" * 60)
902
+ print("[SANITY] Checking LoRA layer initialization...")
 
903
  model.to(device)
904
+ _lora = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
905
  for _m in model.modules():
906
+ if hasattr(_m, 'lora_A') and hasattr(_m, 'lora_B'):
907
+ _lora = _m
908
+ break
909
+ if _lora is not None:
910
+ _a_ok = _lora.lora_A.data.norm().item() > 0
911
+ print(f" lora_A norm={_lora.lora_A.data.norm().item():.6f}, requires_grad={_lora.lora_A.requires_grad} {'OK' if _a_ok else 'ZERO - BUG!'}")
912
+ print(f" lora_B norm={_lora.lora_B.data.norm().item():.6f}, requires_grad={_lora.lora_B.requires_grad}")
913
+ # Quick forward+backward test
914
+ _test_x = torch.randn(1, 3, _lora.lora_A.shape[1], device=device)
915
+ _lora.lora_B.grad = None
916
+ _y = _lora(_test_x)
917
+ _y.sum().backward()
918
+ _b_grad = _lora.lora_B.grad.norm().item() if _lora.lora_B.grad is not None else 0
919
+ print(f" lora_B.grad norm={_b_grad:.6e} {'OK' if _b_grad > 0 else 'ZERO - BUG!'}")
920
+ model.zero_grad()
921
+ if not _a_ok:
922
+ raise RuntimeError("lora_A is all zeros! from_pretrained no_init_weights fix failed.")
923
+ print("=" * 60)
924
+ # ============ END SANITY CHECK ============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
 
926
  all_metrics = {"run_name": training_args.run_name}
927