nathanael-fijalkow commited on
Commit
df21660
·
1 Parent(s): be9adf7

add model to submission script

Browse files
Files changed (1) hide show
  1. submit.py +15 -1
submit.py CHANGED
@@ -82,17 +82,31 @@ def main():
82
  # This adds the 'auto_map' field to tokenizer_config.json
83
  tokenizer.register_for_auto_class("AutoTokenizer")
84
 
 
 
 
 
 
 
 
85
  # Save model and tokenizer
86
  model.save_pretrained(tmp_path)
87
  tokenizer.save_pretrained(tmp_path)
88
 
89
  # Copy tokenizer.py to allow loading with trust_remote_code=True
90
  # This ensures the custom ChessTokenizer can be loaded from the Hub
 
91
  tokenizer_src = Path(__file__).parent / "src" / "tokenizer.py"
92
  if tokenizer_src.exists():
93
- import shutil
94
  shutil.copy(tokenizer_src, tmp_path / "tokenizer.py")
95
  print(" Included tokenizer.py for remote loading")
 
 
 
 
 
 
 
96
 
97
  # Create model card with submitter info
98
  model_card = f"""---
 
82
  # This adds the 'auto_map' field to tokenizer_config.json
83
  tokenizer.register_for_auto_class("AutoTokenizer")
84
 
85
+ # Register model for AutoModelForCausalLM so custom architectures load correctly
86
+ # This adds the 'auto_map' field to config.json
87
+ model.config.auto_map = {
88
+ "AutoConfig": "model.ChessConfig",
89
+ "AutoModelForCausalLM": "model.ChessForCausalLM",
90
+ }
91
+
92
  # Save model and tokenizer
93
  model.save_pretrained(tmp_path)
94
  tokenizer.save_pretrained(tmp_path)
95
 
96
  # Copy tokenizer.py to allow loading with trust_remote_code=True
97
  # This ensures the custom ChessTokenizer can be loaded from the Hub
98
+ import shutil
99
  tokenizer_src = Path(__file__).parent / "src" / "tokenizer.py"
100
  if tokenizer_src.exists():
 
101
  shutil.copy(tokenizer_src, tmp_path / "tokenizer.py")
102
  print(" Included tokenizer.py for remote loading")
103
+
104
+ # Copy model.py to allow loading custom model architectures with trust_remote_code=True
105
+ # This ensures students who modify the model architecture can load their models from the Hub
106
+ model_src = Path(__file__).parent / "src" / "model.py"
107
+ if model_src.exists():
108
+ shutil.copy(model_src, tmp_path / "model.py")
109
+ print(" Included model.py for remote loading")
110
 
111
  # Create model card with submitter info
112
  model_card = f"""---