""" CLI for ternary quantization of HuggingFace models. """ from __future__ import annotations import argparse import gc import json import math import sys import time from pathlib import Path import torch def cmd_catalog(args): """List the repo's known-good and known-probe model entries.""" from ternary_quant.toolkit import known_models_to_dict, list_known_models entries = list_known_models(status=args.status, family=args.family) if args.json: payload = { "status_filter": args.status, "family_filter": args.family, "models": known_models_to_dict(entries), } text = json.dumps(payload, indent=2) if args.output: output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(text + "\n") print(f"Wrote catalog to {output_path}") return print(text) return if not entries: print("No models matched the requested filters.") return grouped: dict[str, list] = {} for entry in entries: grouped.setdefault(entry.status, []).append(entry) for status, status_entries in grouped.items(): print(status.replace("_", " ").title()) for entry in status_entries: print( f" {entry.model_id:<40} family={entry.family:<18} " f"path={entry.path:<8} runtime={entry.recommended_runtime}" ) print(f" note: {entry.note}") print(f" artifact: {entry.artifact}") if args.show_commands and entry.quickstart_command: print(f" quickstart: {entry.quickstart_command}") print("") def cmd_doctor(args): """Report environment readiness and runtime recommendations.""" from ternary_quant.toolkit import build_doctor_report, doctor_report_to_text report = build_doctor_report() if args.json: text = json.dumps(report, indent=2) if args.output: output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(text + "\n") print(f"Wrote doctor report to {output_path}") return print(text) return print(doctor_report_to_text(report)) if args.output: output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(report, indent=2) + "\n") print(f"\nWrote doctor report to {output_path}") def cmd_quantize(args): """Quantize a HuggingFace model with the legacy full-ternary pipeline.""" from ternary_quant.pipeline import QuantizationConfig, quantize_model from ternary_quant.storage import save_quantized_model config = QuantizationConfig( n_iter=args.n_iter, use_activation_aware=not args.no_activation_aware, block_size=args.block_size, n_samples=args.n_samples, seq_len=args.seq_len, dataset=args.dataset, dataset_config=args.dataset_config, seed=args.seed, ) if args.skip_modules: config.skip_modules = args.skip_modules result = quantize_model( model_name_or_path=args.model, config=config, device=args.device, dtype=_parse_dtype(args.dtype), ) save_quantized_model( ternary_params=result.ternary_params, model_name=result.model_name, model_config=result.model_config, quant_config=result.config, output_dir=args.output, stats=result.stats, ) if args.eval: print("\nRunning perplexity evaluation...") from ternary_quant.eval import evaluate_perplexity from ternary_quant.inference import load_ternary_model model, tokenizer = load_ternary_model( args.output, device=args.device, runtime_mode=getattr(args, "runtime_mode", "packed"), ) ppl = evaluate_perplexity(model, tokenizer, max_samples=args.eval_samples) print(f"Ternary model perplexity: {ppl:.2f}") def cmd_quantize_small(args): """Quantize a small model with the role-aware sparse asymmetric ternary path.""" from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from ternary_quant.data import get_calibration_data from ternary_quant.eval import ( evaluate_perplexity, evaluate_prompt_bank, get_default_prompt_bank, ) from ternary_quant.inference import generate_text, load_ternary_model from ternary_quant.quantizer_small import ( SmallModelQuantizationConfig, build_sensitivity_only_plan, build_role_aware_plan, config_to_dict, plan_to_dict, quantize_small_model_inplace, summarize_small_model_quantization, tune_low_rank_residuals_inplace, ) from ternary_quant.storage import save_quantized_model device = _resolve_device(args.device) dtype = _parse_dtype(args.dtype) tokenizer = AutoTokenizer.from_pretrained(args.model) model_config = AutoConfig.from_pretrained(args.model) calibration_data = get_calibration_data( args.model, tokenizer=tokenizer, n_samples=args.n_samples, seq_len=args.seq_len, dataset_name=args.dataset, dataset_config=args.dataset_config, seed=args.seed, ).to(device) def load_base_model(): model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, low_cpu_mem_usage=True, ).to(device) model.eval() return model def build_behavior_sequences(prompt_bank: dict) -> list[torch.Tensor]: sequences = [] for sample in prompt_bank.get("samples", []): prompt_ids = tokenizer( sample["prompt"], return_tensors="pt", truncation=False, )["input_ids"][0] generated_ids = torch.tensor( sample.get("generated_token_ids", []), dtype=torch.long, ) full_sequence = torch.cat([prompt_ids, generated_ids], dim=0).unsqueeze(0) sequences.append(full_sequence) return sequences @torch.no_grad() def build_hidden_cache(sequences): if sequences is None: return None if isinstance(sequences, torch.Tensor): model = load_base_model() outputs = [] for start in range(0, sequences.shape[0], args.calibration_tune_batch_size): batch = sequences[start : start + args.calibration_tune_batch_size] hidden = model(batch, output_hidden_states=True).hidden_states[-1] outputs.append(hidden.detach().cpu().to(torch.float16)) del model _cleanup_device(device) return torch.cat(outputs, dim=0) outputs = [] model = load_base_model() for seq in sequences: hidden = model(seq.to(device), output_hidden_states=True).hidden_states[-1] outputs.append(hidden.detach().cpu().to(torch.float16)) del model _cleanup_device(device) return outputs @torch.no_grad() def build_topk_logit_cache(sequences, top_k: int): if sequences is None or top_k <= 0: return None top_k = max(1, int(top_k)) if isinstance(sequences, torch.Tensor): model = load_base_model() indices_out = [] logits_out = [] entropy_out = [] for start in range(0, sequences.shape[0], args.calibration_tune_batch_size): batch = sequences[start : start + args.calibration_tune_batch_size] logits = model(batch).logits[:, :-1, :].float() values, indices = torch.topk(logits, k=min(top_k, logits.shape[-1]), dim=-1) log_probs = torch.log_softmax(logits, dim=-1) probs = log_probs.exp() entropy = -(probs * log_probs).sum(dim=-1) / math.log(max(logits.shape[-1], 2)) indices_out.append(indices.detach().cpu().to(torch.int32)) logits_out.append(values.detach().cpu().to(torch.float16)) entropy_out.append(entropy.detach().cpu().to(torch.float16)) del model _cleanup_device(device) return { "indices": torch.cat(indices_out, dim=0), "logits": torch.cat(logits_out, dim=0), "entropy": torch.cat(entropy_out, dim=0), } outputs = [] model = load_base_model() for seq in sequences: logits = model(seq.to(device)).logits[:, :-1, :].float() values, indices = torch.topk(logits, k=min(top_k, logits.shape[-1]), dim=-1) log_probs = torch.log_softmax(logits, dim=-1) probs = log_probs.exp() entropy = -(probs * log_probs).sum(dim=-1) / math.log(max(logits.shape[-1], 2)) outputs.append( { "indices": indices.detach().cpu().to(torch.int32), "logits": values.detach().cpu().to(torch.float16), "entropy": entropy.detach().cpu().to(torch.float16), } ) del model _cleanup_device(device) return outputs def make_config( planner: str, ) -> SmallModelQuantizationConfig: target_average_bits = args.target_average_bits adaptive_salient = args.adaptive_salient role_cost_weights = None if planner == "budgeted" and target_average_bits is None: target_average_bits = 10.5 adaptive_salient = True elif planner == "sensitivity_budget": if target_average_bits is None: target_average_bits = 10.5 adaptive_salient = True role_cost_weights = _uniform_role_weights() elif planner == "practical": target_average_bits = None config = SmallModelQuantizationConfig( group_size=args.group_size, n_iter=args.n_iter, salient_fraction=args.salient_fraction, min_salient_fraction=args.min_salient_fraction, max_salient_fraction=args.max_salient_fraction, adaptive_salient=adaptive_salient, low_rank_rank=args.low_rank_rank, adaptive_low_rank=args.adaptive_low_rank, low_rank_chunk_rank=args.low_rank_chunk_rank, low_rank_target_average_bits=args.low_rank_target_average_bits, low_rank_fit_mode=args.low_rank_fit_mode, low_rank_ridge=args.low_rank_ridge, low_rank_max_samples=args.low_rank_max_samples, n_boundary_layers=args.boundary_layers, calibration_batch_size=args.calibration_batch_size, quantize_attention_output=args.quantize_attention_output, quantize_mlp_output=args.quantize_mlp_output, target_average_bits=target_average_bits, importance_threshold_scale=getattr(args, "importance_threshold_scale", 0.0), role_cost_weights=role_cost_weights if role_cost_weights is not None else SmallModelQuantizationConfig().role_cost_weights, ) config.base_config.n_samples = args.n_samples config.base_config.seq_len = args.seq_len config.base_config.dataset = args.dataset config.base_config.dataset_config = args.dataset_config config.base_config.seed = args.seed return config def build_plan(model, config: SmallModelQuantizationConfig, planner: str): if planner == "sensitivity_budget": return build_sensitivity_only_plan(model, calibration_data, config) return build_role_aware_plan(model, calibration_data, config) behavior_sequences = None calibration_hidden_states = None behavior_hidden_states = None calibration_logit_targets = None behavior_logit_targets = None if args.calibration_tune_steps > 0 and args.behavior_tune_weight > 0.0: behavior_prompt_bank = get_default_prompt_bank( primary_prompt=args.prompt, max_prompts=args.behavior_tune_prompt_count, ) print("Building prompt-bank behavior tuning data...") behavior_model = load_base_model() behavior_reference = evaluate_prompt_bank( behavior_model, tokenizer, prompts=behavior_prompt_bank, max_new_tokens=args.behavior_tune_max_tokens, ) behavior_sequences = build_behavior_sequences(behavior_reference) del behavior_model _cleanup_device(device) if args.calibration_tune_steps > 0 and ( args.distill_weight > 0.0 or args.behavior_hidden_weight > 0.0 ): print("Building teacher hidden-state caches...") calibration_hidden_states = build_hidden_cache(calibration_data) if behavior_sequences is not None: behavior_hidden_states = build_hidden_cache(behavior_sequences) if args.calibration_tune_steps > 0 and ( args.logit_distill_weight > 0.0 or args.behavior_logit_weight > 0.0 or args.entropy_distill_weight > 0.0 or args.behavior_entropy_weight > 0.0 ): print("Building teacher top-k logit caches...") calibration_logit_targets = build_topk_logit_cache( calibration_data, args.logit_distill_topk, ) if behavior_sequences is not None: behavior_logit_targets = build_topk_logit_cache( behavior_sequences, args.logit_distill_topk, ) selection = None auto_tuned = False if args.planner in {"auto", "collapse_auto"}: candidate_planners = ["practical", "sensitivity_budget"] best = None total_quant_time = 0.0 selection_metric = ( "collapse_aware" if args.planner == "collapse_auto" else "ppl" ) selection = { "selection_metric": selection_metric, "candidate_scores": {}, } selection_prompt_bank = get_default_prompt_bank( primary_prompt=args.prompt, max_prompts=args.selection_prompt_count, ) reference_behavior = None if selection_metric == "collapse_aware": print("Measuring FP16 prompt-bank behavior...") reference_model = load_base_model() reference_behavior = evaluate_prompt_bank( reference_model, tokenizer, prompts=selection_prompt_bank, max_new_tokens=args.selection_max_tokens, ) selection["reference_behavior"] = { "avg_collapse_score": reference_behavior["avg_collapse_score"], "worst_collapse_score": reference_behavior["worst_collapse_score"], "avg_distinct_2": reference_behavior["avg_distinct_2"], "avg_repeated_3gram_ratio": reference_behavior[ "avg_repeated_3gram_ratio" ], } del reference_model _cleanup_device(device) for planner in candidate_planners: print(f"Evaluating planner candidate: {planner}") model = load_base_model() config = make_config(planner) t0 = time.time() plan = build_plan(model, config, planner) result = quantize_small_model_inplace( model, calibration_data=calibration_data, config=config, plan=plan, ) total_quant_time += time.time() - t0 summary = summarize_small_model_quantization(result, model) tune_stats = None if args.calibration_tune_steps > 0: tune_stats = tune_low_rank_residuals_inplace( model, result, calibration_data=calibration_data, n_steps=args.calibration_tune_steps, lr=args.calibration_tune_lr, batch_size=args.calibration_tune_batch_size, max_seq_len=args.seq_len, behavior_sequences=behavior_sequences, behavior_weight=args.behavior_tune_weight, calibration_hidden_states=calibration_hidden_states, behavior_hidden_states=behavior_hidden_states, calibration_logit_targets=calibration_logit_targets, behavior_logit_targets=behavior_logit_targets, distill_weight=args.distill_weight, behavior_hidden_weight=args.behavior_hidden_weight, logit_distill_weight=args.logit_distill_weight, behavior_logit_weight=args.behavior_logit_weight, entropy_distill_weight=args.entropy_distill_weight, behavior_entropy_weight=args.behavior_entropy_weight, logit_distill_temperature=args.logit_distill_temperature, seed=args.seed, ) summary = summarize_small_model_quantization(result, model) selection_ppl = evaluate_perplexity( model, tokenizer, seq_len=args.seq_len, max_samples=args.selection_eval_samples, ) selection_score = float(selection_ppl) selection_behavior = None if selection_metric == "collapse_aware": selection_behavior = evaluate_prompt_bank( model, tokenizer, prompts=selection_prompt_bank, max_new_tokens=args.selection_max_tokens, ) reference_avg = ( 0.0 if reference_behavior is None else reference_behavior["avg_collapse_score"] ) reference_worst = ( reference_avg if reference_behavior is None else reference_behavior["worst_collapse_score"] ) collapse_excess = max( selection_behavior["avg_collapse_score"] - reference_avg, 0.0, ) worst_excess = max( selection_behavior["worst_collapse_score"] - reference_worst, 0.0, ) selection_score = selection_ppl * ( 1.0 + args.selection_collapse_weight * collapse_excess + args.selection_worst_weight * worst_excess ) selection["candidate_scores"][planner] = { "selection_ppl": selection_ppl, "selection_score": selection_score, "predicted_average_bits": plan.predicted_average_bits, "full_model_effective_bits": summary["full_model_effective_bits"], } if selection_behavior is not None: selection["candidate_scores"][planner]["selection_behavior"] = { "avg_collapse_score": selection_behavior["avg_collapse_score"], "worst_collapse_score": selection_behavior["worst_collapse_score"], "avg_distinct_2": selection_behavior["avg_distinct_2"], "avg_repeated_3gram_ratio": selection_behavior[ "avg_repeated_3gram_ratio" ], } if tune_stats is not None: selection["candidate_scores"][planner]["calibration_tune"] = tune_stats if best is None or selection_score < best["selection_score"]: if best is not None: del best["model"] _cleanup_device(device) best = { "model": model, "config": config, "plan": plan, "result": result, "summary": summary, "selection_ppl": selection_ppl, "selection_score": selection_score, "selection_behavior": selection_behavior, "planner": planner, } else: del model _cleanup_device(device) if best is None: raise RuntimeError("Auto planner failed to select a candidate.") model = best["model"] config = best["config"] plan = best["plan"] result = best["result"] summary = best["summary"] quant_time = total_quant_time selected_name = "RAST-collapse-auto" if args.planner == "collapse_auto" else "RAST-auto" result.plan.method_name = selected_name summary["method_name"] = selected_name auto_tuned = args.calibration_tune_steps > 0 selection.update( { "selected_planner": best["planner"], "selection_ppl": best["selection_ppl"], "selection_score": best["selection_score"], } ) if best["selection_behavior"] is not None: selection["selected_behavior"] = { "avg_collapse_score": best["selection_behavior"]["avg_collapse_score"], "worst_collapse_score": best["selection_behavior"]["worst_collapse_score"], "avg_distinct_2": best["selection_behavior"]["avg_distinct_2"], "avg_repeated_3gram_ratio": best["selection_behavior"][ "avg_repeated_3gram_ratio" ], } print( f"Selected planner: {best['planner']} | " f"held-out score {best['selection_score']:.2f} | " f"PPL {best['selection_ppl']:.2f} | " f"full-model bits {summary['full_model_effective_bits']:.2f}" ) else: model = load_base_model() config = make_config(args.planner) print("Building role-aware plan...") t0 = time.time() plan = build_plan(model, config, args.planner) print( f"Plan ready in {time.time() - t0:.1f}s | " f"Predicted average bits: {plan.predicted_average_bits:.2f}" ) print("Applying role-aware quantization...") t1 = time.time() result = quantize_small_model_inplace( model, calibration_data=calibration_data, config=config, plan=plan, ) quant_time = time.time() - t1 summary = summarize_small_model_quantization(result, model) if args.calibration_tune_steps > 0 and not auto_tuned: print("Calibrating low-rank residuals...") t2 = time.time() tune_stats = tune_low_rank_residuals_inplace( model, result, calibration_data=calibration_data, n_steps=args.calibration_tune_steps, lr=args.calibration_tune_lr, batch_size=args.calibration_tune_batch_size, max_seq_len=args.seq_len, behavior_sequences=behavior_sequences, behavior_weight=args.behavior_tune_weight, calibration_hidden_states=calibration_hidden_states, behavior_hidden_states=behavior_hidden_states, calibration_logit_targets=calibration_logit_targets, behavior_logit_targets=behavior_logit_targets, distill_weight=args.distill_weight, behavior_hidden_weight=args.behavior_hidden_weight, logit_distill_weight=args.logit_distill_weight, behavior_logit_weight=args.behavior_logit_weight, entropy_distill_weight=args.entropy_distill_weight, behavior_entropy_weight=args.behavior_entropy_weight, logit_distill_temperature=args.logit_distill_temperature, seed=args.seed, ) quant_time += time.time() - t2 summary = summarize_small_model_quantization(result, model) print( f"Calibration tune complete | " f"final loss {tune_stats.get('final_loss', float('nan')):.4f} | " f"wrapped modules {tune_stats['n_wrapped_modules']}" ) save_quantized_model( ternary_params=result.quantized_params, model_name=args.model, model_config=model_config, quant_config=config, output_dir=args.output, stats=result.stats, summary=summary, plan=result.plan, method_name=result.plan.method_name, ) report = { "method": result.plan.method_name, "model": args.model, "quant_time_sec": quant_time, "summary": summary, "plan": plan_to_dict(result.plan), "config": config_to_dict(config), } if selection is not None: report["selection"] = selection report_path = Path(args.output) / "role_aware_report.json" with open(report_path, "w") as f: json.dump(report, f, indent=2) print(f"Wrote role-aware report to {report_path}") if args.eval: print("\nRunning validation on saved model...") quantized_model, tokenizer = load_ternary_model( args.output, device=device, runtime_mode=getattr(args, "runtime_mode", "packed"), ) ppl = evaluate_perplexity( quantized_model, tokenizer, seq_len=args.seq_len, max_samples=args.eval_samples, ) print(f"Role-aware quantized perplexity: {ppl:.2f}") if args.prompt: text = generate_text( quantized_model, tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, do_sample=False, ) print(f"Prompt: {args.prompt}") print(f"Output: {text}") def cmd_quantize_ptq(args): """Quantize a small model via a ternary PTQ family or controller.""" from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from ternary_quant.data import get_calibration_data from ternary_quant.eval import ( evaluate_perplexity, evaluate_prompt_bank, get_default_prompt_bank, ) from ternary_quant.inference import generate_text, load_ternary_model from ternary_quant.ptq_families import ( build_family_config, family_config_to_dict, get_default_family_candidates, quantize_family_inplace, summarize_family_quantization, ) from ternary_quant.storage import save_quantized_model device = _resolve_device(args.device) dtype = _parse_dtype(args.dtype) tokenizer = AutoTokenizer.from_pretrained(args.model) model_config = AutoConfig.from_pretrained(args.model) calibration_data = get_calibration_data( args.model, tokenizer=tokenizer, n_samples=args.n_samples, seq_len=args.seq_len, dataset_name=args.dataset, dataset_config=args.dataset_config, seed=args.seed, ).to(device) def load_base_model(): model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=dtype, low_cpu_mem_usage=True, ).to(device) model.eval() return model def build_config(family_name: str): return build_family_config( family_name, target_average_bits=args.target_average_bits, group_size=args.group_size, n_iter=args.n_iter, n_boundary_layers=args.boundary_layers, calibration_batch_size=args.calibration_batch_size, quantize_attention_output=args.quantize_attention_output, quantize_mlp_output=args.quantize_mlp_output, ) selection = None if args.family == "controller": candidate_names = ( args.candidate_families if args.candidate_families else get_default_family_candidates() ) selection_metric = args.selection_metric selection_prompt_bank = get_default_prompt_bank( primary_prompt=args.prompt, max_prompts=args.selection_prompt_count, ) selection = { "selection_metric": selection_metric, "candidate_scores": {}, } reference_behavior = None if selection_metric == "collapse": print("Measuring FP16 prompt-bank behavior for controller selection...") reference_model = load_base_model() reference_behavior = evaluate_prompt_bank( reference_model, tokenizer, prompts=selection_prompt_bank, max_new_tokens=args.selection_max_tokens, ) selection["reference_behavior"] = { "avg_collapse_score": reference_behavior["avg_collapse_score"], "worst_collapse_score": reference_behavior["worst_collapse_score"], "avg_distinct_2": reference_behavior["avg_distinct_2"], "avg_repeated_3gram_ratio": reference_behavior[ "avg_repeated_3gram_ratio" ], } del reference_model _cleanup_device(device) best = None total_quant_time = 0.0 for family_name in candidate_names: print(f"Evaluating ternary PTQ family candidate: {family_name}") family_config = build_config(family_name) model = load_base_model() t0 = time.time() result = quantize_family_inplace( model, calibration_data=calibration_data, config=family_config, ) total_quant_time += time.time() - t0 summary = summarize_family_quantization(result) selection_ppl = evaluate_perplexity( model, tokenizer, seq_len=args.seq_len, max_samples=args.selection_eval_samples, ) selection_score = float(selection_ppl) selection_behavior = None if selection_metric == "collapse": selection_behavior = evaluate_prompt_bank( model, tokenizer, prompts=selection_prompt_bank, max_new_tokens=args.selection_max_tokens, ) reference_avg = ( 0.0 if reference_behavior is None else reference_behavior["avg_collapse_score"] ) reference_worst = ( reference_avg if reference_behavior is None else reference_behavior["worst_collapse_score"] ) collapse_excess = max( selection_behavior["avg_collapse_score"] - reference_avg, 0.0, ) worst_excess = max( selection_behavior["worst_collapse_score"] - reference_worst, 0.0, ) selection_score = selection_ppl * ( 1.0 + args.selection_collapse_weight * collapse_excess + args.selection_worst_weight * worst_excess ) if args.target_average_bits is not None: bits_excess = max( summary["full_model_effective_bits"] - args.target_average_bits, 0.0, ) selection_score *= ( 1.0 + args.selection_bits_weight * bits_excess / max(args.target_average_bits, 1e-6) ) selection["candidate_scores"][family_name] = { "label": family_config.label, "selection_ppl": selection_ppl, "selection_score": selection_score, "full_model_effective_bits": summary["full_model_effective_bits"], "quantized_fraction": summary["quantized_fraction"], } if selection_behavior is not None: selection["candidate_scores"][family_name]["selection_behavior"] = { "avg_collapse_score": selection_behavior["avg_collapse_score"], "worst_collapse_score": selection_behavior["worst_collapse_score"], "avg_distinct_2": selection_behavior["avg_distinct_2"], "avg_repeated_3gram_ratio": selection_behavior[ "avg_repeated_3gram_ratio" ], } if best is None or selection_score < best["selection_score"]: if best is not None: del best["model"] _cleanup_device(device) best = { "model": model, "family_config": family_config, "result": result, "summary": summary, "selection_ppl": selection_ppl, "selection_score": selection_score, "selection_behavior": selection_behavior, "family_name": family_name, } else: del model _cleanup_device(device) if best is None: raise RuntimeError("Controller failed to select a ternary PTQ family.") model = best["model"] family_config = best["family_config"] result = best["result"] summary = best["summary"] quant_time = total_quant_time result.plan.method_name = "Ternary-PTQ-auto" summary["method_name"] = "Ternary-PTQ-auto" summary["selected_family_preset"] = best["family_name"] selection.update( { "selected_family_preset": best["family_name"], "selected_family_label": family_config.label, "selection_ppl": best["selection_ppl"], "selection_score": best["selection_score"], } ) if best["selection_behavior"] is not None: selection["selected_behavior"] = { "avg_collapse_score": best["selection_behavior"]["avg_collapse_score"], "worst_collapse_score": best["selection_behavior"]["worst_collapse_score"], "avg_distinct_2": best["selection_behavior"]["avg_distinct_2"], "avg_repeated_3gram_ratio": best["selection_behavior"][ "avg_repeated_3gram_ratio" ], } print( f"Selected family: {best['family_name']} | " f"held-out score {best['selection_score']:.2f} | " f"PPL {best['selection_ppl']:.2f} | " f"full-model bits {summary['full_model_effective_bits']:.2f}" ) else: family_config = build_config(args.family) print(f"Applying ternary PTQ family: {family_config.label}") model = load_base_model() t0 = time.time() result = quantize_family_inplace( model, calibration_data=calibration_data, config=family_config, ) quant_time = time.time() - t0 summary = summarize_family_quantization(result) save_quantized_model( ternary_params=result.quantized_params, model_name=args.model, model_config=model_config, quant_config=family_config, output_dir=args.output, stats=result.stats, summary=summary, plan=result.plan, method_name=result.plan.method_name, ) report = { "method": result.plan.method_name, "model": args.model, "quant_time_sec": quant_time, "summary": summary, "family_config": family_config_to_dict(family_config), } if selection is not None: report["selection"] = selection report_path = Path(args.output) / "ternary_ptq_report.json" with open(report_path, "w") as f: json.dump(report, f, indent=2) print(f"Wrote ternary PTQ report to {report_path}") if args.eval: print("\nRunning validation on saved model...") quantized_model, tokenizer = load_ternary_model( args.output, device=device, runtime_mode=getattr(args, "runtime_mode", "packed"), ) ppl = evaluate_perplexity( quantized_model, tokenizer, seq_len=args.seq_len, max_samples=args.eval_samples, ) print(f"Ternary PTQ perplexity: {ppl:.2f}") if args.prompt: text = generate_text( quantized_model, tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, do_sample=False, ) print(f"Prompt: {args.prompt}") print(f"Output: {text}") def cmd_eval(args): """Evaluate perplexity of a saved quantized model.""" from ternary_quant.eval import evaluate_perplexity from ternary_quant.inference import load_ternary_model model, tokenizer = load_ternary_model( args.model_dir, device=args.device, runtime_mode=getattr(args, "runtime_mode", "packed"), ) ppl = evaluate_perplexity( model, tokenizer, seq_len=args.seq_len, max_samples=args.max_samples, ) print(f"\nPerplexity: {ppl:.2f}") def cmd_compare(args): """Compare original and saved quantized model.""" from ternary_quant.eval import compare_models compare_models( original_model_name=args.original, ternary_model_dir=args.ternary, device=args.device, seq_len=args.seq_len, max_samples=args.max_samples, ) def cmd_generate(args): """Generate text with a saved quantized model.""" import numpy as np from ternary_quant.generative_adapters import inspect_generative_model from ternary_quant.inference import ( generate_generative_output, generate_text, load_ternary_model, ) model, asset = load_ternary_model( args.model_dir, device=args.device, runtime_mode=getattr(args, "runtime_mode", "packed"), ) model_info = inspect_generative_model( model, model_name=str(getattr(model, "name_or_path", "loaded-model")), ) image = None if args.image_path: try: from PIL import Image except Exception as exc: raise RuntimeError( "Reading --image-path requires Pillow. Install pillow or omit the image." ) from exc image = np.array(Image.open(args.image_path).convert("RGB")) if model_info.model_family == "image_text_to_text": output = generate_generative_output( model, asset, prompt=args.prompt, max_new_tokens=args.max_tokens, image=image, ) else: output = generate_text( model, asset, prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, do_sample=args.temperature > 0, ) print(f"\nPrompt: {args.prompt}") print(f"Output: {output}") def cmd_inspect_generative(args): """Inspect a generative model and list its quantizable components.""" from ternary_quant.generative_adapters import ( generative_model_info_to_dict, load_generative_model, ) device = _resolve_device(args.device) dtype = _parse_dtype(args.dtype) model, _, model_info = load_generative_model( args.model, device=device, dtype=dtype, ) print(f"Model: {model_info.model_name}") print(f"Family: {model_info.model_family}") print(f"Model type: {model_info.model_type}") print(f"Architectures: {', '.join(model_info.architectures) or 'unknown'}") print(f"Default broad components: {', '.join(model_info.default_quantization_components)}") print("\nComponents:") for component in model_info.components: sample = ", ".join(component.sample_linear_like_names[:4]) or "(no linear modules)" print( f" {component.name:<22} path={component.path:<32} " f"linears={component.linear_like_count:<4} params={component.parameter_count:<12}" ) print(f" sample: {sample}") if args.output: payload = generative_model_info_to_dict(model_info) output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: json.dump(payload, f, indent=2) print(f"\nWrote component inventory to {output_path}") del model _cleanup_device(device) def cmd_quantize_broad(args): """Quantize selected components of a broad generative model.""" from ternary_quant.generative_adapters import ( BroadQuantizationConfig, broad_quant_config_to_dict, build_calibration_batches, evaluate_broad_prompt_bank, generative_model_info_to_dict, load_generative_model, make_demo_image, quantize_components_inplace, ) from ternary_quant.inference import load_ternary_model from ternary_quant.storage import save_quantized_model device = _resolve_device(args.device) dtype = _parse_dtype(args.dtype) model, asset, model_info = load_generative_model( args.model, device=device, dtype=dtype, ) components = ( args.components if args.components else model_info.default_quantization_components ) prompts = [args.prompt] if args.prompt else None broad_config = BroadQuantizationConfig( components=list(components), scheme=args.scheme, group_size=args.group_size, n_iter=args.n_iter, salient_fraction=args.salient_fraction, rescue_fraction=args.rescue_fraction, n_planes=3 if args.scheme == "tritplane3" else 2, allow_all_linear=args.allow_all_linear, max_length=args.seq_len, calibration_batch_size=args.calibration_batch_size, calibration_prompts=list(prompts) if prompts is not None else None, ) demo_image = make_demo_image() calibration_batches = build_calibration_batches( asset, model_info, max_length=args.seq_len, batch_size=args.calibration_batch_size, prompts=prompts, demo_images=[demo_image], ) result = quantize_components_inplace( model, model_info=model_info, calibration_batches=calibration_batches, config=broad_config, ) save_quantized_model( ternary_params=result.quantized_params, model_name=args.model, model_config=model.config, quant_config=broad_config, output_dir=args.output, stats=result.stats, summary=result.summary, method_name=result.summary["method_name"], model_family=model_info.model_family, ) report = { "method": result.summary["method_name"], "model": args.model, "model_info": generative_model_info_to_dict(model_info), "config": broad_quant_config_to_dict(broad_config), "summary": result.summary, } if args.eval: quantized_model, quantized_asset = load_ternary_model( args.output, device=device, runtime_mode=getattr(args, "runtime_mode", "packed"), ) eval_prompts = prompts or None validation = evaluate_broad_prompt_bank( quantized_model, quantized_asset, model_info, prompts=eval_prompts if eval_prompts is not None else ( [args.prompt] if args.prompt else ( ["Describe the image in one short sentence."] if model_info.model_family == "image_text_to_text" else [ "The capital of France is", "Answer briefly: What is 2 + 2?", ] ) ), max_new_tokens=args.max_tokens, demo_image=demo_image, ) report["validation"] = validation print("\nValidation:") print(f" Avg collapse: {validation['avg_collapse_score']:.3f}") print(f" Primary output: {validation['primary_text']}") del quantized_model _cleanup_device(device) report_path = Path(args.output) / "broad_generative_report.json" with open(report_path, "w") as f: json.dump(report, f, indent=2) print(f"Wrote broad generative report to {report_path}") # Print a compact summary s = result.summary print(f"\nQuantization summary:") print(f" Layers quantized: {s['quantized_modules']}") print(f" Full-model effective bits: {s['full_model_effective_bits']:.2f}") print(f" Compression ratio: {s['compression_ratio']:.2f}×") print(f" Avg reconstruction error: {s['avg_relative_error']:.4f}") if getattr(args, "push_to_hub", None): _push_to_hub(args.output, args.push_to_hub, args.model, result.summary, broad_config) def _push_to_hub(output_dir: str, hub_repo: str, source_model: str, summary: dict, config) -> None: """Push a quantized model directory to HuggingFace Hub.""" try: from huggingface_hub import HfApi except ImportError: print("huggingface_hub not installed. Run: pip install huggingface_hub") return output_path = Path(output_dir) # Write a model card model_card = f"""--- tags: - ternary-quant - quantization - ternary base_model: {source_model} --- # {hub_repo} Ternary-quantized version of [{source_model}](https://huggingface.co/{source_model}) produced with [ternary-quant](https://github.com/Asad-Ismail/ternary-quant). ## Quantization details - **Scheme**: {getattr(config, 'scheme', 'unknown')} - **Components**: {', '.join(getattr(config, 'components', []))} - **Full-model effective bits**: {summary.get('full_model_effective_bits', '?'):.2f} - **Compression ratio**: {summary.get('compression_ratio', '?'):.2f}× - **Avg reconstruction error**: {summary.get('avg_relative_error', '?'):.4f} ## Usage ```python from ternary_quant.inference import load_ternary_model model, tokenizer = load_ternary_model("{hub_repo}", runtime_mode="cached") inputs = tokenizer("Hello, world!", return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=50) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` Or via CLI: ```bash pip install ternary-quant ternary-quant generate {hub_repo} --prompt "Hello" --runtime-mode cached ``` """ card_path = output_path / "README.md" card_path.write_text(model_card) api = HfApi() print(f"Pushing to {hub_repo}...") api.upload_folder( folder_path=str(output_path), repo_id=hub_repo, repo_type="model", ) print(f"Pushed to https://huggingface.co/{hub_repo}") def cmd_check(args): """Quick compatibility check using only the model config (no weights downloaded).""" from ternary_quant.generative_adapters import ( VLM_MODEL_TYPES, _default_components_for_family, detect_model_family_from_config, ) from transformers import AutoConfig print(f"Checking: {args.model}") try: config = AutoConfig.from_pretrained(args.model) except Exception as exc: print(f" Could not load config: {exc}") print(" → Model may be gated (requires HF token) or not found.") return model_type = getattr(config, "model_type", "unknown") architectures = list(getattr(config, "architectures", None) or []) family = detect_model_family_from_config(config) default_components = _default_components_for_family(family) print(f" model_type: {model_type}") print(f" architectures: {', '.join(architectures) or 'unknown'}") print(f" family: {family}") print(f" default components to quantize: {', '.join(default_components)}") is_vlm = model_type in VLM_MODEL_TYPES has_encoder_decoder = bool(getattr(config, "is_encoder_decoder", False)) if is_vlm: print(" → VLM: quantize text_backbone + multimodal_connector") print(f" ternary-quant quantize-broad {args.model} \\") print(f" --output ./$(basename {args.model})-ternary \\") print(f" --components text_backbone multimodal_connector \\") print(f" --scheme tritplane3 --dtype float16") elif has_encoder_decoder: print(" → Seq2seq / audio: quantize decoder") print(f" ternary-quant quantize-broad {args.model} \\") print(f" --output ./$(basename {args.model})-ternary \\") print(f" --components decoder --scheme tritplane3") else: print(" → Causal LM: quantize text_backbone") print(f" ternary-quant quantize-broad {args.model} \\") print(f" --output ./$(basename {args.model})-ternary \\") print(f" --components text_backbone --scheme tritplane3") print() print(" If quantization fails with 'No quantizable linear modules',") print(" add --allow-all-linear to quantize all nn.Linear layers.") def cmd_info(args): """Show info about a saved quantized model.""" model_dir = Path(args.model_dir) meta_path = model_dir / "metadata.json" if not meta_path.exists(): print(f"No quantized model found at {model_dir}") sys.exit(1) with open(meta_path) as f: metadata = json.load(f) print(f"Model: {metadata['model_name']}") print(f"Model family: {metadata.get('model_family', 'causal_lm')}") print(f"Method: {metadata.get('method_name', 'unknown')}") print(f"Format family: {metadata.get('format_family', 'legacy')}") print(f"Format version: {metadata['format_version']}") print(f"Layers quantized: {len(metadata['layer_info'])}") print(f"Packed size: {metadata['total_packed_bytes'] / 1e6:.1f} MB") print(f"FP16 size: {metadata['total_fp16_bytes'] / 1e6:.1f} MB") print(f"Compression: {metadata['compression_ratio']:.1f}x") qc = metadata["quant_config"] print("\nQuantization config:") for key, value in qc.items(): if key == "base_config": continue print(f" {key}: {value}") if metadata.get("summary"): summary = metadata["summary"] print("\nSummary:") for key in [ "quantized_fraction", "avg_relative_error", "avg_effective_bits", "full_model_effective_bits", "total_sparse_nnz", ]: if key in summary: value = summary[key] if isinstance(value, float): if "fraction" in key: print(f" {key}: {value:.1%}") else: print(f" {key}: {value:.4f}") else: print(f" {key}: {value}") if metadata.get("plan"): plan = metadata["plan"] print("\nPlan:") print(f" Method: {plan.get('method_name', 'unknown')}") print(f" Target average bits: {plan.get('target_average_bits')}") print(f" Predicted average bits: {plan.get('predicted_average_bits'):.2f}") def _parse_dtype(s: str) -> torch.dtype: return { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, }[s] def _resolve_device(device: str) -> str: if device != "auto": return device if torch.cuda.is_available(): return "cuda" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" def _add_runtime_mode_arg(parser: argparse.ArgumentParser, *, default: str = "cached") -> None: parser.add_argument( "--runtime-mode", default=default, choices=["packed", "cached", "native", "metal", "triton", "gemlite"], help=( "Inference runtime path for saved quantized layers. " "'cached': dequantize once at load, fastest on GPU/CPU (recommended). " "'native': replace layers with nn.Linear, ~1.0× vs FP16. " "'packed': re-dequantize every forward, minimal live VRAM. " "'gemlite': NVIDIA GPU only — keeps weights 2-bit packed, good batch throughput. " "'triton': NVIDIA GPU only — custom Triton kernel, slightly faster than gemlite at batch=1. " "'metal': Apple Silicon adaptive — Metal kernel with cached fallback." ), ) def _cleanup_device(device: str) -> None: gc.collect() if device == "cuda": torch.cuda.empty_cache() if device == "mps": torch.mps.empty_cache() def _uniform_role_weights() -> dict[str, float]: return { "attention_inputs": 1.0, "attention_output": 1.0, "mlp_inputs": 1.0, "mlp_output": 1.0, } def main(): from ternary_quant.ptq_families import FAMILY_PRESETS, get_default_family_candidates parser = argparse.ArgumentParser( prog="ternary-quant", description="Post-training ternary quantization for HuggingFace generative models", ) subparsers = parser.add_subparsers(dest="command", required=True) p_catalog = subparsers.add_parser( "catalog", help="List validated, probe-only, and special-handling model entries", ) p_catalog.add_argument( "--status", default="all", choices=[ "all", "validated", "component_validated", "research_validated", "probe_only", "special_handling", ], ) p_catalog.add_argument( "--family", default="all", choices=["all", "causal_lm", "seq2seq_lm", "image_text_to_text"], ) p_catalog.add_argument("--show-commands", action="store_true") p_catalog.add_argument("--json", action="store_true") p_catalog.add_argument("--output", default=None) p_catalog.set_defaults(func=cmd_catalog) p_doctor = subparsers.add_parser( "doctor", help="Check environment readiness and runtime recommendations", ) p_doctor.add_argument("--json", action="store_true") p_doctor.add_argument("--output", default=None) p_doctor.set_defaults(func=cmd_doctor) p_quant = subparsers.add_parser("quantize", help="Quantize a model to ternary") p_quant.add_argument("model", help="HuggingFace model ID or local path") p_quant.add_argument("--output", "-o", required=True, help="Output directory") p_quant.add_argument("--device", default="auto", help="Device (auto/cuda/cpu/mps)") p_quant.add_argument( "--dtype", default="float16", choices=["float16", "bfloat16", "float32"], ) p_quant.add_argument("--n-iter", type=int, default=10, help="ITF iterations") p_quant.add_argument( "--no-activation-aware", action="store_true", help="Disable activation-aware quantization", ) p_quant.add_argument("--block-size", type=int, default=0, help="Column block size") p_quant.add_argument("--n-samples", type=int, default=128) p_quant.add_argument("--seq-len", type=int, default=2048) p_quant.add_argument("--dataset", default="wikitext") p_quant.add_argument("--dataset-config", default="wikitext-2-raw-v1") p_quant.add_argument("--seed", type=int, default=42) p_quant.add_argument("--skip-modules", nargs="+", default=None) p_quant.add_argument("--eval", action="store_true") p_quant.add_argument("--eval-samples", type=int, default=40) _add_runtime_mode_arg(p_quant) p_quant.set_defaults(func=cmd_quantize) p_small = subparsers.add_parser( "quantize-small", help="Role-aware sparse asymmetric ternarization for small models", ) p_small.add_argument("model", help="HuggingFace model ID or local path") p_small.add_argument("--output", "-o", required=True, help="Output directory") p_small.add_argument("--device", default="auto") p_small.add_argument( "--dtype", default="float16", choices=["float16", "bfloat16", "float32"], ) p_small.add_argument("--n-samples", type=int, default=16) p_small.add_argument("--seq-len", type=int, default=256) p_small.add_argument("--dataset", default="wikitext") p_small.add_argument("--dataset-config", default="wikitext-2-raw-v1") p_small.add_argument("--seed", type=int, default=42) p_small.add_argument("--group-size", type=int, default=32) p_small.add_argument("--n-iter", type=int, default=10) p_small.add_argument( "--planner", default="practical", choices=["practical", "budgeted", "sensitivity_budget", "auto", "collapse_auto"], help=( "Planner variant: fixed role-aware recipe, role-aware bit-budgeted recipe, " "sensitivity-only matched-bit baseline, held-out PPL selection, or " "held-out prompt-bank collapse-aware selection." ), ) p_small.add_argument("--salient-fraction", type=float, default=0.01) p_small.add_argument("--min-salient-fraction", type=float, default=0.0025) p_small.add_argument("--max-salient-fraction", type=float, default=0.01) p_small.add_argument( "--low-rank-rank", type=int, default=0, help="Optional per-module low-rank residual rank for quantized modules.", ) p_small.add_argument( "--adaptive-low-rank", action="store_true", help="Allocate low-rank rank adaptively per module using residual spectra.", ) p_small.add_argument( "--low-rank-chunk-rank", type=int, default=16, help="Rank chunk used by adaptive low-rank allocation.", ) p_small.add_argument( "--low-rank-target-average-bits", type=float, default=None, help="Optional full-model bit target for adaptive low-rank allocation.", ) p_small.add_argument( "--low-rank-fit-mode", default="activation_regression", choices=["weight_svd", "activation_regression"], help="How to fit optional low-rank residuals for quantized modules.", ) p_small.add_argument( "--low-rank-ridge", type=float, default=1e-4, help="Ridge penalty used for activation-regressed low-rank fitting.", ) p_small.add_argument( "--low-rank-max-samples", type=int, default=4096, help="Maximum captured tokens per module when fitting low-rank residuals.", ) p_small.add_argument( "--calibration-tune-steps", type=int, default=0, help="Optional number of calibration-only LM fine-tune steps for low-rank residuals.", ) p_small.add_argument( "--calibration-tune-lr", type=float, default=5e-5, help="Learning rate for optional low-rank calibration tuning.", ) p_small.add_argument( "--calibration-tune-batch-size", type=int, default=2, help="Batch size for optional low-rank calibration tuning.", ) p_small.add_argument( "--behavior-tune-weight", type=float, default=0.0, help=( "Optional weight for prompt-bank teacher-sequence tuning during low-rank " "calibration. Requires --calibration-tune-steps > 0." ), ) p_small.add_argument( "--behavior-tune-prompt-count", type=int, default=4, help="Number of prompts to use when building behavior-tuning teacher sequences.", ) p_small.add_argument( "--behavior-tune-max-tokens", type=int, default=48, help="Max generated tokens per prompt when building behavior-tuning teacher sequences.", ) p_small.add_argument( "--distill-weight", type=float, default=0.0, help="Optional teacher hidden-state distillation weight for calibration tuning.", ) p_small.add_argument( "--behavior-hidden-weight", type=float, default=0.0, help="Optional teacher hidden-state distillation weight on prompt-bank sequences.", ) p_small.add_argument( "--logit-distill-weight", type=float, default=0.0, help="Optional top-k teacher logit distillation weight for calibration tuning.", ) p_small.add_argument( "--behavior-logit-weight", type=float, default=0.0, help="Optional top-k teacher logit distillation weight on prompt-bank sequences.", ) p_small.add_argument( "--entropy-distill-weight", type=float, default=0.0, help="Optional teacher entropy-floor regularization weight for calibration tuning.", ) p_small.add_argument( "--behavior-entropy-weight", type=float, default=0.0, help="Optional teacher entropy-floor regularization weight on prompt-bank sequences.", ) p_small.add_argument( "--logit-distill-topk", type=int, default=32, help="Teacher top-k to cache for logit distillation.", ) p_small.add_argument( "--logit-distill-temperature", type=float, default=2.0, help="Temperature for top-k teacher logit distillation.", ) p_small.add_argument( "--importance-threshold-scale", type=float, default=0.0, help=( "AWQ-inspired per-channel importance thresholding. When > 0 and activations " "are used, input channels with high activation magnitude get a lower ternary " "threshold (fewer zeros = more signal preserved). 0.0 = uniform (default). " "Typical range: 0.25–0.5." ), ) p_small.add_argument("--adaptive-salient", action="store_true") p_small.add_argument("--boundary-layers", type=int, default=2) p_small.add_argument("--calibration-batch-size", type=int, default=4) p_small.add_argument("--quantize-attention-output", action="store_true") p_small.add_argument("--quantize-mlp-output", action="store_true") p_small.add_argument( "--target-average-bits", type=float, default=None, help="Optional full-model bit budget for the role-aware allocator.", ) p_small.add_argument("--eval", action="store_true") p_small.add_argument("--eval-samples", type=int, default=8) p_small.add_argument("--selection-eval-samples", type=int, default=2) p_small.add_argument("--selection-prompt-count", type=int, default=4) p_small.add_argument("--selection-max-tokens", type=int, default=48) p_small.add_argument("--selection-collapse-weight", type=float, default=2.0) p_small.add_argument("--selection-worst-weight", type=float, default=1.0) p_small.add_argument("--prompt", default=None) p_small.add_argument("--max-tokens", type=int, default=80) _add_runtime_mode_arg(p_small) p_small.set_defaults(func=cmd_quantize_small) p_ptq = subparsers.add_parser( "quantize-ptq", help="Compare or apply broader ternary PTQ families for small models", ) p_ptq.add_argument("model", help="HuggingFace model ID or local path") p_ptq.add_argument("--output", "-o", required=True, help="Output directory") p_ptq.add_argument("--device", default="auto") p_ptq.add_argument( "--dtype", default="float16", choices=["float16", "bfloat16", "float32"], ) p_ptq.add_argument("--n-samples", type=int, default=16) p_ptq.add_argument("--seq-len", type=int, default=256) p_ptq.add_argument("--dataset", default="wikitext") p_ptq.add_argument("--dataset-config", default="wikitext-2-raw-v1") p_ptq.add_argument("--seed", type=int, default=42) p_ptq.add_argument("--group-size", type=int, default=32) p_ptq.add_argument("--n-iter", type=int, default=10) p_ptq.add_argument( "--family", default="controller", choices=["controller", *sorted(FAMILY_PRESETS)], help="PTQ family preset to apply, or controller to select across families.", ) p_ptq.add_argument( "--candidate-families", nargs="*", default=list(get_default_family_candidates()), help="Candidate families considered by the controller.", ) p_ptq.add_argument("--boundary-layers", type=int, default=2) p_ptq.add_argument("--calibration-batch-size", type=int, default=4) p_ptq.add_argument("--quantize-attention-output", action="store_true") p_ptq.add_argument("--quantize-mlp-output", action="store_true") p_ptq.add_argument( "--target-average-bits", type=float, default=None, help="Optional full-model bit target used by budget-aware family presets and selection.", ) p_ptq.add_argument( "--selection-metric", default="ppl", choices=["ppl", "collapse"], help="Controller selection objective.", ) p_ptq.add_argument("--selection-eval-samples", type=int, default=2) p_ptq.add_argument("--selection-prompt-count", type=int, default=4) p_ptq.add_argument("--selection-max-tokens", type=int, default=48) p_ptq.add_argument("--selection-collapse-weight", type=float, default=2.0) p_ptq.add_argument("--selection-worst-weight", type=float, default=1.0) p_ptq.add_argument("--selection-bits-weight", type=float, default=0.25) p_ptq.add_argument("--eval", action="store_true") p_ptq.add_argument("--eval-samples", type=int, default=8) p_ptq.add_argument("--prompt", default=None) p_ptq.add_argument("--max-tokens", type=int, default=80) _add_runtime_mode_arg(p_ptq) p_ptq.set_defaults(func=cmd_quantize_ptq) p_broad = subparsers.add_parser( "quantize-broad", help="Quantize selected components of a broader generative model family", ) p_broad.add_argument("model", help="HuggingFace model ID or local path") p_broad.add_argument("--output", "-o", required=True, help="Output directory") p_broad.add_argument("--device", default="auto") p_broad.add_argument( "--dtype", default="float32", choices=["float16", "bfloat16", "float32"], ) p_broad.add_argument( "--components", nargs="*", default=None, help="Component names to quantize. Defaults to the family-specific broad preset.", ) p_broad.add_argument( "--scheme", default="groupwise", choices=["groupwise", "tritplane2", "tritplane3"], help="Broad quantization scheme.", ) p_broad.add_argument("--group-size", type=int, default=32) p_broad.add_argument("--n-iter", type=int, default=10) p_broad.add_argument("--salient-fraction", type=float, default=0.0) p_broad.add_argument("--rescue-fraction", type=float, default=0.0) p_broad.add_argument("--allow-all-linear", action="store_true") p_broad.add_argument("--seq-len", type=int, default=160) p_broad.add_argument("--calibration-batch-size", type=int, default=2) p_broad.add_argument("--prompt", default=None) p_broad.add_argument("--max-tokens", type=int, default=64) p_broad.add_argument("--eval", action="store_true") p_broad.add_argument( "--push-to-hub", default=None, metavar="REPO_ID", help="Push the quantized model to HuggingFace Hub (e.g. username/my-model-ternary).", ) _add_runtime_mode_arg(p_broad) p_broad.set_defaults( func=cmd_quantize_broad, n_planes=2, ) p_inspect = subparsers.add_parser( "inspect-generative", help="Inspect the generative-family components of a model", ) p_inspect.add_argument("model", help="HuggingFace model ID or local path") p_inspect.add_argument("--device", default="auto") p_inspect.add_argument( "--dtype", default="float32", choices=["float16", "bfloat16", "float32"], ) p_inspect.add_argument( "--output", default=None, help="Optional JSON output path for the component inventory.", ) p_inspect.set_defaults(func=cmd_inspect_generative) p_check = subparsers.add_parser( "check", help="Quick compatibility check for a model (no weights downloaded)", ) p_check.add_argument("model", help="HuggingFace model ID") p_check.set_defaults(func=cmd_check) p_eval = subparsers.add_parser("eval", help="Evaluate saved model perplexity") p_eval.add_argument("model_dir") p_eval.add_argument("--device", default="auto") p_eval.add_argument("--seq-len", type=int, default=2048) p_eval.add_argument("--max-samples", type=int, default=None) _add_runtime_mode_arg(p_eval) p_eval.set_defaults(func=cmd_eval) p_cmp = subparsers.add_parser("compare", help="Compare original vs quantized") p_cmp.add_argument("original") p_cmp.add_argument("ternary") p_cmp.add_argument("--device", default="auto") p_cmp.add_argument("--seq-len", type=int, default=2048) p_cmp.add_argument("--max-samples", type=int, default=40) p_cmp.set_defaults(func=cmd_compare) p_gen = subparsers.add_parser("generate", help="Generate text with saved model") p_gen.add_argument("model_dir") p_gen.add_argument("--prompt", "-p", required=True) p_gen.add_argument("--max-tokens", type=int, default=256) p_gen.add_argument("--temperature", type=float, default=0.7) p_gen.add_argument("--device", default="auto") p_gen.add_argument( "--image-path", default=None, help="Optional image path for image-text-to-text models. If omitted, a demo image is used.", ) _add_runtime_mode_arg(p_gen) p_gen.set_defaults(func=cmd_generate) p_info = subparsers.add_parser("info", help="Show info about a saved model") p_info.add_argument("model_dir") p_info.set_defaults(func=cmd_info) args = parser.parse_args() args.func(args) if __name__ == "__main__": main()