nathanael-fijalkow commited on
Commit
ea5cabf
·
1 Parent(s): 5bcafc3

model_load fix again

Browse files
Files changed (1) hide show
  1. src/evaluate.py +22 -21
src/evaluate.py CHANGED
@@ -479,39 +479,40 @@ def load_model_from_hub(model_id: str, device: str = "auto"):
479
  """
480
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
481
 
482
- # Import and register custom classes
483
  try:
484
  from src.model import ChessConfig, ChessForCausalLM
 
485
  except ImportError:
486
  from .model import ChessConfig, ChessForCausalLM
 
487
 
488
- # Explicitly register in case it wasn't done yet
489
  try:
490
  AutoConfig.register("chess_transformer", ChessConfig)
 
 
 
 
491
  AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
492
  except ValueError:
493
- # Already registered
494
- pass
495
 
496
- # Load tokenizer
497
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
498
 
499
- # Load model - try with trust_remote_code first, then without
500
  try:
501
- model = AutoModelForCausalLM.from_pretrained(
502
- model_id,
503
- trust_remote_code=True,
504
- device_map=device,
505
- )
506
- except Exception as e:
507
- print(f"Loading with trust_remote_code failed ({e}), trying local classes...")
508
- # Load config and model using our local classes
509
- config = ChessConfig.from_pretrained(model_id)
510
- model = ChessForCausalLM.from_pretrained(
511
- model_id,
512
- config=config,
513
- device_map=device,
514
- )
515
 
516
  return model, tokenizer
517
 
 
479
  """
480
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
481
 
482
+ # Import custom classes - this also triggers registration at module load
483
  try:
484
  from src.model import ChessConfig, ChessForCausalLM
485
+ from src.tokenizer import ChessTokenizer
486
  except ImportError:
487
  from .model import ChessConfig, ChessForCausalLM
488
+ from .tokenizer import ChessTokenizer
489
 
490
+ # Explicitly register to ensure it's done before loading
491
  try:
492
  AutoConfig.register("chess_transformer", ChessConfig)
493
+ except ValueError:
494
+ pass # Already registered
495
+
496
+ try:
497
  AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
498
  except ValueError:
499
+ pass # Already registered
 
500
 
501
+ # Load using our local classes directly (most reliable)
502
+ print(f"Loading model {model_id}...")
503
+ config = ChessConfig.from_pretrained(model_id, trust_remote_code=True)
504
+ model = ChessForCausalLM.from_pretrained(
505
+ model_id,
506
+ config=config,
507
+ device_map=device,
508
+ trust_remote_code=True,
509
+ )
510
 
511
+ # Load tokenizer - try custom class first, then generic
512
  try:
513
+ tokenizer = ChessTokenizer.from_pretrained(model_id)
514
+ except Exception:
515
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
516
 
517
  return model, tokenizer
518