basilboy commited on
Commit
db8b5ba
·
verified ·
1 Parent(s): d8c125e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -8,6 +8,8 @@ from transformers import AutoTokenizer
8
  from safetensors.torch import load_file as load_sft
9
  from huggingface_hub import snapshot_download
10
 
 
 
11
  # ===============================================
12
  # Default config (from your training notes)
13
  # ===============================================
@@ -60,6 +62,9 @@ class AttnBlock(nn.Module):
60
  return Qh2, Kh2
61
 
62
  def forward(self, x, rope, radius):
 
 
 
63
  h = self.norm1(x)
64
  B, S, E = h.shape
65
  cos, sin = rope
@@ -382,11 +387,15 @@ def load_model(source: str):
382
  else:
383
  model.load_state_dict(state, strict=True)
384
 
385
- # 🔧 ensure everything is float32 (fixes mixed-dtype LayerNorm on CPU)
386
  model = model.to(torch.float32)
387
- for bname, buf in model.named_buffers():
388
- if buf.dtype.is_floating_point:
389
- setattr(model, bname, buf.float())
 
 
 
 
390
 
391
  model.eval()
392
  return model, tokenizer, conf["radius"]
 
8
  from safetensors.torch import load_file as load_sft
9
  from huggingface_hub import snapshot_download
10
 
11
+ torch.set_default_dtype(torch.float32)
12
+
13
  # ===============================================
14
  # Default config (from your training notes)
15
  # ===============================================
 
62
  return Qh2, Kh2
63
 
64
  def forward(self, x, rope, radius):
65
+ if x.dtype != self.norm1.weight.dtype:
66
+ x = x.to(self.norm1.weight.dtype)
67
+
68
  h = self.norm1(x)
69
  B, S, E = h.shape
70
  cos, sin = rope
 
387
  else:
388
  model.load_state_dict(state, strict=True)
389
 
390
+ # hard-cast ALL params & buffers to float32 (handles weights-only .pt that saved as float64)
391
  model = model.to(torch.float32)
392
+ with torch.no_grad():
393
+ for p in model.parameters():
394
+ if p.dtype.is_floating_point:
395
+ p.data = p.data.float()
396
+ for _, buf in model.named_buffers():
397
+ if buf.dtype.is_floating_point:
398
+ buf.data = buf.data.float()
399
 
400
  model.eval()
401
  return model, tokenizer, conf["radius"]