Spaces:
Running
Running
| import yaml | |
| import json | |
| import os | |
| import shutil | |
| import argparse | |
| from colorama import init, Fore, Style | |
| init() | |
| def load_json(path): | |
| if not os.path.exists(path): | |
| return {} | |
| try: | |
| with open(path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def save_json(path, data): | |
| with open(path, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=2) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Patch missing EOS IDs in generation_config.json") | |
| parser.add_argument("config", help="Path to the mergekit yaml config file") | |
| args = parser.parse_args() | |
| print(f"{Fore.CYAN}--- GENERATION CONFIG PATCHER ---{Style.RESET_ALL}") | |
| # 1. Load Config | |
| with open(args.config, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| base_model_path = config.get('base_model') | |
| if not base_model_path: | |
| print("No base_model found.") | |
| return | |
| # 2. Get Target EOS ID from Base Model | |
| print(f"Reading Base Model: {os.path.basename(base_model_path)}") | |
| base_gen_path = os.path.join(base_model_path, "generation_config.json") | |
| base_gen = load_json(base_gen_path) | |
| target_eos_id = base_gen.get("eos_token_id") | |
| if target_eos_id is None: | |
| print(f"{Fore.RED}CRITICAL: Base model lacks eos_token_id. Cannot patch.{Style.RESET_ALL}") | |
| return | |
| print(f"Target EOS ID is: {Fore.GREEN}{target_eos_id}{Style.RESET_ALL}") | |
| print("-" * 60) | |
| # 3. Iterate and Patch | |
| models = [m['model'] for m in config.get('models', []) if isinstance(m, dict)] | |
| patched_count = 0 | |
| for model_path in models: | |
| model_name = os.path.basename(model_path).replace("!models--", "") | |
| gen_path = os.path.join(model_path, "generation_config.json") | |
| # Load or create empty dict | |
| data = load_json(gen_path) | |
| current_id = data.get("eos_token_id") | |
| # Logic: Only patch if MISSING. | |
| # If it exists but is different (e.g. 999), we DO NOT touch it (that's a real mismatch). | |
| if current_id is None: | |
| print(f"Patching {model_name}...") | |
| # Backup first | |
| if os.path.exists(gen_path): | |
| shutil.copy(gen_path, gen_path + ".bak") | |
| # Apply Patch | |
| data["eos_token_id"] = target_eos_id | |
| # Ensure other basics exist if file was empty | |
| if "bos_token_id" not in data: | |
| data["bos_token_id"] = 1 # Standard Mistral assumption | |
| save_json(gen_path, data) | |
| print(f" {Fore.GREEN}-> Fixed: Added eos_token_id: {target_eos_id}{Style.RESET_ALL}") | |
| patched_count += 1 | |
| elif str(current_id) != str(target_eos_id): | |
| print(f"Skipping {model_name}: Has ID {current_id} (Mismatch, not missing)") | |
| else: | |
| # Already matches, do nothing | |
| pass | |
| print("-" * 60) | |
| print(f"Operation Complete. Patched {patched_count} models.") | |
| print("Run eos_scanner.py again to verify results.") | |
| if __name__ == "__main__": | |
| main() |