apolinario commited on
Commit
9014add
·
1 Parent(s): b0fc9e3

Patch Gemma2.forward to put torch.tensor on the embedding's device (ZeroGPU rejects the CPU normalizer cross-device multiply)

Browse files
Files changed (1) hide show
  1. app.py +23 -0
app.py CHANGED
@@ -74,6 +74,29 @@ def _broadcasting_vmap_for_bhqkv(mask_function, bh_indices: bool = True):
74
 
75
  _mu._vmap_for_bhqkv = _broadcasting_vmap_for_bhqkv
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
78
  pipeline.to("cuda")
79
 
 
74
 
75
  _mu._vmap_for_bhqkv = _broadcasting_vmap_for_bhqkv
76
 
77
+ # Gemma2's forward does `normalizer = torch.tensor(hidden_size**0.5, dtype=...)`
78
+ # without a device kwarg, so it lands on CPU while hidden_states is on cuda.
79
+ # Vanilla CUDA tolerates the cross-device scalar op; ZeroGPU's __torch_function__
80
+ # hijack rejects it. Force torch.tensor calls inside Gemma2.forward onto the
81
+ # embedding's device.
82
+ import transformers.models.gemma2.modeling_gemma2 as _gm
83
+
84
+ _orig_gemma2_forward = _gm.Gemma2Model.forward
85
+
86
+ def _patched_gemma2_forward(self, *args, **kwargs):
87
+ _orig_tt = torch.tensor
88
+ dev = self.embed_tokens.weight.device
89
+ def _tt(data, *a, **kw):
90
+ kw.setdefault("device", dev)
91
+ return _orig_tt(data, *a, **kw)
92
+ torch.tensor = _tt
93
+ try:
94
+ return _orig_gemma2_forward(self, *args, **kwargs)
95
+ finally:
96
+ torch.tensor = _orig_tt
97
+
98
+ _gm.Gemma2Model.forward = _patched_gemma2_forward
99
+
100
  pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
101
  pipeline.to("cuda")
102