nathanael-fijalkow commited on
Commit
5bcafc3
·
1 Parent(s): 36f1a65

fix load_model

Browse files
Files changed (1) hide show
  1. src/evaluate.py +30 -8
src/evaluate.py CHANGED
@@ -155,11 +155,12 @@ class ChessEvaluator:
155
  input_text = self.tokenizer.bos_token + " " + moves_str
156
 
157
  # Tokenize
 
158
  inputs = self.tokenizer(
159
  input_text,
160
  return_tensors="pt",
161
  truncation=True,
162
- max_length=self.model.config.n_ctx - 1,
163
  ).to(self.device)
164
 
165
  # Try to generate a legal move
@@ -476,20 +477,41 @@ def load_model_from_hub(model_id: str, device: str = "auto"):
476
  Returns:
477
  Tuple of (model, tokenizer).
478
  """
479
- from transformers import AutoModelForCausalLM, AutoTokenizer
480
 
481
- # Import to register custom classes (use relative import or handle both cases)
482
  try:
483
  from src.model import ChessConfig, ChessForCausalLM
484
  except ImportError:
485
  from .model import ChessConfig, ChessForCausalLM
486
 
 
 
 
 
 
 
 
 
 
487
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
488
- model = AutoModelForCausalLM.from_pretrained(
489
- model_id,
490
- trust_remote_code=True,
491
- device_map=device,
492
- )
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  return model, tokenizer
495
 
 
155
  input_text = self.tokenizer.bos_token + " " + moves_str
156
 
157
  # Tokenize
158
+ max_len = getattr(self.model.config, 'n_ctx', None) or getattr(self.model.config, 'max_position_embeddings', 256)
159
  inputs = self.tokenizer(
160
  input_text,
161
  return_tensors="pt",
162
  truncation=True,
163
+ max_length=max_len - 1,
164
  ).to(self.device)
165
 
166
  # Try to generate a legal move
 
477
  Returns:
478
  Tuple of (model, tokenizer).
479
  """
480
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
481
 
482
+ # Import and register custom classes
483
  try:
484
  from src.model import ChessConfig, ChessForCausalLM
485
  except ImportError:
486
  from .model import ChessConfig, ChessForCausalLM
487
 
488
+ # Explicitly register in case it wasn't done yet
489
+ try:
490
+ AutoConfig.register("chess_transformer", ChessConfig)
491
+ AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
492
+ except ValueError:
493
+ # Already registered
494
+ pass
495
+
496
+ # Load tokenizer
497
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
498
+
499
+ # Load model - try with trust_remote_code first, then without
500
+ try:
501
+ model = AutoModelForCausalLM.from_pretrained(
502
+ model_id,
503
+ trust_remote_code=True,
504
+ device_map=device,
505
+ )
506
+ except Exception as e:
507
+ print(f"Loading with trust_remote_code failed ({e}), trying local classes...")
508
+ # Load config and model using our local classes
509
+ config = ChessConfig.from_pretrained(model_id)
510
+ model = ChessForCausalLM.from_pretrained(
511
+ model_id,
512
+ config=config,
513
+ device_map=device,
514
+ )
515
 
516
  return model, tokenizer
517