Spaces:
Running
Running
Commit
·
5bcafc3
1
Parent(s):
36f1a65
fix load_model
Browse files- src/evaluate.py +30 -8
src/evaluate.py
CHANGED
|
@@ -155,11 +155,12 @@ class ChessEvaluator:
|
|
| 155 |
input_text = self.tokenizer.bos_token + " " + moves_str
|
| 156 |
|
| 157 |
# Tokenize
|
|
|
|
| 158 |
inputs = self.tokenizer(
|
| 159 |
input_text,
|
| 160 |
return_tensors="pt",
|
| 161 |
truncation=True,
|
| 162 |
-
max_length=
|
| 163 |
).to(self.device)
|
| 164 |
|
| 165 |
# Try to generate a legal move
|
|
@@ -476,20 +477,41 @@ def load_model_from_hub(model_id: str, device: str = "auto"):
|
|
| 476 |
Returns:
|
| 477 |
Tuple of (model, tokenizer).
|
| 478 |
"""
|
| 479 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 480 |
|
| 481 |
-
# Import
|
| 482 |
try:
|
| 483 |
from src.model import ChessConfig, ChessForCausalLM
|
| 484 |
except ImportError:
|
| 485 |
from .model import ChessConfig, ChessForCausalLM
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
return model, tokenizer
|
| 495 |
|
|
|
|
| 155 |
input_text = self.tokenizer.bos_token + " " + moves_str
|
| 156 |
|
| 157 |
# Tokenize
|
| 158 |
+
max_len = getattr(self.model.config, 'n_ctx', None) or getattr(self.model.config, 'max_position_embeddings', 256)
|
| 159 |
inputs = self.tokenizer(
|
| 160 |
input_text,
|
| 161 |
return_tensors="pt",
|
| 162 |
truncation=True,
|
| 163 |
+
max_length=max_len - 1,
|
| 164 |
).to(self.device)
|
| 165 |
|
| 166 |
# Try to generate a legal move
|
|
|
|
| 477 |
Returns:
|
| 478 |
Tuple of (model, tokenizer).
|
| 479 |
"""
|
| 480 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
| 481 |
|
| 482 |
+
# Import and register custom classes
|
| 483 |
try:
|
| 484 |
from src.model import ChessConfig, ChessForCausalLM
|
| 485 |
except ImportError:
|
| 486 |
from .model import ChessConfig, ChessForCausalLM
|
| 487 |
|
| 488 |
+
# Explicitly register in case it wasn't done yet
|
| 489 |
+
try:
|
| 490 |
+
AutoConfig.register("chess_transformer", ChessConfig)
|
| 491 |
+
AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
|
| 492 |
+
except ValueError:
|
| 493 |
+
# Already registered
|
| 494 |
+
pass
|
| 495 |
+
|
| 496 |
+
# Load tokenizer
|
| 497 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 498 |
+
|
| 499 |
+
# Load model - try with trust_remote_code first, then without
|
| 500 |
+
try:
|
| 501 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 502 |
+
model_id,
|
| 503 |
+
trust_remote_code=True,
|
| 504 |
+
device_map=device,
|
| 505 |
+
)
|
| 506 |
+
except Exception as e:
|
| 507 |
+
print(f"Loading with trust_remote_code failed ({e}), trying local classes...")
|
| 508 |
+
# Load config and model using our local classes
|
| 509 |
+
config = ChessConfig.from_pretrained(model_id)
|
| 510 |
+
model = ChessForCausalLM.from_pretrained(
|
| 511 |
+
model_id,
|
| 512 |
+
config=config,
|
| 513 |
+
device_map=device,
|
| 514 |
+
)
|
| 515 |
|
| 516 |
return model, tokenizer
|
| 517 |
|