model_tools / eos_scanner.py
Naphula's picture
Upload 8 files
5f463e1 verified
import yaml
import json
import os
import sys
import argparse
from colorama import init, Fore, Style
# Initialize colorama
init()
class Logger:
def __init__(self, filename="eos_audit.log"):
self.terminal = sys.stdout
self.log = open(filename, "w", encoding="utf-8")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def load_json(path):
if not os.path.exists(path):
return None
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
return None
def get_model_metadata(model_path):
# --- NAME FIX LOGIC START ---
# Normalize path to handle trailing slashes or mixed separators
norm_path = os.path.normpath(model_path)
base_name = os.path.basename(norm_path)
# If the folder is named "fixed", grab the parent folder name instead
if base_name == "fixed":
parent_name = os.path.basename(os.path.dirname(norm_path))
display_name = f"{parent_name}/fixed"
else:
display_name = base_name
# Clean up the huggingface cache prefix
display_name = display_name.replace("!models--", "")
# --- NAME FIX LOGIC END ---
data = {
"path": model_path,
"name": display_name,
"gen_eos_id": "MISSING", # From generation_config.json
"tok_eos_str": "MISSING", # From tokenizer_config.json
"vocab_eos_id": "MISSING", # The actual ID of the string in tokenizer.json
"vocab_size": "MISSING",
"internal_consistency": True
}
# 1. Generation Config (What the model uses to stop)
gen_conf = load_json(os.path.join(model_path, "generation_config.json"))
if gen_conf:
data["gen_eos_id"] = gen_conf.get("eos_token_id", "MISSING")
# Handle list of EOS ids
if isinstance(data["gen_eos_id"], list):
data["gen_eos_id"] = data["gen_eos_id"][0] # Take first for comparison
# 2. Tokenizer Config (What the string is)
tok_conf = load_json(os.path.join(model_path, "tokenizer_config.json"))
if tok_conf:
data["tok_eos_str"] = tok_conf.get("eos_token", "MISSING")
if isinstance(data["tok_eos_str"], dict):
data["tok_eos_str"] = data["tok_eos_str"].get("content", "MISSING")
# 3. Tokenizer JSON (The actual map)
# We prefer tokenizer.json (HuggingFace) over tokenizer.model (SentencePiece) for inspection
tok_file = load_json(os.path.join(model_path, "tokenizer.json"))
if tok_file and data["tok_eos_str"] != "MISSING":
model_vocab = tok_file.get("model", {}).get("vocab", {})
data["vocab_size"] = len(model_vocab)
# Find ID of the EOS string
if data["tok_eos_str"] in model_vocab:
data["vocab_eos_id"] = model_vocab[data["tok_eos_str"]]
# Check Internal Consistency
# Does the ID in generation_config match the ID of the string in tokenizer.json?
if str(data["gen_eos_id"]) != str(data["vocab_eos_id"]):
data["internal_consistency"] = False
return data
def main():
parser = argparse.ArgumentParser(description="Scan models for EOS/Tokenizer mismatches.")
parser.add_argument("config", help="Path to the mergekit yaml config file")
args = parser.parse_args()
sys.stdout = Logger()
print(f"{Fore.CYAN}--- EOS & TOKENIZER SCANNER (DEEP SCAN) ---{Style.RESET_ALL}")
print(f"Scanning config: {args.config}\n")
with open(args.config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
base_model_path = config.get('base_model')
# Extract model paths from list of dicts
models = []
if config.get('models'):
for m in config['models']:
if isinstance(m, dict) and 'model' in m:
models.append(m['model'])
elif isinstance(m, str):
models.append(m)
if not base_model_path:
print(f"{Fore.RED}CRITICAL: No base_model defined in YAML.{Style.RESET_ALL}")
return
# 1. Analyze Base Model
print("Analyzing Base Model...")
base_data = get_model_metadata(base_model_path)
print(f"{Fore.GREEN}BASE MODEL: {base_data['name']}{Style.RESET_ALL}")
print(f" Gen Config EOS ID: {base_data['gen_eos_id']}")
print(f" Tokenizer EOS Str: {base_data['tok_eos_str']}")
print(f" Actual Vocab ID: {base_data['vocab_eos_id']}")
if not base_data['internal_consistency']:
print(f" {Fore.RED}INTERNAL ERROR: Base model generation_config ID does not match tokenizer ID!{Style.RESET_ALL}")
else:
print(f" Internal Consistency: {Fore.GREEN}PASS{Style.RESET_ALL}")
print("-" * 80)
# 2. Analyze Donors
print(f"{'Status':<10} | {'Gen ID':<8} | {'Vocab ID':<8} | {'EOS Str':<10} | {'Model Name'}")
print("-" * 100)
mismatches = 0
for model_path in models:
d = get_model_metadata(model_path)
is_match = True
reasons = []
# Check against Base Model
if str(d['gen_eos_id']) != str(base_data['gen_eos_id']):
is_match = False
reasons.append("GenID")
if str(d['vocab_eos_id']) != str(base_data['vocab_eos_id']):
is_match = False
reasons.append("VocabID")
if d['tok_eos_str'] != base_data['tok_eos_str']:
is_match = False
reasons.append("Str")
# Formatting
status_color = Fore.GREEN
status_text = "MATCH"
if not is_match:
status_color = Fore.RED
status_text = f"FAIL"
mismatches += 1
# Internal consistency check override
if not d['internal_consistency']:
status_color = Fore.MAGENTA
status_text = "BROKEN"
mismatches += 1
# Column coloring
gen_id_str = str(d['gen_eos_id'])
if gen_id_str != str(base_data['gen_eos_id']): gen_id_str = f"{Fore.RED}{gen_id_str}{status_color}"
vocab_id_str = str(d['vocab_eos_id'])
if vocab_id_str != str(base_data['vocab_eos_id']): vocab_id_str = f"{Fore.RED}{vocab_id_str}{status_color}"
str_str = str(d['tok_eos_str'])
if str_str != base_data['tok_eos_str']: str_str = f"{Fore.RED}{str_str}{status_color}"
print(f"{status_color}{status_text:<10} | {gen_id_str:<8} | {vocab_id_str:<8} | {str_str:<10} | {d['name']}{Style.RESET_ALL}")
print("-" * 100)
# 3. Final Recommendation
print(f"\n{Fore.CYAN}--- FINAL VERDICT ---{Style.RESET_ALL}")
if mismatches == 0:
print(f"{Fore.GREEN}ALL CLEAR.{Style.RESET_ALL}")
print("1. Change YAML to: tokenizer: source: base")
print("2. Remove: chat_template: auto")
print("3. Ensure your base model path in YAML is correct.")
else:
print(f"{Fore.RED}MISMATCHES DETECTED.{Style.RESET_ALL}")
print("1. You MUST use: tokenizer: source: union")
print("2. However, 'union' may cause the early termination bug if IDs shift.")
print("3. Recommendation: Remove the models marked FAIL/BROKEN from the merge.")
if __name__ == "__main__":
main()