mineself2016 commited on
Commit
ea25230
·
verified ·
1 Parent(s): 3d0c815

Sync latest GeneMamba docs and next-token training updates

Browse files
Files changed (1) hide show
  1. scripts/push_to_hub.py +53 -6
scripts/push_to_hub.py CHANGED
@@ -12,6 +12,7 @@ Requirements:
12
  import os
13
  import shutil
14
  import argparse
 
15
  from pathlib import Path
16
  from huggingface_hub import HfApi
17
 
@@ -29,6 +30,47 @@ def collect_local_files(root: Path):
29
  return files
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def main():
33
  project_root = Path(__file__).resolve().parent.parent
34
 
@@ -118,18 +160,23 @@ def main():
118
  "modeling_genemamba.py",
119
  "configuration_genemamba.py",
120
  "modeling_outputs.py",
 
121
  ]
122
 
123
- print(" - Copying model definition files...")
124
  for file in model_files:
125
  src = script_dir / file
126
  dst = model_path / file
127
- if src.exists() and not dst.exists():
128
- shutil.copy(src, dst)
129
- print(f" ✓ Copied {file}")
130
- elif dst.exists():
131
- print(f" ✓ {file} already exists")
132
 
 
 
 
 
133
  print("✓ Model files prepared")
134
 
135
  except Exception as e:
 
12
  import os
13
  import shutil
14
  import argparse
15
+ import json
16
  from pathlib import Path
17
  from huggingface_hub import HfApi
18
 
 
30
  return files
31
 
32
 
33
+ def normalize_config_for_hf(config_path: Path):
34
+ with config_path.open("r", encoding="utf-8") as f:
35
+ config = json.load(f)
36
+
37
+ if "d_model" in config and "hidden_size" not in config:
38
+ config["hidden_size"] = config["d_model"]
39
+ if "mamba_layer" in config and "num_hidden_layers" not in config:
40
+ config["num_hidden_layers"] = config["mamba_layer"]
41
+
42
+ legacy_checkpoint_config = ("d_model" in config) or ("mamba_layer" in config)
43
+
44
+ config["model_type"] = "genemamba"
45
+ config.setdefault("architectures", ["GeneMambaModel"])
46
+ config.setdefault("max_position_embeddings", 2048)
47
+ config.setdefault("intermediate_size", 2048)
48
+ config.setdefault("hidden_dropout_prob", 0.1)
49
+ config.setdefault("initializer_range", 0.02)
50
+ if legacy_checkpoint_config and config.get("mamba_mode") == "gate":
51
+ config["mamba_mode"] = "mean"
52
+ else:
53
+ config.setdefault("mamba_mode", "mean")
54
+ config.setdefault("embedding_pooling", "mean")
55
+ config.setdefault("num_labels", 2)
56
+ config.setdefault("pad_token_id", 1)
57
+ config.setdefault("bos_token_id", 0)
58
+ config.setdefault("eos_token_id", 2)
59
+ config.setdefault("use_cache", True)
60
+ config.setdefault("torch_dtype", "float32")
61
+ config.setdefault("transformers_version", "4.40.2")
62
+ config["auto_map"] = {
63
+ "AutoConfig": "configuration_genemamba.GeneMambaConfig",
64
+ "AutoModel": "modeling_genemamba.GeneMambaModel",
65
+ "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
66
+ "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification",
67
+ }
68
+
69
+ with config_path.open("w", encoding="utf-8") as f:
70
+ json.dump(config, f, indent=2)
71
+ f.write("\n")
72
+
73
+
74
  def main():
75
  project_root = Path(__file__).resolve().parent.parent
76
 
 
160
  "modeling_genemamba.py",
161
  "configuration_genemamba.py",
162
  "modeling_outputs.py",
163
+ "README.md",
164
  ]
165
 
166
+ print(" - Syncing model definition files...")
167
  for file in model_files:
168
  src = script_dir / file
169
  dst = model_path / file
170
+ if not src.exists():
171
+ print(f" ✗ Missing source file: {file}")
172
+ return 1
173
+ shutil.copy(src, dst)
174
+ print(f" ✓ Synced {file}")
175
 
176
+ config_path = model_path / "config.json"
177
+ normalize_config_for_hf(config_path)
178
+ print(" - Normalized config.json for custom AutoModel loading")
179
+
180
  print("✓ Model files prepared")
181
 
182
  except Exception as e: