fix bug
Browse files
improve_gainlora/src/t5_specroute.py
CHANGED
|
@@ -23,6 +23,12 @@ from torch import nn
|
|
| 23 |
from torch.nn import CrossEntropyLoss
|
| 24 |
from torch.utils.checkpoint import checkpoint
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from transformers.modeling_outputs import (
|
| 27 |
BaseModelOutput,
|
| 28 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
@@ -183,7 +189,15 @@ class T5Stack(T5PreTrainedModel):
|
|
| 183 |
A = lora.lora_A.data.float() # (r, d_model)
|
| 184 |
B = lora.lora_B.data.float() # (inner_dim, r)
|
| 185 |
delta_W = B @ A # (inner_dim, d_model)
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
r = lora.r
|
| 188 |
V = Vt[:r].to(h.device, dtype=h.dtype) # (r, d_model)
|
| 189 |
sigma = S[:r].to(h.device, dtype=h.dtype) # (r,)
|
|
|
|
| 23 |
from torch.nn import CrossEntropyLoss
|
| 24 |
from torch.utils.checkpoint import checkpoint
|
| 25 |
|
| 26 |
+
# Prefer magma backend for linalg ops (cusolver can crash on some GPU configs)
|
| 27 |
+
try:
|
| 28 |
+
torch.backends.cuda.preferred_linalg_library("magma")
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
from transformers.modeling_outputs import (
|
| 33 |
BaseModelOutput,
|
| 34 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
|
| 189 |
A = lora.lora_A.data.float() # (r, d_model)
|
| 190 |
B = lora.lora_B.data.float() # (inner_dim, r)
|
| 191 |
delta_W = B @ A # (inner_dim, d_model)
|
| 192 |
+
# Clamp NaN/Inf to avoid cusolver crash
|
| 193 |
+
if not torch.isfinite(delta_W).all():
|
| 194 |
+
delta_W = torch.nan_to_num(delta_W, nan=0.0, posinf=1e6, neginf=-1e6)
|
| 195 |
+
try:
|
| 196 |
+
_, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
|
| 197 |
+
except RuntimeError:
|
| 198 |
+
# cusolver can fail on certain GPU configs; fall back to CPU
|
| 199 |
+
_, S, Vt = torch.linalg.svd(delta_W.cpu(), full_matrices=False)
|
| 200 |
+
S, Vt = S.to(delta_W.device), Vt.to(delta_W.device)
|
| 201 |
r = lora.r
|
| 202 |
V = Vt[:r].to(h.device, dtype=h.dtype) # (r, d_model)
|
| 203 |
sigma = S[:r].to(h.device, dtype=h.dtype) # (r,)
|