natmin322 commited on
Commit
0aeac35
·
1 Parent(s): 9dc2b5d
improve_gainlora/src/t5_specroute.py CHANGED
@@ -23,6 +23,12 @@ from torch import nn
23
  from torch.nn import CrossEntropyLoss
24
  from torch.utils.checkpoint import checkpoint
25
 
 
 
 
 
 
 
26
  from transformers.modeling_outputs import (
27
  BaseModelOutput,
28
  BaseModelOutputWithPastAndCrossAttentions,
@@ -183,7 +189,15 @@ class T5Stack(T5PreTrainedModel):
183
  A = lora.lora_A.data.float() # (r, d_model)
184
  B = lora.lora_B.data.float() # (inner_dim, r)
185
  delta_W = B @ A # (inner_dim, d_model)
186
- _, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
 
 
 
 
 
 
 
 
187
  r = lora.r
188
  V = Vt[:r].to(h.device, dtype=h.dtype) # (r, d_model)
189
  sigma = S[:r].to(h.device, dtype=h.dtype) # (r,)
 
23
  from torch.nn import CrossEntropyLoss
24
  from torch.utils.checkpoint import checkpoint
25
 
26
+ # Prefer magma backend for linalg ops (cusolver can crash on some GPU configs)
27
+ try:
28
+ torch.backends.cuda.preferred_linalg_library("magma")
29
+ except Exception:
30
+ pass
31
+
32
  from transformers.modeling_outputs import (
33
  BaseModelOutput,
34
  BaseModelOutputWithPastAndCrossAttentions,
 
189
  A = lora.lora_A.data.float() # (r, d_model)
190
  B = lora.lora_B.data.float() # (inner_dim, r)
191
  delta_W = B @ A # (inner_dim, d_model)
192
+ # Clamp NaN/Inf to avoid cusolver crash
193
+ if not torch.isfinite(delta_W).all():
194
+ delta_W = torch.nan_to_num(delta_W, nan=0.0, posinf=1e6, neginf=-1e6)
195
+ try:
196
+ _, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
197
+ except RuntimeError:
198
+ # cusolver can fail on certain GPU configs; fall back to CPU
199
+ _, S, Vt = torch.linalg.svd(delta_W.cpu(), full_matrices=False)
200
+ S, Vt = S.to(delta_W.device), Vt.to(delta_W.device)
201
  r = lora.r
202
  V = Vt[:r].to(h.device, dtype=h.dtype) # (r, d_model)
203
  sigma = S[:r].to(h.device, dtype=h.dtype) # (r,)