basilboy commited on
Commit
d8c125e
·
verified ·
1 Parent(s): b5f874f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -381,7 +381,13 @@ def load_model(source: str):
381
  nn.init.zeros_(model.proj.bias)
382
  else:
383
  model.load_state_dict(state, strict=True)
384
-
 
 
 
 
 
 
385
  model.eval()
386
  return model, tokenizer, conf["radius"]
387
 
 
381
  nn.init.zeros_(model.proj.bias)
382
  else:
383
  model.load_state_dict(state, strict=True)
384
+
385
+ # 🔧 ensure everything is float32 (fixes mixed-dtype LayerNorm on CPU)
386
+ model = model.to(torch.float32)
387
+ for bname, buf in model.named_buffers():
388
+ if buf.dtype.is_floating_point:
389
+ setattr(model, bname, buf.float())
390
+
391
  model.eval()
392
  return model, tokenizer, conf["radius"]
393