Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,8 @@ from transformers import AutoTokenizer
|
|
| 8 |
from safetensors.torch import load_file as load_sft
|
| 9 |
from huggingface_hub import snapshot_download
|
| 10 |
|
|
|
|
|
|
|
| 11 |
# ===============================================
|
| 12 |
# Default config (from your training notes)
|
| 13 |
# ===============================================
|
|
@@ -60,6 +62,9 @@ class AttnBlock(nn.Module):
|
|
| 60 |
return Qh2, Kh2
|
| 61 |
|
| 62 |
def forward(self, x, rope, radius):
|
|
|
|
|
|
|
|
|
|
| 63 |
h = self.norm1(x)
|
| 64 |
B, S, E = h.shape
|
| 65 |
cos, sin = rope
|
|
@@ -382,11 +387,15 @@ def load_model(source: str):
|
|
| 382 |
else:
|
| 383 |
model.load_state_dict(state, strict=True)
|
| 384 |
|
| 385 |
-
#
|
| 386 |
model = model.to(torch.float32)
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
model.eval()
|
| 392 |
return model, tokenizer, conf["radius"]
|
|
|
|
| 8 |
from safetensors.torch import load_file as load_sft
|
| 9 |
from huggingface_hub import snapshot_download
|
| 10 |
|
| 11 |
+
torch.set_default_dtype(torch.float32)
|
| 12 |
+
|
| 13 |
# ===============================================
|
| 14 |
# Default config (from your training notes)
|
| 15 |
# ===============================================
|
|
|
|
| 62 |
return Qh2, Kh2
|
| 63 |
|
| 64 |
def forward(self, x, rope, radius):
|
| 65 |
+
if x.dtype != self.norm1.weight.dtype:
|
| 66 |
+
x = x.to(self.norm1.weight.dtype)
|
| 67 |
+
|
| 68 |
h = self.norm1(x)
|
| 69 |
B, S, E = h.shape
|
| 70 |
cos, sin = rope
|
|
|
|
| 387 |
else:
|
| 388 |
model.load_state_dict(state, strict=True)
|
| 389 |
|
| 390 |
+
# ✅ hard-cast ALL params & buffers to float32 (handles weights-only .pt that saved as float64)
|
| 391 |
model = model.to(torch.float32)
|
| 392 |
+
with torch.no_grad():
|
| 393 |
+
for p in model.parameters():
|
| 394 |
+
if p.dtype.is_floating_point:
|
| 395 |
+
p.data = p.data.float()
|
| 396 |
+
for _, buf in model.named_buffers():
|
| 397 |
+
if buf.dtype.is_floating_point:
|
| 398 |
+
buf.data = buf.data.float()
|
| 399 |
|
| 400 |
model.eval()
|
| 401 |
return model, tokenizer, conf["radius"]
|