Spaces:
Running
Running
File size: 7,493 Bytes
7080631 5f463e1 7080631 5f463e1 7080631 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | import yaml
import json
import os
import sys
import argparse
from colorama import init, Fore, Style
# Initialize colorama
init()
class Logger:
def __init__(self, filename="eos_audit.log"):
self.terminal = sys.stdout
self.log = open(filename, "w", encoding="utf-8")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def load_json(path):
if not os.path.exists(path):
return None
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
return None
def get_model_metadata(model_path):
# --- NAME FIX LOGIC START ---
# Normalize path to handle trailing slashes or mixed separators
norm_path = os.path.normpath(model_path)
base_name = os.path.basename(norm_path)
# If the folder is named "fixed", grab the parent folder name instead
if base_name == "fixed":
parent_name = os.path.basename(os.path.dirname(norm_path))
display_name = f"{parent_name}/fixed"
else:
display_name = base_name
# Clean up the huggingface cache prefix
display_name = display_name.replace("!models--", "")
# --- NAME FIX LOGIC END ---
data = {
"path": model_path,
"name": display_name,
"gen_eos_id": "MISSING", # From generation_config.json
"tok_eos_str": "MISSING", # From tokenizer_config.json
"vocab_eos_id": "MISSING", # The actual ID of the string in tokenizer.json
"vocab_size": "MISSING",
"internal_consistency": True
}
# 1. Generation Config (What the model uses to stop)
gen_conf = load_json(os.path.join(model_path, "generation_config.json"))
if gen_conf:
data["gen_eos_id"] = gen_conf.get("eos_token_id", "MISSING")
# Handle list of EOS ids
if isinstance(data["gen_eos_id"], list):
data["gen_eos_id"] = data["gen_eos_id"][0] # Take first for comparison
# 2. Tokenizer Config (What the string is)
tok_conf = load_json(os.path.join(model_path, "tokenizer_config.json"))
if tok_conf:
data["tok_eos_str"] = tok_conf.get("eos_token", "MISSING")
if isinstance(data["tok_eos_str"], dict):
data["tok_eos_str"] = data["tok_eos_str"].get("content", "MISSING")
# 3. Tokenizer JSON (The actual map)
# We prefer tokenizer.json (HuggingFace) over tokenizer.model (SentencePiece) for inspection
tok_file = load_json(os.path.join(model_path, "tokenizer.json"))
if tok_file and data["tok_eos_str"] != "MISSING":
model_vocab = tok_file.get("model", {}).get("vocab", {})
data["vocab_size"] = len(model_vocab)
# Find ID of the EOS string
if data["tok_eos_str"] in model_vocab:
data["vocab_eos_id"] = model_vocab[data["tok_eos_str"]]
# Check Internal Consistency
# Does the ID in generation_config match the ID of the string in tokenizer.json?
if str(data["gen_eos_id"]) != str(data["vocab_eos_id"]):
data["internal_consistency"] = False
return data
def main():
parser = argparse.ArgumentParser(description="Scan models for EOS/Tokenizer mismatches.")
parser.add_argument("config", help="Path to the mergekit yaml config file")
args = parser.parse_args()
sys.stdout = Logger()
print(f"{Fore.CYAN}--- EOS & TOKENIZER SCANNER (DEEP SCAN) ---{Style.RESET_ALL}")
print(f"Scanning config: {args.config}\n")
with open(args.config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
base_model_path = config.get('base_model')
# Extract model paths from list of dicts
models = []
if config.get('models'):
for m in config['models']:
if isinstance(m, dict) and 'model' in m:
models.append(m['model'])
elif isinstance(m, str):
models.append(m)
if not base_model_path:
print(f"{Fore.RED}CRITICAL: No base_model defined in YAML.{Style.RESET_ALL}")
return
# 1. Analyze Base Model
print("Analyzing Base Model...")
base_data = get_model_metadata(base_model_path)
print(f"{Fore.GREEN}BASE MODEL: {base_data['name']}{Style.RESET_ALL}")
print(f" Gen Config EOS ID: {base_data['gen_eos_id']}")
print(f" Tokenizer EOS Str: {base_data['tok_eos_str']}")
print(f" Actual Vocab ID: {base_data['vocab_eos_id']}")
if not base_data['internal_consistency']:
print(f" {Fore.RED}INTERNAL ERROR: Base model generation_config ID does not match tokenizer ID!{Style.RESET_ALL}")
else:
print(f" Internal Consistency: {Fore.GREEN}PASS{Style.RESET_ALL}")
print("-" * 80)
# 2. Analyze Donors
print(f"{'Status':<10} | {'Gen ID':<8} | {'Vocab ID':<8} | {'EOS Str':<10} | {'Model Name'}")
print("-" * 100)
mismatches = 0
for model_path in models:
d = get_model_metadata(model_path)
is_match = True
reasons = []
# Check against Base Model
if str(d['gen_eos_id']) != str(base_data['gen_eos_id']):
is_match = False
reasons.append("GenID")
if str(d['vocab_eos_id']) != str(base_data['vocab_eos_id']):
is_match = False
reasons.append("VocabID")
if d['tok_eos_str'] != base_data['tok_eos_str']:
is_match = False
reasons.append("Str")
# Formatting
status_color = Fore.GREEN
status_text = "MATCH"
if not is_match:
status_color = Fore.RED
status_text = f"FAIL"
mismatches += 1
# Internal consistency check override
if not d['internal_consistency']:
status_color = Fore.MAGENTA
status_text = "BROKEN"
mismatches += 1
# Column coloring
gen_id_str = str(d['gen_eos_id'])
if gen_id_str != str(base_data['gen_eos_id']): gen_id_str = f"{Fore.RED}{gen_id_str}{status_color}"
vocab_id_str = str(d['vocab_eos_id'])
if vocab_id_str != str(base_data['vocab_eos_id']): vocab_id_str = f"{Fore.RED}{vocab_id_str}{status_color}"
str_str = str(d['tok_eos_str'])
if str_str != base_data['tok_eos_str']: str_str = f"{Fore.RED}{str_str}{status_color}"
print(f"{status_color}{status_text:<10} | {gen_id_str:<8} | {vocab_id_str:<8} | {str_str:<10} | {d['name']}{Style.RESET_ALL}")
print("-" * 100)
# 3. Final Recommendation
print(f"\n{Fore.CYAN}--- FINAL VERDICT ---{Style.RESET_ALL}")
if mismatches == 0:
print(f"{Fore.GREEN}ALL CLEAR.{Style.RESET_ALL}")
print("1. Change YAML to: tokenizer: source: base")
print("2. Remove: chat_template: auto")
print("3. Ensure your base model path in YAML is correct.")
else:
print(f"{Fore.RED}MISMATCHES DETECTED.{Style.RESET_ALL}")
print("1. You MUST use: tokenizer: source: union")
print("2. However, 'union' may cause the early termination bug if IDs shift.")
print("3. Recommendation: Remove the models marked FAIL/BROKEN from the merge.")
if __name__ == "__main__":
main() |