AsadIsmail's picture
Bundle ternary_quant package directly (private repo fix)
162f86a verified
"""
CLI for ternary quantization of HuggingFace models.
"""
from __future__ import annotations
import argparse
import gc
import json
import math
import sys
import time
from pathlib import Path
import torch
def cmd_catalog(args):
"""List the repo's known-good and known-probe model entries."""
from ternary_quant.toolkit import known_models_to_dict, list_known_models
entries = list_known_models(status=args.status, family=args.family)
if args.json:
payload = {
"status_filter": args.status,
"family_filter": args.family,
"models": known_models_to_dict(entries),
}
text = json.dumps(payload, indent=2)
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(text + "\n")
print(f"Wrote catalog to {output_path}")
return
print(text)
return
if not entries:
print("No models matched the requested filters.")
return
grouped: dict[str, list] = {}
for entry in entries:
grouped.setdefault(entry.status, []).append(entry)
for status, status_entries in grouped.items():
print(status.replace("_", " ").title())
for entry in status_entries:
print(
f" {entry.model_id:<40} family={entry.family:<18} "
f"path={entry.path:<8} runtime={entry.recommended_runtime}"
)
print(f" note: {entry.note}")
print(f" artifact: {entry.artifact}")
if args.show_commands and entry.quickstart_command:
print(f" quickstart: {entry.quickstart_command}")
print("")
def cmd_doctor(args):
"""Report environment readiness and runtime recommendations."""
from ternary_quant.toolkit import build_doctor_report, doctor_report_to_text
report = build_doctor_report()
if args.json:
text = json.dumps(report, indent=2)
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(text + "\n")
print(f"Wrote doctor report to {output_path}")
return
print(text)
return
print(doctor_report_to_text(report))
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(report, indent=2) + "\n")
print(f"\nWrote doctor report to {output_path}")
def cmd_quantize(args):
"""Quantize a HuggingFace model with the legacy full-ternary pipeline."""
from ternary_quant.pipeline import QuantizationConfig, quantize_model
from ternary_quant.storage import save_quantized_model
config = QuantizationConfig(
n_iter=args.n_iter,
use_activation_aware=not args.no_activation_aware,
block_size=args.block_size,
n_samples=args.n_samples,
seq_len=args.seq_len,
dataset=args.dataset,
dataset_config=args.dataset_config,
seed=args.seed,
)
if args.skip_modules:
config.skip_modules = args.skip_modules
result = quantize_model(
model_name_or_path=args.model,
config=config,
device=args.device,
dtype=_parse_dtype(args.dtype),
)
save_quantized_model(
ternary_params=result.ternary_params,
model_name=result.model_name,
model_config=result.model_config,
quant_config=result.config,
output_dir=args.output,
stats=result.stats,
)
if args.eval:
print("\nRunning perplexity evaluation...")
from ternary_quant.eval import evaluate_perplexity
from ternary_quant.inference import load_ternary_model
model, tokenizer = load_ternary_model(
args.output,
device=args.device,
runtime_mode=getattr(args, "runtime_mode", "packed"),
)
ppl = evaluate_perplexity(model, tokenizer, max_samples=args.eval_samples)
print(f"Ternary model perplexity: {ppl:.2f}")
def cmd_quantize_small(args):
"""Quantize a small model with the role-aware sparse asymmetric ternary path."""
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from ternary_quant.data import get_calibration_data
from ternary_quant.eval import (
evaluate_perplexity,
evaluate_prompt_bank,
get_default_prompt_bank,
)
from ternary_quant.inference import generate_text, load_ternary_model
from ternary_quant.quantizer_small import (
SmallModelQuantizationConfig,
build_sensitivity_only_plan,
build_role_aware_plan,
config_to_dict,
plan_to_dict,
quantize_small_model_inplace,
summarize_small_model_quantization,
tune_low_rank_residuals_inplace,
)
from ternary_quant.storage import save_quantized_model
device = _resolve_device(args.device)
dtype = _parse_dtype(args.dtype)
tokenizer = AutoTokenizer.from_pretrained(args.model)
model_config = AutoConfig.from_pretrained(args.model)
calibration_data = get_calibration_data(
args.model,
tokenizer=tokenizer,
n_samples=args.n_samples,
seq_len=args.seq_len,
dataset_name=args.dataset,
dataset_config=args.dataset_config,
seed=args.seed,
).to(device)
def load_base_model():
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
low_cpu_mem_usage=True,
).to(device)
model.eval()
return model
def build_behavior_sequences(prompt_bank: dict) -> list[torch.Tensor]:
sequences = []
for sample in prompt_bank.get("samples", []):
prompt_ids = tokenizer(
sample["prompt"],
return_tensors="pt",
truncation=False,
)["input_ids"][0]
generated_ids = torch.tensor(
sample.get("generated_token_ids", []),
dtype=torch.long,
)
full_sequence = torch.cat([prompt_ids, generated_ids], dim=0).unsqueeze(0)
sequences.append(full_sequence)
return sequences
@torch.no_grad()
def build_hidden_cache(sequences):
if sequences is None:
return None
if isinstance(sequences, torch.Tensor):
model = load_base_model()
outputs = []
for start in range(0, sequences.shape[0], args.calibration_tune_batch_size):
batch = sequences[start : start + args.calibration_tune_batch_size]
hidden = model(batch, output_hidden_states=True).hidden_states[-1]
outputs.append(hidden.detach().cpu().to(torch.float16))
del model
_cleanup_device(device)
return torch.cat(outputs, dim=0)
outputs = []
model = load_base_model()
for seq in sequences:
hidden = model(seq.to(device), output_hidden_states=True).hidden_states[-1]
outputs.append(hidden.detach().cpu().to(torch.float16))
del model
_cleanup_device(device)
return outputs
@torch.no_grad()
def build_topk_logit_cache(sequences, top_k: int):
if sequences is None or top_k <= 0:
return None
top_k = max(1, int(top_k))
if isinstance(sequences, torch.Tensor):
model = load_base_model()
indices_out = []
logits_out = []
entropy_out = []
for start in range(0, sequences.shape[0], args.calibration_tune_batch_size):
batch = sequences[start : start + args.calibration_tune_batch_size]
logits = model(batch).logits[:, :-1, :].float()
values, indices = torch.topk(logits, k=min(top_k, logits.shape[-1]), dim=-1)
log_probs = torch.log_softmax(logits, dim=-1)
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(dim=-1) / math.log(max(logits.shape[-1], 2))
indices_out.append(indices.detach().cpu().to(torch.int32))
logits_out.append(values.detach().cpu().to(torch.float16))
entropy_out.append(entropy.detach().cpu().to(torch.float16))
del model
_cleanup_device(device)
return {
"indices": torch.cat(indices_out, dim=0),
"logits": torch.cat(logits_out, dim=0),
"entropy": torch.cat(entropy_out, dim=0),
}
outputs = []
model = load_base_model()
for seq in sequences:
logits = model(seq.to(device)).logits[:, :-1, :].float()
values, indices = torch.topk(logits, k=min(top_k, logits.shape[-1]), dim=-1)
log_probs = torch.log_softmax(logits, dim=-1)
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(dim=-1) / math.log(max(logits.shape[-1], 2))
outputs.append(
{
"indices": indices.detach().cpu().to(torch.int32),
"logits": values.detach().cpu().to(torch.float16),
"entropy": entropy.detach().cpu().to(torch.float16),
}
)
del model
_cleanup_device(device)
return outputs
def make_config(
planner: str,
) -> SmallModelQuantizationConfig:
target_average_bits = args.target_average_bits
adaptive_salient = args.adaptive_salient
role_cost_weights = None
if planner == "budgeted" and target_average_bits is None:
target_average_bits = 10.5
adaptive_salient = True
elif planner == "sensitivity_budget":
if target_average_bits is None:
target_average_bits = 10.5
adaptive_salient = True
role_cost_weights = _uniform_role_weights()
elif planner == "practical":
target_average_bits = None
config = SmallModelQuantizationConfig(
group_size=args.group_size,
n_iter=args.n_iter,
salient_fraction=args.salient_fraction,
min_salient_fraction=args.min_salient_fraction,
max_salient_fraction=args.max_salient_fraction,
adaptive_salient=adaptive_salient,
low_rank_rank=args.low_rank_rank,
adaptive_low_rank=args.adaptive_low_rank,
low_rank_chunk_rank=args.low_rank_chunk_rank,
low_rank_target_average_bits=args.low_rank_target_average_bits,
low_rank_fit_mode=args.low_rank_fit_mode,
low_rank_ridge=args.low_rank_ridge,
low_rank_max_samples=args.low_rank_max_samples,
n_boundary_layers=args.boundary_layers,
calibration_batch_size=args.calibration_batch_size,
quantize_attention_output=args.quantize_attention_output,
quantize_mlp_output=args.quantize_mlp_output,
target_average_bits=target_average_bits,
importance_threshold_scale=getattr(args, "importance_threshold_scale", 0.0),
role_cost_weights=role_cost_weights
if role_cost_weights is not None
else SmallModelQuantizationConfig().role_cost_weights,
)
config.base_config.n_samples = args.n_samples
config.base_config.seq_len = args.seq_len
config.base_config.dataset = args.dataset
config.base_config.dataset_config = args.dataset_config
config.base_config.seed = args.seed
return config
def build_plan(model, config: SmallModelQuantizationConfig, planner: str):
if planner == "sensitivity_budget":
return build_sensitivity_only_plan(model, calibration_data, config)
return build_role_aware_plan(model, calibration_data, config)
behavior_sequences = None
calibration_hidden_states = None
behavior_hidden_states = None
calibration_logit_targets = None
behavior_logit_targets = None
if args.calibration_tune_steps > 0 and args.behavior_tune_weight > 0.0:
behavior_prompt_bank = get_default_prompt_bank(
primary_prompt=args.prompt,
max_prompts=args.behavior_tune_prompt_count,
)
print("Building prompt-bank behavior tuning data...")
behavior_model = load_base_model()
behavior_reference = evaluate_prompt_bank(
behavior_model,
tokenizer,
prompts=behavior_prompt_bank,
max_new_tokens=args.behavior_tune_max_tokens,
)
behavior_sequences = build_behavior_sequences(behavior_reference)
del behavior_model
_cleanup_device(device)
if args.calibration_tune_steps > 0 and (
args.distill_weight > 0.0 or args.behavior_hidden_weight > 0.0
):
print("Building teacher hidden-state caches...")
calibration_hidden_states = build_hidden_cache(calibration_data)
if behavior_sequences is not None:
behavior_hidden_states = build_hidden_cache(behavior_sequences)
if args.calibration_tune_steps > 0 and (
args.logit_distill_weight > 0.0
or args.behavior_logit_weight > 0.0
or args.entropy_distill_weight > 0.0
or args.behavior_entropy_weight > 0.0
):
print("Building teacher top-k logit caches...")
calibration_logit_targets = build_topk_logit_cache(
calibration_data,
args.logit_distill_topk,
)
if behavior_sequences is not None:
behavior_logit_targets = build_topk_logit_cache(
behavior_sequences,
args.logit_distill_topk,
)
selection = None
auto_tuned = False
if args.planner in {"auto", "collapse_auto"}:
candidate_planners = ["practical", "sensitivity_budget"]
best = None
total_quant_time = 0.0
selection_metric = (
"collapse_aware" if args.planner == "collapse_auto" else "ppl"
)
selection = {
"selection_metric": selection_metric,
"candidate_scores": {},
}
selection_prompt_bank = get_default_prompt_bank(
primary_prompt=args.prompt,
max_prompts=args.selection_prompt_count,
)
reference_behavior = None
if selection_metric == "collapse_aware":
print("Measuring FP16 prompt-bank behavior...")
reference_model = load_base_model()
reference_behavior = evaluate_prompt_bank(
reference_model,
tokenizer,
prompts=selection_prompt_bank,
max_new_tokens=args.selection_max_tokens,
)
selection["reference_behavior"] = {
"avg_collapse_score": reference_behavior["avg_collapse_score"],
"worst_collapse_score": reference_behavior["worst_collapse_score"],
"avg_distinct_2": reference_behavior["avg_distinct_2"],
"avg_repeated_3gram_ratio": reference_behavior[
"avg_repeated_3gram_ratio"
],
}
del reference_model
_cleanup_device(device)
for planner in candidate_planners:
print(f"Evaluating planner candidate: {planner}")
model = load_base_model()
config = make_config(planner)
t0 = time.time()
plan = build_plan(model, config, planner)
result = quantize_small_model_inplace(
model,
calibration_data=calibration_data,
config=config,
plan=plan,
)
total_quant_time += time.time() - t0
summary = summarize_small_model_quantization(result, model)
tune_stats = None
if args.calibration_tune_steps > 0:
tune_stats = tune_low_rank_residuals_inplace(
model,
result,
calibration_data=calibration_data,
n_steps=args.calibration_tune_steps,
lr=args.calibration_tune_lr,
batch_size=args.calibration_tune_batch_size,
max_seq_len=args.seq_len,
behavior_sequences=behavior_sequences,
behavior_weight=args.behavior_tune_weight,
calibration_hidden_states=calibration_hidden_states,
behavior_hidden_states=behavior_hidden_states,
calibration_logit_targets=calibration_logit_targets,
behavior_logit_targets=behavior_logit_targets,
distill_weight=args.distill_weight,
behavior_hidden_weight=args.behavior_hidden_weight,
logit_distill_weight=args.logit_distill_weight,
behavior_logit_weight=args.behavior_logit_weight,
entropy_distill_weight=args.entropy_distill_weight,
behavior_entropy_weight=args.behavior_entropy_weight,
logit_distill_temperature=args.logit_distill_temperature,
seed=args.seed,
)
summary = summarize_small_model_quantization(result, model)
selection_ppl = evaluate_perplexity(
model,
tokenizer,
seq_len=args.seq_len,
max_samples=args.selection_eval_samples,
)
selection_score = float(selection_ppl)
selection_behavior = None
if selection_metric == "collapse_aware":
selection_behavior = evaluate_prompt_bank(
model,
tokenizer,
prompts=selection_prompt_bank,
max_new_tokens=args.selection_max_tokens,
)
reference_avg = (
0.0 if reference_behavior is None else reference_behavior["avg_collapse_score"]
)
reference_worst = (
reference_avg
if reference_behavior is None
else reference_behavior["worst_collapse_score"]
)
collapse_excess = max(
selection_behavior["avg_collapse_score"] - reference_avg,
0.0,
)
worst_excess = max(
selection_behavior["worst_collapse_score"] - reference_worst,
0.0,
)
selection_score = selection_ppl * (
1.0
+ args.selection_collapse_weight * collapse_excess
+ args.selection_worst_weight * worst_excess
)
selection["candidate_scores"][planner] = {
"selection_ppl": selection_ppl,
"selection_score": selection_score,
"predicted_average_bits": plan.predicted_average_bits,
"full_model_effective_bits": summary["full_model_effective_bits"],
}
if selection_behavior is not None:
selection["candidate_scores"][planner]["selection_behavior"] = {
"avg_collapse_score": selection_behavior["avg_collapse_score"],
"worst_collapse_score": selection_behavior["worst_collapse_score"],
"avg_distinct_2": selection_behavior["avg_distinct_2"],
"avg_repeated_3gram_ratio": selection_behavior[
"avg_repeated_3gram_ratio"
],
}
if tune_stats is not None:
selection["candidate_scores"][planner]["calibration_tune"] = tune_stats
if best is None or selection_score < best["selection_score"]:
if best is not None:
del best["model"]
_cleanup_device(device)
best = {
"model": model,
"config": config,
"plan": plan,
"result": result,
"summary": summary,
"selection_ppl": selection_ppl,
"selection_score": selection_score,
"selection_behavior": selection_behavior,
"planner": planner,
}
else:
del model
_cleanup_device(device)
if best is None:
raise RuntimeError("Auto planner failed to select a candidate.")
model = best["model"]
config = best["config"]
plan = best["plan"]
result = best["result"]
summary = best["summary"]
quant_time = total_quant_time
selected_name = "RAST-collapse-auto" if args.planner == "collapse_auto" else "RAST-auto"
result.plan.method_name = selected_name
summary["method_name"] = selected_name
auto_tuned = args.calibration_tune_steps > 0
selection.update(
{
"selected_planner": best["planner"],
"selection_ppl": best["selection_ppl"],
"selection_score": best["selection_score"],
}
)
if best["selection_behavior"] is not None:
selection["selected_behavior"] = {
"avg_collapse_score": best["selection_behavior"]["avg_collapse_score"],
"worst_collapse_score": best["selection_behavior"]["worst_collapse_score"],
"avg_distinct_2": best["selection_behavior"]["avg_distinct_2"],
"avg_repeated_3gram_ratio": best["selection_behavior"][
"avg_repeated_3gram_ratio"
],
}
print(
f"Selected planner: {best['planner']} | "
f"held-out score {best['selection_score']:.2f} | "
f"PPL {best['selection_ppl']:.2f} | "
f"full-model bits {summary['full_model_effective_bits']:.2f}"
)
else:
model = load_base_model()
config = make_config(args.planner)
print("Building role-aware plan...")
t0 = time.time()
plan = build_plan(model, config, args.planner)
print(
f"Plan ready in {time.time() - t0:.1f}s | "
f"Predicted average bits: {plan.predicted_average_bits:.2f}"
)
print("Applying role-aware quantization...")
t1 = time.time()
result = quantize_small_model_inplace(
model,
calibration_data=calibration_data,
config=config,
plan=plan,
)
quant_time = time.time() - t1
summary = summarize_small_model_quantization(result, model)
if args.calibration_tune_steps > 0 and not auto_tuned:
print("Calibrating low-rank residuals...")
t2 = time.time()
tune_stats = tune_low_rank_residuals_inplace(
model,
result,
calibration_data=calibration_data,
n_steps=args.calibration_tune_steps,
lr=args.calibration_tune_lr,
batch_size=args.calibration_tune_batch_size,
max_seq_len=args.seq_len,
behavior_sequences=behavior_sequences,
behavior_weight=args.behavior_tune_weight,
calibration_hidden_states=calibration_hidden_states,
behavior_hidden_states=behavior_hidden_states,
calibration_logit_targets=calibration_logit_targets,
behavior_logit_targets=behavior_logit_targets,
distill_weight=args.distill_weight,
behavior_hidden_weight=args.behavior_hidden_weight,
logit_distill_weight=args.logit_distill_weight,
behavior_logit_weight=args.behavior_logit_weight,
entropy_distill_weight=args.entropy_distill_weight,
behavior_entropy_weight=args.behavior_entropy_weight,
logit_distill_temperature=args.logit_distill_temperature,
seed=args.seed,
)
quant_time += time.time() - t2
summary = summarize_small_model_quantization(result, model)
print(
f"Calibration tune complete | "
f"final loss {tune_stats.get('final_loss', float('nan')):.4f} | "
f"wrapped modules {tune_stats['n_wrapped_modules']}"
)
save_quantized_model(
ternary_params=result.quantized_params,
model_name=args.model,
model_config=model_config,
quant_config=config,
output_dir=args.output,
stats=result.stats,
summary=summary,
plan=result.plan,
method_name=result.plan.method_name,
)
report = {
"method": result.plan.method_name,
"model": args.model,
"quant_time_sec": quant_time,
"summary": summary,
"plan": plan_to_dict(result.plan),
"config": config_to_dict(config),
}
if selection is not None:
report["selection"] = selection
report_path = Path(args.output) / "role_aware_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
print(f"Wrote role-aware report to {report_path}")
if args.eval:
print("\nRunning validation on saved model...")
quantized_model, tokenizer = load_ternary_model(
args.output,
device=device,
runtime_mode=getattr(args, "runtime_mode", "packed"),
)
ppl = evaluate_perplexity(
quantized_model,
tokenizer,
seq_len=args.seq_len,
max_samples=args.eval_samples,
)
print(f"Role-aware quantized perplexity: {ppl:.2f}")
if args.prompt:
text = generate_text(
quantized_model,
tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
do_sample=False,
)
print(f"Prompt: {args.prompt}")
print(f"Output: {text}")
def cmd_quantize_ptq(args):
"""Quantize a small model via a ternary PTQ family or controller."""
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from ternary_quant.data import get_calibration_data
from ternary_quant.eval import (
evaluate_perplexity,
evaluate_prompt_bank,
get_default_prompt_bank,
)
from ternary_quant.inference import generate_text, load_ternary_model
from ternary_quant.ptq_families import (
build_family_config,
family_config_to_dict,
get_default_family_candidates,
quantize_family_inplace,
summarize_family_quantization,
)
from ternary_quant.storage import save_quantized_model
device = _resolve_device(args.device)
dtype = _parse_dtype(args.dtype)
tokenizer = AutoTokenizer.from_pretrained(args.model)
model_config = AutoConfig.from_pretrained(args.model)
calibration_data = get_calibration_data(
args.model,
tokenizer=tokenizer,
n_samples=args.n_samples,
seq_len=args.seq_len,
dataset_name=args.dataset,
dataset_config=args.dataset_config,
seed=args.seed,
).to(device)
def load_base_model():
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
low_cpu_mem_usage=True,
).to(device)
model.eval()
return model
def build_config(family_name: str):
return build_family_config(
family_name,
target_average_bits=args.target_average_bits,
group_size=args.group_size,
n_iter=args.n_iter,
n_boundary_layers=args.boundary_layers,
calibration_batch_size=args.calibration_batch_size,
quantize_attention_output=args.quantize_attention_output,
quantize_mlp_output=args.quantize_mlp_output,
)
selection = None
if args.family == "controller":
candidate_names = (
args.candidate_families
if args.candidate_families
else get_default_family_candidates()
)
selection_metric = args.selection_metric
selection_prompt_bank = get_default_prompt_bank(
primary_prompt=args.prompt,
max_prompts=args.selection_prompt_count,
)
selection = {
"selection_metric": selection_metric,
"candidate_scores": {},
}
reference_behavior = None
if selection_metric == "collapse":
print("Measuring FP16 prompt-bank behavior for controller selection...")
reference_model = load_base_model()
reference_behavior = evaluate_prompt_bank(
reference_model,
tokenizer,
prompts=selection_prompt_bank,
max_new_tokens=args.selection_max_tokens,
)
selection["reference_behavior"] = {
"avg_collapse_score": reference_behavior["avg_collapse_score"],
"worst_collapse_score": reference_behavior["worst_collapse_score"],
"avg_distinct_2": reference_behavior["avg_distinct_2"],
"avg_repeated_3gram_ratio": reference_behavior[
"avg_repeated_3gram_ratio"
],
}
del reference_model
_cleanup_device(device)
best = None
total_quant_time = 0.0
for family_name in candidate_names:
print(f"Evaluating ternary PTQ family candidate: {family_name}")
family_config = build_config(family_name)
model = load_base_model()
t0 = time.time()
result = quantize_family_inplace(
model,
calibration_data=calibration_data,
config=family_config,
)
total_quant_time += time.time() - t0
summary = summarize_family_quantization(result)
selection_ppl = evaluate_perplexity(
model,
tokenizer,
seq_len=args.seq_len,
max_samples=args.selection_eval_samples,
)
selection_score = float(selection_ppl)
selection_behavior = None
if selection_metric == "collapse":
selection_behavior = evaluate_prompt_bank(
model,
tokenizer,
prompts=selection_prompt_bank,
max_new_tokens=args.selection_max_tokens,
)
reference_avg = (
0.0
if reference_behavior is None
else reference_behavior["avg_collapse_score"]
)
reference_worst = (
reference_avg
if reference_behavior is None
else reference_behavior["worst_collapse_score"]
)
collapse_excess = max(
selection_behavior["avg_collapse_score"] - reference_avg,
0.0,
)
worst_excess = max(
selection_behavior["worst_collapse_score"] - reference_worst,
0.0,
)
selection_score = selection_ppl * (
1.0
+ args.selection_collapse_weight * collapse_excess
+ args.selection_worst_weight * worst_excess
)
if args.target_average_bits is not None:
bits_excess = max(
summary["full_model_effective_bits"] - args.target_average_bits,
0.0,
)
selection_score *= (
1.0
+ args.selection_bits_weight
* bits_excess
/ max(args.target_average_bits, 1e-6)
)
selection["candidate_scores"][family_name] = {
"label": family_config.label,
"selection_ppl": selection_ppl,
"selection_score": selection_score,
"full_model_effective_bits": summary["full_model_effective_bits"],
"quantized_fraction": summary["quantized_fraction"],
}
if selection_behavior is not None:
selection["candidate_scores"][family_name]["selection_behavior"] = {
"avg_collapse_score": selection_behavior["avg_collapse_score"],
"worst_collapse_score": selection_behavior["worst_collapse_score"],
"avg_distinct_2": selection_behavior["avg_distinct_2"],
"avg_repeated_3gram_ratio": selection_behavior[
"avg_repeated_3gram_ratio"
],
}
if best is None or selection_score < best["selection_score"]:
if best is not None:
del best["model"]
_cleanup_device(device)
best = {
"model": model,
"family_config": family_config,
"result": result,
"summary": summary,
"selection_ppl": selection_ppl,
"selection_score": selection_score,
"selection_behavior": selection_behavior,
"family_name": family_name,
}
else:
del model
_cleanup_device(device)
if best is None:
raise RuntimeError("Controller failed to select a ternary PTQ family.")
model = best["model"]
family_config = best["family_config"]
result = best["result"]
summary = best["summary"]
quant_time = total_quant_time
result.plan.method_name = "Ternary-PTQ-auto"
summary["method_name"] = "Ternary-PTQ-auto"
summary["selected_family_preset"] = best["family_name"]
selection.update(
{
"selected_family_preset": best["family_name"],
"selected_family_label": family_config.label,
"selection_ppl": best["selection_ppl"],
"selection_score": best["selection_score"],
}
)
if best["selection_behavior"] is not None:
selection["selected_behavior"] = {
"avg_collapse_score": best["selection_behavior"]["avg_collapse_score"],
"worst_collapse_score": best["selection_behavior"]["worst_collapse_score"],
"avg_distinct_2": best["selection_behavior"]["avg_distinct_2"],
"avg_repeated_3gram_ratio": best["selection_behavior"][
"avg_repeated_3gram_ratio"
],
}
print(
f"Selected family: {best['family_name']} | "
f"held-out score {best['selection_score']:.2f} | "
f"PPL {best['selection_ppl']:.2f} | "
f"full-model bits {summary['full_model_effective_bits']:.2f}"
)
else:
family_config = build_config(args.family)
print(f"Applying ternary PTQ family: {family_config.label}")
model = load_base_model()
t0 = time.time()
result = quantize_family_inplace(
model,
calibration_data=calibration_data,
config=family_config,
)
quant_time = time.time() - t0
summary = summarize_family_quantization(result)
save_quantized_model(
ternary_params=result.quantized_params,
model_name=args.model,
model_config=model_config,
quant_config=family_config,
output_dir=args.output,
stats=result.stats,
summary=summary,
plan=result.plan,
method_name=result.plan.method_name,
)
report = {
"method": result.plan.method_name,
"model": args.model,
"quant_time_sec": quant_time,
"summary": summary,
"family_config": family_config_to_dict(family_config),
}
if selection is not None:
report["selection"] = selection
report_path = Path(args.output) / "ternary_ptq_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
print(f"Wrote ternary PTQ report to {report_path}")
if args.eval:
print("\nRunning validation on saved model...")
quantized_model, tokenizer = load_ternary_model(
args.output,
device=device,
runtime_mode=getattr(args, "runtime_mode", "packed"),
)
ppl = evaluate_perplexity(
quantized_model,
tokenizer,
seq_len=args.seq_len,
max_samples=args.eval_samples,
)
print(f"Ternary PTQ perplexity: {ppl:.2f}")
if args.prompt:
text = generate_text(
quantized_model,
tokenizer,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
do_sample=False,
)
print(f"Prompt: {args.prompt}")
print(f"Output: {text}")
def cmd_eval(args):
"""Evaluate perplexity of a saved quantized model."""
from ternary_quant.eval import evaluate_perplexity
from ternary_quant.inference import load_ternary_model
model, tokenizer = load_ternary_model(
args.model_dir,
device=args.device,
runtime_mode=getattr(args, "runtime_mode", "packed"),
)
ppl = evaluate_perplexity(
model,
tokenizer,
seq_len=args.seq_len,
max_samples=args.max_samples,
)
print(f"\nPerplexity: {ppl:.2f}")
def cmd_compare(args):
"""Compare original and saved quantized model."""
from ternary_quant.eval import compare_models
compare_models(
original_model_name=args.original,
ternary_model_dir=args.ternary,
device=args.device,
seq_len=args.seq_len,
max_samples=args.max_samples,
)
def cmd_generate(args):
"""Generate text with a saved quantized model."""
import numpy as np
from ternary_quant.generative_adapters import inspect_generative_model
from ternary_quant.inference import (
generate_generative_output,
generate_text,
load_ternary_model,
)
model, asset = load_ternary_model(
args.model_dir,
device=args.device,
runtime_mode=getattr(args, "runtime_mode", "packed"),
)
model_info = inspect_generative_model(
model,
model_name=str(getattr(model, "name_or_path", "loaded-model")),
)
image = None
if args.image_path:
try:
from PIL import Image
except Exception as exc:
raise RuntimeError(
"Reading --image-path requires Pillow. Install pillow or omit the image."
) from exc
image = np.array(Image.open(args.image_path).convert("RGB"))
if model_info.model_family == "image_text_to_text":
output = generate_generative_output(
model,
asset,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
image=image,
)
else:
output = generate_text(
model,
asset,
prompt=args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
do_sample=args.temperature > 0,
)
print(f"\nPrompt: {args.prompt}")
print(f"Output: {output}")
def cmd_inspect_generative(args):
"""Inspect a generative model and list its quantizable components."""
from ternary_quant.generative_adapters import (
generative_model_info_to_dict,
load_generative_model,
)
device = _resolve_device(args.device)
dtype = _parse_dtype(args.dtype)
model, _, model_info = load_generative_model(
args.model,
device=device,
dtype=dtype,
)
print(f"Model: {model_info.model_name}")
print(f"Family: {model_info.model_family}")
print(f"Model type: {model_info.model_type}")
print(f"Architectures: {', '.join(model_info.architectures) or 'unknown'}")
print(f"Default broad components: {', '.join(model_info.default_quantization_components)}")
print("\nComponents:")
for component in model_info.components:
sample = ", ".join(component.sample_linear_like_names[:4]) or "(no linear modules)"
print(
f" {component.name:<22} path={component.path:<32} "
f"linears={component.linear_like_count:<4} params={component.parameter_count:<12}"
)
print(f" sample: {sample}")
if args.output:
payload = generative_model_info_to_dict(model_info)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(payload, f, indent=2)
print(f"\nWrote component inventory to {output_path}")
del model
_cleanup_device(device)
def cmd_quantize_broad(args):
"""Quantize selected components of a broad generative model."""
from ternary_quant.generative_adapters import (
BroadQuantizationConfig,
broad_quant_config_to_dict,
build_calibration_batches,
evaluate_broad_prompt_bank,
generative_model_info_to_dict,
load_generative_model,
make_demo_image,
quantize_components_inplace,
)
from ternary_quant.inference import load_ternary_model
from ternary_quant.storage import save_quantized_model
device = _resolve_device(args.device)
dtype = _parse_dtype(args.dtype)
model, asset, model_info = load_generative_model(
args.model,
device=device,
dtype=dtype,
)
components = (
args.components if args.components else model_info.default_quantization_components
)
prompts = [args.prompt] if args.prompt else None
broad_config = BroadQuantizationConfig(
components=list(components),
scheme=args.scheme,
group_size=args.group_size,
n_iter=args.n_iter,
salient_fraction=args.salient_fraction,
rescue_fraction=args.rescue_fraction,
n_planes=3 if args.scheme == "tritplane3" else 2,
allow_all_linear=args.allow_all_linear,
max_length=args.seq_len,
calibration_batch_size=args.calibration_batch_size,
calibration_prompts=list(prompts) if prompts is not None else None,
)
demo_image = make_demo_image()
calibration_batches = build_calibration_batches(
asset,
model_info,
max_length=args.seq_len,
batch_size=args.calibration_batch_size,
prompts=prompts,
demo_images=[demo_image],
)
result = quantize_components_inplace(
model,
model_info=model_info,
calibration_batches=calibration_batches,
config=broad_config,
)
save_quantized_model(
ternary_params=result.quantized_params,
model_name=args.model,
model_config=model.config,
quant_config=broad_config,
output_dir=args.output,
stats=result.stats,
summary=result.summary,
method_name=result.summary["method_name"],
model_family=model_info.model_family,
)
report = {
"method": result.summary["method_name"],
"model": args.model,
"model_info": generative_model_info_to_dict(model_info),
"config": broad_quant_config_to_dict(broad_config),
"summary": result.summary,
}
if args.eval:
quantized_model, quantized_asset = load_ternary_model(
args.output,
device=device,
runtime_mode=getattr(args, "runtime_mode", "packed"),
)
eval_prompts = prompts or None
validation = evaluate_broad_prompt_bank(
quantized_model,
quantized_asset,
model_info,
prompts=eval_prompts
if eval_prompts is not None
else (
[args.prompt]
if args.prompt
else (
["Describe the image in one short sentence."]
if model_info.model_family == "image_text_to_text"
else [
"The capital of France is",
"Answer briefly: What is 2 + 2?",
]
)
),
max_new_tokens=args.max_tokens,
demo_image=demo_image,
)
report["validation"] = validation
print("\nValidation:")
print(f" Avg collapse: {validation['avg_collapse_score']:.3f}")
print(f" Primary output: {validation['primary_text']}")
del quantized_model
_cleanup_device(device)
report_path = Path(args.output) / "broad_generative_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
print(f"Wrote broad generative report to {report_path}")
# Print a compact summary
s = result.summary
print(f"\nQuantization summary:")
print(f" Layers quantized: {s['quantized_modules']}")
print(f" Full-model effective bits: {s['full_model_effective_bits']:.2f}")
print(f" Compression ratio: {s['compression_ratio']:.2f}×")
print(f" Avg reconstruction error: {s['avg_relative_error']:.4f}")
if getattr(args, "push_to_hub", None):
_push_to_hub(args.output, args.push_to_hub, args.model, result.summary, broad_config)
def _push_to_hub(output_dir: str, hub_repo: str, source_model: str, summary: dict, config) -> None:
"""Push a quantized model directory to HuggingFace Hub."""
try:
from huggingface_hub import HfApi
except ImportError:
print("huggingface_hub not installed. Run: pip install huggingface_hub")
return
output_path = Path(output_dir)
# Write a model card
model_card = f"""---
tags:
- ternary-quant
- quantization
- ternary
base_model: {source_model}
---
# {hub_repo}
Ternary-quantized version of [{source_model}](https://huggingface.co/{source_model})
produced with [ternary-quant](https://github.com/Asad-Ismail/ternary-quant).
## Quantization details
- **Scheme**: {getattr(config, 'scheme', 'unknown')}
- **Components**: {', '.join(getattr(config, 'components', []))}
- **Full-model effective bits**: {summary.get('full_model_effective_bits', '?'):.2f}
- **Compression ratio**: {summary.get('compression_ratio', '?'):.2f}×
- **Avg reconstruction error**: {summary.get('avg_relative_error', '?'):.4f}
## Usage
```python
from ternary_quant.inference import load_ternary_model
model, tokenizer = load_ternary_model("{hub_repo}", runtime_mode="cached")
inputs = tokenizer("Hello, world!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
Or via CLI:
```bash
pip install ternary-quant
ternary-quant generate {hub_repo} --prompt "Hello" --runtime-mode cached
```
"""
card_path = output_path / "README.md"
card_path.write_text(model_card)
api = HfApi()
print(f"Pushing to {hub_repo}...")
api.upload_folder(
folder_path=str(output_path),
repo_id=hub_repo,
repo_type="model",
)
print(f"Pushed to https://huggingface.co/{hub_repo}")
def cmd_check(args):
"""Quick compatibility check using only the model config (no weights downloaded)."""
from ternary_quant.generative_adapters import (
VLM_MODEL_TYPES,
_default_components_for_family,
detect_model_family_from_config,
)
from transformers import AutoConfig
print(f"Checking: {args.model}")
try:
config = AutoConfig.from_pretrained(args.model)
except Exception as exc:
print(f" Could not load config: {exc}")
print(" → Model may be gated (requires HF token) or not found.")
return
model_type = getattr(config, "model_type", "unknown")
architectures = list(getattr(config, "architectures", None) or [])
family = detect_model_family_from_config(config)
default_components = _default_components_for_family(family)
print(f" model_type: {model_type}")
print(f" architectures: {', '.join(architectures) or 'unknown'}")
print(f" family: {family}")
print(f" default components to quantize: {', '.join(default_components)}")
is_vlm = model_type in VLM_MODEL_TYPES
has_encoder_decoder = bool(getattr(config, "is_encoder_decoder", False))
if is_vlm:
print(" → VLM: quantize text_backbone + multimodal_connector")
print(f" ternary-quant quantize-broad {args.model} \\")
print(f" --output ./$(basename {args.model})-ternary \\")
print(f" --components text_backbone multimodal_connector \\")
print(f" --scheme tritplane3 --dtype float16")
elif has_encoder_decoder:
print(" → Seq2seq / audio: quantize decoder")
print(f" ternary-quant quantize-broad {args.model} \\")
print(f" --output ./$(basename {args.model})-ternary \\")
print(f" --components decoder --scheme tritplane3")
else:
print(" → Causal LM: quantize text_backbone")
print(f" ternary-quant quantize-broad {args.model} \\")
print(f" --output ./$(basename {args.model})-ternary \\")
print(f" --components text_backbone --scheme tritplane3")
print()
print(" If quantization fails with 'No quantizable linear modules',")
print(" add --allow-all-linear to quantize all nn.Linear layers.")
def cmd_info(args):
"""Show info about a saved quantized model."""
model_dir = Path(args.model_dir)
meta_path = model_dir / "metadata.json"
if not meta_path.exists():
print(f"No quantized model found at {model_dir}")
sys.exit(1)
with open(meta_path) as f:
metadata = json.load(f)
print(f"Model: {metadata['model_name']}")
print(f"Model family: {metadata.get('model_family', 'causal_lm')}")
print(f"Method: {metadata.get('method_name', 'unknown')}")
print(f"Format family: {metadata.get('format_family', 'legacy')}")
print(f"Format version: {metadata['format_version']}")
print(f"Layers quantized: {len(metadata['layer_info'])}")
print(f"Packed size: {metadata['total_packed_bytes'] / 1e6:.1f} MB")
print(f"FP16 size: {metadata['total_fp16_bytes'] / 1e6:.1f} MB")
print(f"Compression: {metadata['compression_ratio']:.1f}x")
qc = metadata["quant_config"]
print("\nQuantization config:")
for key, value in qc.items():
if key == "base_config":
continue
print(f" {key}: {value}")
if metadata.get("summary"):
summary = metadata["summary"]
print("\nSummary:")
for key in [
"quantized_fraction",
"avg_relative_error",
"avg_effective_bits",
"full_model_effective_bits",
"total_sparse_nnz",
]:
if key in summary:
value = summary[key]
if isinstance(value, float):
if "fraction" in key:
print(f" {key}: {value:.1%}")
else:
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
if metadata.get("plan"):
plan = metadata["plan"]
print("\nPlan:")
print(f" Method: {plan.get('method_name', 'unknown')}")
print(f" Target average bits: {plan.get('target_average_bits')}")
print(f" Predicted average bits: {plan.get('predicted_average_bits'):.2f}")
def _parse_dtype(s: str) -> torch.dtype:
return {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}[s]
def _resolve_device(device: str) -> str:
if device != "auto":
return device
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def _add_runtime_mode_arg(parser: argparse.ArgumentParser, *, default: str = "cached") -> None:
parser.add_argument(
"--runtime-mode",
default=default,
choices=["packed", "cached", "native", "metal", "triton", "gemlite"],
help=(
"Inference runtime path for saved quantized layers. "
"'cached': dequantize once at load, fastest on GPU/CPU (recommended). "
"'native': replace layers with nn.Linear, ~1.0× vs FP16. "
"'packed': re-dequantize every forward, minimal live VRAM. "
"'gemlite': NVIDIA GPU only — keeps weights 2-bit packed, good batch throughput. "
"'triton': NVIDIA GPU only — custom Triton kernel, slightly faster than gemlite at batch=1. "
"'metal': Apple Silicon adaptive — Metal kernel with cached fallback."
),
)
def _cleanup_device(device: str) -> None:
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
if device == "mps":
torch.mps.empty_cache()
def _uniform_role_weights() -> dict[str, float]:
return {
"attention_inputs": 1.0,
"attention_output": 1.0,
"mlp_inputs": 1.0,
"mlp_output": 1.0,
}
def main():
from ternary_quant.ptq_families import FAMILY_PRESETS, get_default_family_candidates
parser = argparse.ArgumentParser(
prog="ternary-quant",
description="Post-training ternary quantization for HuggingFace generative models",
)
subparsers = parser.add_subparsers(dest="command", required=True)
p_catalog = subparsers.add_parser(
"catalog",
help="List validated, probe-only, and special-handling model entries",
)
p_catalog.add_argument(
"--status",
default="all",
choices=[
"all",
"validated",
"component_validated",
"research_validated",
"probe_only",
"special_handling",
],
)
p_catalog.add_argument(
"--family",
default="all",
choices=["all", "causal_lm", "seq2seq_lm", "image_text_to_text"],
)
p_catalog.add_argument("--show-commands", action="store_true")
p_catalog.add_argument("--json", action="store_true")
p_catalog.add_argument("--output", default=None)
p_catalog.set_defaults(func=cmd_catalog)
p_doctor = subparsers.add_parser(
"doctor",
help="Check environment readiness and runtime recommendations",
)
p_doctor.add_argument("--json", action="store_true")
p_doctor.add_argument("--output", default=None)
p_doctor.set_defaults(func=cmd_doctor)
p_quant = subparsers.add_parser("quantize", help="Quantize a model to ternary")
p_quant.add_argument("model", help="HuggingFace model ID or local path")
p_quant.add_argument("--output", "-o", required=True, help="Output directory")
p_quant.add_argument("--device", default="auto", help="Device (auto/cuda/cpu/mps)")
p_quant.add_argument(
"--dtype",
default="float16",
choices=["float16", "bfloat16", "float32"],
)
p_quant.add_argument("--n-iter", type=int, default=10, help="ITF iterations")
p_quant.add_argument(
"--no-activation-aware",
action="store_true",
help="Disable activation-aware quantization",
)
p_quant.add_argument("--block-size", type=int, default=0, help="Column block size")
p_quant.add_argument("--n-samples", type=int, default=128)
p_quant.add_argument("--seq-len", type=int, default=2048)
p_quant.add_argument("--dataset", default="wikitext")
p_quant.add_argument("--dataset-config", default="wikitext-2-raw-v1")
p_quant.add_argument("--seed", type=int, default=42)
p_quant.add_argument("--skip-modules", nargs="+", default=None)
p_quant.add_argument("--eval", action="store_true")
p_quant.add_argument("--eval-samples", type=int, default=40)
_add_runtime_mode_arg(p_quant)
p_quant.set_defaults(func=cmd_quantize)
p_small = subparsers.add_parser(
"quantize-small",
help="Role-aware sparse asymmetric ternarization for small models",
)
p_small.add_argument("model", help="HuggingFace model ID or local path")
p_small.add_argument("--output", "-o", required=True, help="Output directory")
p_small.add_argument("--device", default="auto")
p_small.add_argument(
"--dtype",
default="float16",
choices=["float16", "bfloat16", "float32"],
)
p_small.add_argument("--n-samples", type=int, default=16)
p_small.add_argument("--seq-len", type=int, default=256)
p_small.add_argument("--dataset", default="wikitext")
p_small.add_argument("--dataset-config", default="wikitext-2-raw-v1")
p_small.add_argument("--seed", type=int, default=42)
p_small.add_argument("--group-size", type=int, default=32)
p_small.add_argument("--n-iter", type=int, default=10)
p_small.add_argument(
"--planner",
default="practical",
choices=["practical", "budgeted", "sensitivity_budget", "auto", "collapse_auto"],
help=(
"Planner variant: fixed role-aware recipe, role-aware bit-budgeted recipe, "
"sensitivity-only matched-bit baseline, held-out PPL selection, or "
"held-out prompt-bank collapse-aware selection."
),
)
p_small.add_argument("--salient-fraction", type=float, default=0.01)
p_small.add_argument("--min-salient-fraction", type=float, default=0.0025)
p_small.add_argument("--max-salient-fraction", type=float, default=0.01)
p_small.add_argument(
"--low-rank-rank",
type=int,
default=0,
help="Optional per-module low-rank residual rank for quantized modules.",
)
p_small.add_argument(
"--adaptive-low-rank",
action="store_true",
help="Allocate low-rank rank adaptively per module using residual spectra.",
)
p_small.add_argument(
"--low-rank-chunk-rank",
type=int,
default=16,
help="Rank chunk used by adaptive low-rank allocation.",
)
p_small.add_argument(
"--low-rank-target-average-bits",
type=float,
default=None,
help="Optional full-model bit target for adaptive low-rank allocation.",
)
p_small.add_argument(
"--low-rank-fit-mode",
default="activation_regression",
choices=["weight_svd", "activation_regression"],
help="How to fit optional low-rank residuals for quantized modules.",
)
p_small.add_argument(
"--low-rank-ridge",
type=float,
default=1e-4,
help="Ridge penalty used for activation-regressed low-rank fitting.",
)
p_small.add_argument(
"--low-rank-max-samples",
type=int,
default=4096,
help="Maximum captured tokens per module when fitting low-rank residuals.",
)
p_small.add_argument(
"--calibration-tune-steps",
type=int,
default=0,
help="Optional number of calibration-only LM fine-tune steps for low-rank residuals.",
)
p_small.add_argument(
"--calibration-tune-lr",
type=float,
default=5e-5,
help="Learning rate for optional low-rank calibration tuning.",
)
p_small.add_argument(
"--calibration-tune-batch-size",
type=int,
default=2,
help="Batch size for optional low-rank calibration tuning.",
)
p_small.add_argument(
"--behavior-tune-weight",
type=float,
default=0.0,
help=(
"Optional weight for prompt-bank teacher-sequence tuning during low-rank "
"calibration. Requires --calibration-tune-steps > 0."
),
)
p_small.add_argument(
"--behavior-tune-prompt-count",
type=int,
default=4,
help="Number of prompts to use when building behavior-tuning teacher sequences.",
)
p_small.add_argument(
"--behavior-tune-max-tokens",
type=int,
default=48,
help="Max generated tokens per prompt when building behavior-tuning teacher sequences.",
)
p_small.add_argument(
"--distill-weight",
type=float,
default=0.0,
help="Optional teacher hidden-state distillation weight for calibration tuning.",
)
p_small.add_argument(
"--behavior-hidden-weight",
type=float,
default=0.0,
help="Optional teacher hidden-state distillation weight on prompt-bank sequences.",
)
p_small.add_argument(
"--logit-distill-weight",
type=float,
default=0.0,
help="Optional top-k teacher logit distillation weight for calibration tuning.",
)
p_small.add_argument(
"--behavior-logit-weight",
type=float,
default=0.0,
help="Optional top-k teacher logit distillation weight on prompt-bank sequences.",
)
p_small.add_argument(
"--entropy-distill-weight",
type=float,
default=0.0,
help="Optional teacher entropy-floor regularization weight for calibration tuning.",
)
p_small.add_argument(
"--behavior-entropy-weight",
type=float,
default=0.0,
help="Optional teacher entropy-floor regularization weight on prompt-bank sequences.",
)
p_small.add_argument(
"--logit-distill-topk",
type=int,
default=32,
help="Teacher top-k to cache for logit distillation.",
)
p_small.add_argument(
"--logit-distill-temperature",
type=float,
default=2.0,
help="Temperature for top-k teacher logit distillation.",
)
p_small.add_argument(
"--importance-threshold-scale",
type=float,
default=0.0,
help=(
"AWQ-inspired per-channel importance thresholding. When > 0 and activations "
"are used, input channels with high activation magnitude get a lower ternary "
"threshold (fewer zeros = more signal preserved). 0.0 = uniform (default). "
"Typical range: 0.25–0.5."
),
)
p_small.add_argument("--adaptive-salient", action="store_true")
p_small.add_argument("--boundary-layers", type=int, default=2)
p_small.add_argument("--calibration-batch-size", type=int, default=4)
p_small.add_argument("--quantize-attention-output", action="store_true")
p_small.add_argument("--quantize-mlp-output", action="store_true")
p_small.add_argument(
"--target-average-bits",
type=float,
default=None,
help="Optional full-model bit budget for the role-aware allocator.",
)
p_small.add_argument("--eval", action="store_true")
p_small.add_argument("--eval-samples", type=int, default=8)
p_small.add_argument("--selection-eval-samples", type=int, default=2)
p_small.add_argument("--selection-prompt-count", type=int, default=4)
p_small.add_argument("--selection-max-tokens", type=int, default=48)
p_small.add_argument("--selection-collapse-weight", type=float, default=2.0)
p_small.add_argument("--selection-worst-weight", type=float, default=1.0)
p_small.add_argument("--prompt", default=None)
p_small.add_argument("--max-tokens", type=int, default=80)
_add_runtime_mode_arg(p_small)
p_small.set_defaults(func=cmd_quantize_small)
p_ptq = subparsers.add_parser(
"quantize-ptq",
help="Compare or apply broader ternary PTQ families for small models",
)
p_ptq.add_argument("model", help="HuggingFace model ID or local path")
p_ptq.add_argument("--output", "-o", required=True, help="Output directory")
p_ptq.add_argument("--device", default="auto")
p_ptq.add_argument(
"--dtype",
default="float16",
choices=["float16", "bfloat16", "float32"],
)
p_ptq.add_argument("--n-samples", type=int, default=16)
p_ptq.add_argument("--seq-len", type=int, default=256)
p_ptq.add_argument("--dataset", default="wikitext")
p_ptq.add_argument("--dataset-config", default="wikitext-2-raw-v1")
p_ptq.add_argument("--seed", type=int, default=42)
p_ptq.add_argument("--group-size", type=int, default=32)
p_ptq.add_argument("--n-iter", type=int, default=10)
p_ptq.add_argument(
"--family",
default="controller",
choices=["controller", *sorted(FAMILY_PRESETS)],
help="PTQ family preset to apply, or controller to select across families.",
)
p_ptq.add_argument(
"--candidate-families",
nargs="*",
default=list(get_default_family_candidates()),
help="Candidate families considered by the controller.",
)
p_ptq.add_argument("--boundary-layers", type=int, default=2)
p_ptq.add_argument("--calibration-batch-size", type=int, default=4)
p_ptq.add_argument("--quantize-attention-output", action="store_true")
p_ptq.add_argument("--quantize-mlp-output", action="store_true")
p_ptq.add_argument(
"--target-average-bits",
type=float,
default=None,
help="Optional full-model bit target used by budget-aware family presets and selection.",
)
p_ptq.add_argument(
"--selection-metric",
default="ppl",
choices=["ppl", "collapse"],
help="Controller selection objective.",
)
p_ptq.add_argument("--selection-eval-samples", type=int, default=2)
p_ptq.add_argument("--selection-prompt-count", type=int, default=4)
p_ptq.add_argument("--selection-max-tokens", type=int, default=48)
p_ptq.add_argument("--selection-collapse-weight", type=float, default=2.0)
p_ptq.add_argument("--selection-worst-weight", type=float, default=1.0)
p_ptq.add_argument("--selection-bits-weight", type=float, default=0.25)
p_ptq.add_argument("--eval", action="store_true")
p_ptq.add_argument("--eval-samples", type=int, default=8)
p_ptq.add_argument("--prompt", default=None)
p_ptq.add_argument("--max-tokens", type=int, default=80)
_add_runtime_mode_arg(p_ptq)
p_ptq.set_defaults(func=cmd_quantize_ptq)
p_broad = subparsers.add_parser(
"quantize-broad",
help="Quantize selected components of a broader generative model family",
)
p_broad.add_argument("model", help="HuggingFace model ID or local path")
p_broad.add_argument("--output", "-o", required=True, help="Output directory")
p_broad.add_argument("--device", default="auto")
p_broad.add_argument(
"--dtype",
default="float32",
choices=["float16", "bfloat16", "float32"],
)
p_broad.add_argument(
"--components",
nargs="*",
default=None,
help="Component names to quantize. Defaults to the family-specific broad preset.",
)
p_broad.add_argument(
"--scheme",
default="groupwise",
choices=["groupwise", "tritplane2", "tritplane3"],
help="Broad quantization scheme.",
)
p_broad.add_argument("--group-size", type=int, default=32)
p_broad.add_argument("--n-iter", type=int, default=10)
p_broad.add_argument("--salient-fraction", type=float, default=0.0)
p_broad.add_argument("--rescue-fraction", type=float, default=0.0)
p_broad.add_argument("--allow-all-linear", action="store_true")
p_broad.add_argument("--seq-len", type=int, default=160)
p_broad.add_argument("--calibration-batch-size", type=int, default=2)
p_broad.add_argument("--prompt", default=None)
p_broad.add_argument("--max-tokens", type=int, default=64)
p_broad.add_argument("--eval", action="store_true")
p_broad.add_argument(
"--push-to-hub",
default=None,
metavar="REPO_ID",
help="Push the quantized model to HuggingFace Hub (e.g. username/my-model-ternary).",
)
_add_runtime_mode_arg(p_broad)
p_broad.set_defaults(
func=cmd_quantize_broad,
n_planes=2,
)
p_inspect = subparsers.add_parser(
"inspect-generative",
help="Inspect the generative-family components of a model",
)
p_inspect.add_argument("model", help="HuggingFace model ID or local path")
p_inspect.add_argument("--device", default="auto")
p_inspect.add_argument(
"--dtype",
default="float32",
choices=["float16", "bfloat16", "float32"],
)
p_inspect.add_argument(
"--output",
default=None,
help="Optional JSON output path for the component inventory.",
)
p_inspect.set_defaults(func=cmd_inspect_generative)
p_check = subparsers.add_parser(
"check",
help="Quick compatibility check for a model (no weights downloaded)",
)
p_check.add_argument("model", help="HuggingFace model ID")
p_check.set_defaults(func=cmd_check)
p_eval = subparsers.add_parser("eval", help="Evaluate saved model perplexity")
p_eval.add_argument("model_dir")
p_eval.add_argument("--device", default="auto")
p_eval.add_argument("--seq-len", type=int, default=2048)
p_eval.add_argument("--max-samples", type=int, default=None)
_add_runtime_mode_arg(p_eval)
p_eval.set_defaults(func=cmd_eval)
p_cmp = subparsers.add_parser("compare", help="Compare original vs quantized")
p_cmp.add_argument("original")
p_cmp.add_argument("ternary")
p_cmp.add_argument("--device", default="auto")
p_cmp.add_argument("--seq-len", type=int, default=2048)
p_cmp.add_argument("--max-samples", type=int, default=40)
p_cmp.set_defaults(func=cmd_compare)
p_gen = subparsers.add_parser("generate", help="Generate text with saved model")
p_gen.add_argument("model_dir")
p_gen.add_argument("--prompt", "-p", required=True)
p_gen.add_argument("--max-tokens", type=int, default=256)
p_gen.add_argument("--temperature", type=float, default=0.7)
p_gen.add_argument("--device", default="auto")
p_gen.add_argument(
"--image-path",
default=None,
help="Optional image path for image-text-to-text models. If omitted, a demo image is used.",
)
_add_runtime_mode_arg(p_gen)
p_gen.set_defaults(func=cmd_generate)
p_info = subparsers.add_parser("info", help="Show info about a saved model")
p_info.add_argument("model_dir")
p_info.set_defaults(func=cmd_info)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()