Upload ChessBot Chess model
Browse files- modeling_chessbot.py +0 -71
modeling_chessbot.py
CHANGED
|
@@ -527,77 +527,6 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 527 |
# Initialize weights
|
| 528 |
self.post_init()
|
| 529 |
|
| 530 |
-
@classmethod
|
| 531 |
-
def from_pretrained(cls, model_path, **kwargs):
|
| 532 |
-
"""
|
| 533 |
-
Load a pretrained model from a directory (HuggingFace compatible)
|
| 534 |
-
"""
|
| 535 |
-
import os
|
| 536 |
-
|
| 537 |
-
# Load config
|
| 538 |
-
config_path = os.path.join(model_path, "config.json")
|
| 539 |
-
if os.path.exists(config_path):
|
| 540 |
-
config = ChessBotConfig.from_pretrained(model_path)
|
| 541 |
-
else:
|
| 542 |
-
config = ChessBotConfig()
|
| 543 |
-
|
| 544 |
-
# Create model instance
|
| 545 |
-
model = cls(config)
|
| 546 |
-
|
| 547 |
-
# Load weights
|
| 548 |
-
model_file = None
|
| 549 |
-
for filename in ["pytorch_model.bin", "model.safetensors"]:
|
| 550 |
-
full_path = os.path.join(model_path, filename)
|
| 551 |
-
if os.path.exists(full_path):
|
| 552 |
-
model_file = full_path
|
| 553 |
-
break
|
| 554 |
-
|
| 555 |
-
if model_file is None:
|
| 556 |
-
raise FileNotFoundError(f"No model file found in {model_path}")
|
| 557 |
-
|
| 558 |
-
if model_file.endswith('.safetensors'):
|
| 559 |
-
# Handle safetensors format
|
| 560 |
-
try:
|
| 561 |
-
from safetensors import safe_open
|
| 562 |
-
state_dict = {}
|
| 563 |
-
with safe_open(model_file, framework="pt", device="cpu") as f:
|
| 564 |
-
for key in f.keys():
|
| 565 |
-
state_dict[key] = f.get_tensor(key)
|
| 566 |
-
except ImportError:
|
| 567 |
-
raise ImportError("safetensors library is required to load .safetensors files. Install with: pip install safetensors")
|
| 568 |
-
else:
|
| 569 |
-
# Handle pytorch format
|
| 570 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
| 571 |
-
|
| 572 |
-
# Load state dict into model
|
| 573 |
-
model.load_state_dict(state_dict, strict=False)
|
| 574 |
-
|
| 575 |
-
return model
|
| 576 |
-
|
| 577 |
-
def save_pretrained(self, save_directory, safe_serialization=False):
|
| 578 |
-
"""
|
| 579 |
-
Save the model to a directory (HuggingFace compatible)
|
| 580 |
-
"""
|
| 581 |
-
import os
|
| 582 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 583 |
-
|
| 584 |
-
# Save config
|
| 585 |
-
self.config.save_pretrained(save_directory)
|
| 586 |
-
|
| 587 |
-
# Save model weights
|
| 588 |
-
if safe_serialization:
|
| 589 |
-
try:
|
| 590 |
-
from safetensors.torch import save_file
|
| 591 |
-
model_path = os.path.join(save_directory, "model.safetensors")
|
| 592 |
-
save_file(self.state_dict(), model_path)
|
| 593 |
-
except ImportError:
|
| 594 |
-
print("⚠ Warning: safetensors not available, falling back to pytorch_model.bin")
|
| 595 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 596 |
-
torch.save(self.state_dict(), model_path)
|
| 597 |
-
else:
|
| 598 |
-
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
| 599 |
-
torch.save(self.state_dict(), model_path)
|
| 600 |
-
|
| 601 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
| 602 |
"""
|
| 603 |
Forward pass compatible with both HuggingFace interface and original interface
|
|
|
|
| 527 |
# Initialize weights
|
| 528 |
self.post_init()
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
def forward(self, input_ids, attention_mask=None, compute_loss=False):
|
| 531 |
"""
|
| 532 |
Forward pass compatible with both HuggingFace interface and original interface
|