model_tools / gen_id_patcher.py
Naphula's picture
Upload 5 files
7080631 verified
raw
history blame
3.27 kB
import yaml
import json
import os
import shutil
import argparse
from colorama import init, Fore, Style
init()
def load_json(path):
if not os.path.exists(path):
return {}
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
return {}
def save_json(path, data):
with open(path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
def main():
parser = argparse.ArgumentParser(description="Patch missing EOS IDs in generation_config.json")
parser.add_argument("config", help="Path to the mergekit yaml config file")
args = parser.parse_args()
print(f"{Fore.CYAN}--- GENERATION CONFIG PATCHER ---{Style.RESET_ALL}")
# 1. Load Config
with open(args.config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
base_model_path = config.get('base_model')
if not base_model_path:
print("No base_model found.")
return
# 2. Get Target EOS ID from Base Model
print(f"Reading Base Model: {os.path.basename(base_model_path)}")
base_gen_path = os.path.join(base_model_path, "generation_config.json")
base_gen = load_json(base_gen_path)
target_eos_id = base_gen.get("eos_token_id")
if target_eos_id is None:
print(f"{Fore.RED}CRITICAL: Base model lacks eos_token_id. Cannot patch.{Style.RESET_ALL}")
return
print(f"Target EOS ID is: {Fore.GREEN}{target_eos_id}{Style.RESET_ALL}")
print("-" * 60)
# 3. Iterate and Patch
models = [m['model'] for m in config.get('models', []) if isinstance(m, dict)]
patched_count = 0
for model_path in models:
model_name = os.path.basename(model_path).replace("!models--", "")
gen_path = os.path.join(model_path, "generation_config.json")
# Load or create empty dict
data = load_json(gen_path)
current_id = data.get("eos_token_id")
# Logic: Only patch if MISSING.
# If it exists but is different (e.g. 999), we DO NOT touch it (that's a real mismatch).
if current_id is None:
print(f"Patching {model_name}...")
# Backup first
if os.path.exists(gen_path):
shutil.copy(gen_path, gen_path + ".bak")
# Apply Patch
data["eos_token_id"] = target_eos_id
# Ensure other basics exist if file was empty
if "bos_token_id" not in data:
data["bos_token_id"] = 1 # Standard Mistral assumption
save_json(gen_path, data)
print(f" {Fore.GREEN}-> Fixed: Added eos_token_id: {target_eos_id}{Style.RESET_ALL}")
patched_count += 1
elif str(current_id) != str(target_eos_id):
print(f"Skipping {model_name}: Has ID {current_id} (Mismatch, not missing)")
else:
# Already matches, do nothing
pass
print("-" * 60)
print(f"Operation Complete. Patched {patched_count} models.")
print("Run eos_scanner.py again to verify results.")
if __name__ == "__main__":
main()