Spaces:
Running
Running
| """ | |
| CLI for ternary quantization of HuggingFace models. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import json | |
| import math | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import torch | |
| def cmd_catalog(args): | |
| """List the repo's known-good and known-probe model entries.""" | |
| from ternary_quant.toolkit import known_models_to_dict, list_known_models | |
| entries = list_known_models(status=args.status, family=args.family) | |
| if args.json: | |
| payload = { | |
| "status_filter": args.status, | |
| "family_filter": args.family, | |
| "models": known_models_to_dict(entries), | |
| } | |
| text = json.dumps(payload, indent=2) | |
| if args.output: | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(text + "\n") | |
| print(f"Wrote catalog to {output_path}") | |
| return | |
| print(text) | |
| return | |
| if not entries: | |
| print("No models matched the requested filters.") | |
| return | |
| grouped: dict[str, list] = {} | |
| for entry in entries: | |
| grouped.setdefault(entry.status, []).append(entry) | |
| for status, status_entries in grouped.items(): | |
| print(status.replace("_", " ").title()) | |
| for entry in status_entries: | |
| print( | |
| f" {entry.model_id:<40} family={entry.family:<18} " | |
| f"path={entry.path:<8} runtime={entry.recommended_runtime}" | |
| ) | |
| print(f" note: {entry.note}") | |
| print(f" artifact: {entry.artifact}") | |
| if args.show_commands and entry.quickstart_command: | |
| print(f" quickstart: {entry.quickstart_command}") | |
| print("") | |
| def cmd_doctor(args): | |
| """Report environment readiness and runtime recommendations.""" | |
| from ternary_quant.toolkit import build_doctor_report, doctor_report_to_text | |
| report = build_doctor_report() | |
| if args.json: | |
| text = json.dumps(report, indent=2) | |
| if args.output: | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(text + "\n") | |
| print(f"Wrote doctor report to {output_path}") | |
| return | |
| print(text) | |
| return | |
| print(doctor_report_to_text(report)) | |
| if args.output: | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(json.dumps(report, indent=2) + "\n") | |
| print(f"\nWrote doctor report to {output_path}") | |
| def cmd_quantize(args): | |
| """Quantize a HuggingFace model with the legacy full-ternary pipeline.""" | |
| from ternary_quant.pipeline import QuantizationConfig, quantize_model | |
| from ternary_quant.storage import save_quantized_model | |
| config = QuantizationConfig( | |
| n_iter=args.n_iter, | |
| use_activation_aware=not args.no_activation_aware, | |
| block_size=args.block_size, | |
| n_samples=args.n_samples, | |
| seq_len=args.seq_len, | |
| dataset=args.dataset, | |
| dataset_config=args.dataset_config, | |
| seed=args.seed, | |
| ) | |
| if args.skip_modules: | |
| config.skip_modules = args.skip_modules | |
| result = quantize_model( | |
| model_name_or_path=args.model, | |
| config=config, | |
| device=args.device, | |
| dtype=_parse_dtype(args.dtype), | |
| ) | |
| save_quantized_model( | |
| ternary_params=result.ternary_params, | |
| model_name=result.model_name, | |
| model_config=result.model_config, | |
| quant_config=result.config, | |
| output_dir=args.output, | |
| stats=result.stats, | |
| ) | |
| if args.eval: | |
| print("\nRunning perplexity evaluation...") | |
| from ternary_quant.eval import evaluate_perplexity | |
| from ternary_quant.inference import load_ternary_model | |
| model, tokenizer = load_ternary_model( | |
| args.output, | |
| device=args.device, | |
| runtime_mode=getattr(args, "runtime_mode", "packed"), | |
| ) | |
| ppl = evaluate_perplexity(model, tokenizer, max_samples=args.eval_samples) | |
| print(f"Ternary model perplexity: {ppl:.2f}") | |
| def cmd_quantize_small(args): | |
| """Quantize a small model with the role-aware sparse asymmetric ternary path.""" | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| from ternary_quant.data import get_calibration_data | |
| from ternary_quant.eval import ( | |
| evaluate_perplexity, | |
| evaluate_prompt_bank, | |
| get_default_prompt_bank, | |
| ) | |
| from ternary_quant.inference import generate_text, load_ternary_model | |
| from ternary_quant.quantizer_small import ( | |
| SmallModelQuantizationConfig, | |
| build_sensitivity_only_plan, | |
| build_role_aware_plan, | |
| config_to_dict, | |
| plan_to_dict, | |
| quantize_small_model_inplace, | |
| summarize_small_model_quantization, | |
| tune_low_rank_residuals_inplace, | |
| ) | |
| from ternary_quant.storage import save_quantized_model | |
| device = _resolve_device(args.device) | |
| dtype = _parse_dtype(args.dtype) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model) | |
| model_config = AutoConfig.from_pretrained(args.model) | |
| calibration_data = get_calibration_data( | |
| args.model, | |
| tokenizer=tokenizer, | |
| n_samples=args.n_samples, | |
| seq_len=args.seq_len, | |
| dataset_name=args.dataset, | |
| dataset_config=args.dataset_config, | |
| seed=args.seed, | |
| ).to(device) | |
| def load_base_model(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| model.eval() | |
| return model | |
| def build_behavior_sequences(prompt_bank: dict) -> list[torch.Tensor]: | |
| sequences = [] | |
| for sample in prompt_bank.get("samples", []): | |
| prompt_ids = tokenizer( | |
| sample["prompt"], | |
| return_tensors="pt", | |
| truncation=False, | |
| )["input_ids"][0] | |
| generated_ids = torch.tensor( | |
| sample.get("generated_token_ids", []), | |
| dtype=torch.long, | |
| ) | |
| full_sequence = torch.cat([prompt_ids, generated_ids], dim=0).unsqueeze(0) | |
| sequences.append(full_sequence) | |
| return sequences | |
| def build_hidden_cache(sequences): | |
| if sequences is None: | |
| return None | |
| if isinstance(sequences, torch.Tensor): | |
| model = load_base_model() | |
| outputs = [] | |
| for start in range(0, sequences.shape[0], args.calibration_tune_batch_size): | |
| batch = sequences[start : start + args.calibration_tune_batch_size] | |
| hidden = model(batch, output_hidden_states=True).hidden_states[-1] | |
| outputs.append(hidden.detach().cpu().to(torch.float16)) | |
| del model | |
| _cleanup_device(device) | |
| return torch.cat(outputs, dim=0) | |
| outputs = [] | |
| model = load_base_model() | |
| for seq in sequences: | |
| hidden = model(seq.to(device), output_hidden_states=True).hidden_states[-1] | |
| outputs.append(hidden.detach().cpu().to(torch.float16)) | |
| del model | |
| _cleanup_device(device) | |
| return outputs | |
| def build_topk_logit_cache(sequences, top_k: int): | |
| if sequences is None or top_k <= 0: | |
| return None | |
| top_k = max(1, int(top_k)) | |
| if isinstance(sequences, torch.Tensor): | |
| model = load_base_model() | |
| indices_out = [] | |
| logits_out = [] | |
| entropy_out = [] | |
| for start in range(0, sequences.shape[0], args.calibration_tune_batch_size): | |
| batch = sequences[start : start + args.calibration_tune_batch_size] | |
| logits = model(batch).logits[:, :-1, :].float() | |
| values, indices = torch.topk(logits, k=min(top_k, logits.shape[-1]), dim=-1) | |
| log_probs = torch.log_softmax(logits, dim=-1) | |
| probs = log_probs.exp() | |
| entropy = -(probs * log_probs).sum(dim=-1) / math.log(max(logits.shape[-1], 2)) | |
| indices_out.append(indices.detach().cpu().to(torch.int32)) | |
| logits_out.append(values.detach().cpu().to(torch.float16)) | |
| entropy_out.append(entropy.detach().cpu().to(torch.float16)) | |
| del model | |
| _cleanup_device(device) | |
| return { | |
| "indices": torch.cat(indices_out, dim=0), | |
| "logits": torch.cat(logits_out, dim=0), | |
| "entropy": torch.cat(entropy_out, dim=0), | |
| } | |
| outputs = [] | |
| model = load_base_model() | |
| for seq in sequences: | |
| logits = model(seq.to(device)).logits[:, :-1, :].float() | |
| values, indices = torch.topk(logits, k=min(top_k, logits.shape[-1]), dim=-1) | |
| log_probs = torch.log_softmax(logits, dim=-1) | |
| probs = log_probs.exp() | |
| entropy = -(probs * log_probs).sum(dim=-1) / math.log(max(logits.shape[-1], 2)) | |
| outputs.append( | |
| { | |
| "indices": indices.detach().cpu().to(torch.int32), | |
| "logits": values.detach().cpu().to(torch.float16), | |
| "entropy": entropy.detach().cpu().to(torch.float16), | |
| } | |
| ) | |
| del model | |
| _cleanup_device(device) | |
| return outputs | |
| def make_config( | |
| planner: str, | |
| ) -> SmallModelQuantizationConfig: | |
| target_average_bits = args.target_average_bits | |
| adaptive_salient = args.adaptive_salient | |
| role_cost_weights = None | |
| if planner == "budgeted" and target_average_bits is None: | |
| target_average_bits = 10.5 | |
| adaptive_salient = True | |
| elif planner == "sensitivity_budget": | |
| if target_average_bits is None: | |
| target_average_bits = 10.5 | |
| adaptive_salient = True | |
| role_cost_weights = _uniform_role_weights() | |
| elif planner == "practical": | |
| target_average_bits = None | |
| config = SmallModelQuantizationConfig( | |
| group_size=args.group_size, | |
| n_iter=args.n_iter, | |
| salient_fraction=args.salient_fraction, | |
| min_salient_fraction=args.min_salient_fraction, | |
| max_salient_fraction=args.max_salient_fraction, | |
| adaptive_salient=adaptive_salient, | |
| low_rank_rank=args.low_rank_rank, | |
| adaptive_low_rank=args.adaptive_low_rank, | |
| low_rank_chunk_rank=args.low_rank_chunk_rank, | |
| low_rank_target_average_bits=args.low_rank_target_average_bits, | |
| low_rank_fit_mode=args.low_rank_fit_mode, | |
| low_rank_ridge=args.low_rank_ridge, | |
| low_rank_max_samples=args.low_rank_max_samples, | |
| n_boundary_layers=args.boundary_layers, | |
| calibration_batch_size=args.calibration_batch_size, | |
| quantize_attention_output=args.quantize_attention_output, | |
| quantize_mlp_output=args.quantize_mlp_output, | |
| target_average_bits=target_average_bits, | |
| importance_threshold_scale=getattr(args, "importance_threshold_scale", 0.0), | |
| role_cost_weights=role_cost_weights | |
| if role_cost_weights is not None | |
| else SmallModelQuantizationConfig().role_cost_weights, | |
| ) | |
| config.base_config.n_samples = args.n_samples | |
| config.base_config.seq_len = args.seq_len | |
| config.base_config.dataset = args.dataset | |
| config.base_config.dataset_config = args.dataset_config | |
| config.base_config.seed = args.seed | |
| return config | |
| def build_plan(model, config: SmallModelQuantizationConfig, planner: str): | |
| if planner == "sensitivity_budget": | |
| return build_sensitivity_only_plan(model, calibration_data, config) | |
| return build_role_aware_plan(model, calibration_data, config) | |
| behavior_sequences = None | |
| calibration_hidden_states = None | |
| behavior_hidden_states = None | |
| calibration_logit_targets = None | |
| behavior_logit_targets = None | |
| if args.calibration_tune_steps > 0 and args.behavior_tune_weight > 0.0: | |
| behavior_prompt_bank = get_default_prompt_bank( | |
| primary_prompt=args.prompt, | |
| max_prompts=args.behavior_tune_prompt_count, | |
| ) | |
| print("Building prompt-bank behavior tuning data...") | |
| behavior_model = load_base_model() | |
| behavior_reference = evaluate_prompt_bank( | |
| behavior_model, | |
| tokenizer, | |
| prompts=behavior_prompt_bank, | |
| max_new_tokens=args.behavior_tune_max_tokens, | |
| ) | |
| behavior_sequences = build_behavior_sequences(behavior_reference) | |
| del behavior_model | |
| _cleanup_device(device) | |
| if args.calibration_tune_steps > 0 and ( | |
| args.distill_weight > 0.0 or args.behavior_hidden_weight > 0.0 | |
| ): | |
| print("Building teacher hidden-state caches...") | |
| calibration_hidden_states = build_hidden_cache(calibration_data) | |
| if behavior_sequences is not None: | |
| behavior_hidden_states = build_hidden_cache(behavior_sequences) | |
| if args.calibration_tune_steps > 0 and ( | |
| args.logit_distill_weight > 0.0 | |
| or args.behavior_logit_weight > 0.0 | |
| or args.entropy_distill_weight > 0.0 | |
| or args.behavior_entropy_weight > 0.0 | |
| ): | |
| print("Building teacher top-k logit caches...") | |
| calibration_logit_targets = build_topk_logit_cache( | |
| calibration_data, | |
| args.logit_distill_topk, | |
| ) | |
| if behavior_sequences is not None: | |
| behavior_logit_targets = build_topk_logit_cache( | |
| behavior_sequences, | |
| args.logit_distill_topk, | |
| ) | |
| selection = None | |
| auto_tuned = False | |
| if args.planner in {"auto", "collapse_auto"}: | |
| candidate_planners = ["practical", "sensitivity_budget"] | |
| best = None | |
| total_quant_time = 0.0 | |
| selection_metric = ( | |
| "collapse_aware" if args.planner == "collapse_auto" else "ppl" | |
| ) | |
| selection = { | |
| "selection_metric": selection_metric, | |
| "candidate_scores": {}, | |
| } | |
| selection_prompt_bank = get_default_prompt_bank( | |
| primary_prompt=args.prompt, | |
| max_prompts=args.selection_prompt_count, | |
| ) | |
| reference_behavior = None | |
| if selection_metric == "collapse_aware": | |
| print("Measuring FP16 prompt-bank behavior...") | |
| reference_model = load_base_model() | |
| reference_behavior = evaluate_prompt_bank( | |
| reference_model, | |
| tokenizer, | |
| prompts=selection_prompt_bank, | |
| max_new_tokens=args.selection_max_tokens, | |
| ) | |
| selection["reference_behavior"] = { | |
| "avg_collapse_score": reference_behavior["avg_collapse_score"], | |
| "worst_collapse_score": reference_behavior["worst_collapse_score"], | |
| "avg_distinct_2": reference_behavior["avg_distinct_2"], | |
| "avg_repeated_3gram_ratio": reference_behavior[ | |
| "avg_repeated_3gram_ratio" | |
| ], | |
| } | |
| del reference_model | |
| _cleanup_device(device) | |
| for planner in candidate_planners: | |
| print(f"Evaluating planner candidate: {planner}") | |
| model = load_base_model() | |
| config = make_config(planner) | |
| t0 = time.time() | |
| plan = build_plan(model, config, planner) | |
| result = quantize_small_model_inplace( | |
| model, | |
| calibration_data=calibration_data, | |
| config=config, | |
| plan=plan, | |
| ) | |
| total_quant_time += time.time() - t0 | |
| summary = summarize_small_model_quantization(result, model) | |
| tune_stats = None | |
| if args.calibration_tune_steps > 0: | |
| tune_stats = tune_low_rank_residuals_inplace( | |
| model, | |
| result, | |
| calibration_data=calibration_data, | |
| n_steps=args.calibration_tune_steps, | |
| lr=args.calibration_tune_lr, | |
| batch_size=args.calibration_tune_batch_size, | |
| max_seq_len=args.seq_len, | |
| behavior_sequences=behavior_sequences, | |
| behavior_weight=args.behavior_tune_weight, | |
| calibration_hidden_states=calibration_hidden_states, | |
| behavior_hidden_states=behavior_hidden_states, | |
| calibration_logit_targets=calibration_logit_targets, | |
| behavior_logit_targets=behavior_logit_targets, | |
| distill_weight=args.distill_weight, | |
| behavior_hidden_weight=args.behavior_hidden_weight, | |
| logit_distill_weight=args.logit_distill_weight, | |
| behavior_logit_weight=args.behavior_logit_weight, | |
| entropy_distill_weight=args.entropy_distill_weight, | |
| behavior_entropy_weight=args.behavior_entropy_weight, | |
| logit_distill_temperature=args.logit_distill_temperature, | |
| seed=args.seed, | |
| ) | |
| summary = summarize_small_model_quantization(result, model) | |
| selection_ppl = evaluate_perplexity( | |
| model, | |
| tokenizer, | |
| seq_len=args.seq_len, | |
| max_samples=args.selection_eval_samples, | |
| ) | |
| selection_score = float(selection_ppl) | |
| selection_behavior = None | |
| if selection_metric == "collapse_aware": | |
| selection_behavior = evaluate_prompt_bank( | |
| model, | |
| tokenizer, | |
| prompts=selection_prompt_bank, | |
| max_new_tokens=args.selection_max_tokens, | |
| ) | |
| reference_avg = ( | |
| 0.0 if reference_behavior is None else reference_behavior["avg_collapse_score"] | |
| ) | |
| reference_worst = ( | |
| reference_avg | |
| if reference_behavior is None | |
| else reference_behavior["worst_collapse_score"] | |
| ) | |
| collapse_excess = max( | |
| selection_behavior["avg_collapse_score"] - reference_avg, | |
| 0.0, | |
| ) | |
| worst_excess = max( | |
| selection_behavior["worst_collapse_score"] - reference_worst, | |
| 0.0, | |
| ) | |
| selection_score = selection_ppl * ( | |
| 1.0 | |
| + args.selection_collapse_weight * collapse_excess | |
| + args.selection_worst_weight * worst_excess | |
| ) | |
| selection["candidate_scores"][planner] = { | |
| "selection_ppl": selection_ppl, | |
| "selection_score": selection_score, | |
| "predicted_average_bits": plan.predicted_average_bits, | |
| "full_model_effective_bits": summary["full_model_effective_bits"], | |
| } | |
| if selection_behavior is not None: | |
| selection["candidate_scores"][planner]["selection_behavior"] = { | |
| "avg_collapse_score": selection_behavior["avg_collapse_score"], | |
| "worst_collapse_score": selection_behavior["worst_collapse_score"], | |
| "avg_distinct_2": selection_behavior["avg_distinct_2"], | |
| "avg_repeated_3gram_ratio": selection_behavior[ | |
| "avg_repeated_3gram_ratio" | |
| ], | |
| } | |
| if tune_stats is not None: | |
| selection["candidate_scores"][planner]["calibration_tune"] = tune_stats | |
| if best is None or selection_score < best["selection_score"]: | |
| if best is not None: | |
| del best["model"] | |
| _cleanup_device(device) | |
| best = { | |
| "model": model, | |
| "config": config, | |
| "plan": plan, | |
| "result": result, | |
| "summary": summary, | |
| "selection_ppl": selection_ppl, | |
| "selection_score": selection_score, | |
| "selection_behavior": selection_behavior, | |
| "planner": planner, | |
| } | |
| else: | |
| del model | |
| _cleanup_device(device) | |
| if best is None: | |
| raise RuntimeError("Auto planner failed to select a candidate.") | |
| model = best["model"] | |
| config = best["config"] | |
| plan = best["plan"] | |
| result = best["result"] | |
| summary = best["summary"] | |
| quant_time = total_quant_time | |
| selected_name = "RAST-collapse-auto" if args.planner == "collapse_auto" else "RAST-auto" | |
| result.plan.method_name = selected_name | |
| summary["method_name"] = selected_name | |
| auto_tuned = args.calibration_tune_steps > 0 | |
| selection.update( | |
| { | |
| "selected_planner": best["planner"], | |
| "selection_ppl": best["selection_ppl"], | |
| "selection_score": best["selection_score"], | |
| } | |
| ) | |
| if best["selection_behavior"] is not None: | |
| selection["selected_behavior"] = { | |
| "avg_collapse_score": best["selection_behavior"]["avg_collapse_score"], | |
| "worst_collapse_score": best["selection_behavior"]["worst_collapse_score"], | |
| "avg_distinct_2": best["selection_behavior"]["avg_distinct_2"], | |
| "avg_repeated_3gram_ratio": best["selection_behavior"][ | |
| "avg_repeated_3gram_ratio" | |
| ], | |
| } | |
| print( | |
| f"Selected planner: {best['planner']} | " | |
| f"held-out score {best['selection_score']:.2f} | " | |
| f"PPL {best['selection_ppl']:.2f} | " | |
| f"full-model bits {summary['full_model_effective_bits']:.2f}" | |
| ) | |
| else: | |
| model = load_base_model() | |
| config = make_config(args.planner) | |
| print("Building role-aware plan...") | |
| t0 = time.time() | |
| plan = build_plan(model, config, args.planner) | |
| print( | |
| f"Plan ready in {time.time() - t0:.1f}s | " | |
| f"Predicted average bits: {plan.predicted_average_bits:.2f}" | |
| ) | |
| print("Applying role-aware quantization...") | |
| t1 = time.time() | |
| result = quantize_small_model_inplace( | |
| model, | |
| calibration_data=calibration_data, | |
| config=config, | |
| plan=plan, | |
| ) | |
| quant_time = time.time() - t1 | |
| summary = summarize_small_model_quantization(result, model) | |
| if args.calibration_tune_steps > 0 and not auto_tuned: | |
| print("Calibrating low-rank residuals...") | |
| t2 = time.time() | |
| tune_stats = tune_low_rank_residuals_inplace( | |
| model, | |
| result, | |
| calibration_data=calibration_data, | |
| n_steps=args.calibration_tune_steps, | |
| lr=args.calibration_tune_lr, | |
| batch_size=args.calibration_tune_batch_size, | |
| max_seq_len=args.seq_len, | |
| behavior_sequences=behavior_sequences, | |
| behavior_weight=args.behavior_tune_weight, | |
| calibration_hidden_states=calibration_hidden_states, | |
| behavior_hidden_states=behavior_hidden_states, | |
| calibration_logit_targets=calibration_logit_targets, | |
| behavior_logit_targets=behavior_logit_targets, | |
| distill_weight=args.distill_weight, | |
| behavior_hidden_weight=args.behavior_hidden_weight, | |
| logit_distill_weight=args.logit_distill_weight, | |
| behavior_logit_weight=args.behavior_logit_weight, | |
| entropy_distill_weight=args.entropy_distill_weight, | |
| behavior_entropy_weight=args.behavior_entropy_weight, | |
| logit_distill_temperature=args.logit_distill_temperature, | |
| seed=args.seed, | |
| ) | |
| quant_time += time.time() - t2 | |
| summary = summarize_small_model_quantization(result, model) | |
| print( | |
| f"Calibration tune complete | " | |
| f"final loss {tune_stats.get('final_loss', float('nan')):.4f} | " | |
| f"wrapped modules {tune_stats['n_wrapped_modules']}" | |
| ) | |
| save_quantized_model( | |
| ternary_params=result.quantized_params, | |
| model_name=args.model, | |
| model_config=model_config, | |
| quant_config=config, | |
| output_dir=args.output, | |
| stats=result.stats, | |
| summary=summary, | |
| plan=result.plan, | |
| method_name=result.plan.method_name, | |
| ) | |
| report = { | |
| "method": result.plan.method_name, | |
| "model": args.model, | |
| "quant_time_sec": quant_time, | |
| "summary": summary, | |
| "plan": plan_to_dict(result.plan), | |
| "config": config_to_dict(config), | |
| } | |
| if selection is not None: | |
| report["selection"] = selection | |
| report_path = Path(args.output) / "role_aware_report.json" | |
| with open(report_path, "w") as f: | |
| json.dump(report, f, indent=2) | |
| print(f"Wrote role-aware report to {report_path}") | |
| if args.eval: | |
| print("\nRunning validation on saved model...") | |
| quantized_model, tokenizer = load_ternary_model( | |
| args.output, | |
| device=device, | |
| runtime_mode=getattr(args, "runtime_mode", "packed"), | |
| ) | |
| ppl = evaluate_perplexity( | |
| quantized_model, | |
| tokenizer, | |
| seq_len=args.seq_len, | |
| max_samples=args.eval_samples, | |
| ) | |
| print(f"Role-aware quantized perplexity: {ppl:.2f}") | |
| if args.prompt: | |
| text = generate_text( | |
| quantized_model, | |
| tokenizer, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_tokens, | |
| do_sample=False, | |
| ) | |
| print(f"Prompt: {args.prompt}") | |
| print(f"Output: {text}") | |
| def cmd_quantize_ptq(args): | |
| """Quantize a small model via a ternary PTQ family or controller.""" | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| from ternary_quant.data import get_calibration_data | |
| from ternary_quant.eval import ( | |
| evaluate_perplexity, | |
| evaluate_prompt_bank, | |
| get_default_prompt_bank, | |
| ) | |
| from ternary_quant.inference import generate_text, load_ternary_model | |
| from ternary_quant.ptq_families import ( | |
| build_family_config, | |
| family_config_to_dict, | |
| get_default_family_candidates, | |
| quantize_family_inplace, | |
| summarize_family_quantization, | |
| ) | |
| from ternary_quant.storage import save_quantized_model | |
| device = _resolve_device(args.device) | |
| dtype = _parse_dtype(args.dtype) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model) | |
| model_config = AutoConfig.from_pretrained(args.model) | |
| calibration_data = get_calibration_data( | |
| args.model, | |
| tokenizer=tokenizer, | |
| n_samples=args.n_samples, | |
| seq_len=args.seq_len, | |
| dataset_name=args.dataset, | |
| dataset_config=args.dataset_config, | |
| seed=args.seed, | |
| ).to(device) | |
| def load_base_model(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| model.eval() | |
| return model | |
| def build_config(family_name: str): | |
| return build_family_config( | |
| family_name, | |
| target_average_bits=args.target_average_bits, | |
| group_size=args.group_size, | |
| n_iter=args.n_iter, | |
| n_boundary_layers=args.boundary_layers, | |
| calibration_batch_size=args.calibration_batch_size, | |
| quantize_attention_output=args.quantize_attention_output, | |
| quantize_mlp_output=args.quantize_mlp_output, | |
| ) | |
| selection = None | |
| if args.family == "controller": | |
| candidate_names = ( | |
| args.candidate_families | |
| if args.candidate_families | |
| else get_default_family_candidates() | |
| ) | |
| selection_metric = args.selection_metric | |
| selection_prompt_bank = get_default_prompt_bank( | |
| primary_prompt=args.prompt, | |
| max_prompts=args.selection_prompt_count, | |
| ) | |
| selection = { | |
| "selection_metric": selection_metric, | |
| "candidate_scores": {}, | |
| } | |
| reference_behavior = None | |
| if selection_metric == "collapse": | |
| print("Measuring FP16 prompt-bank behavior for controller selection...") | |
| reference_model = load_base_model() | |
| reference_behavior = evaluate_prompt_bank( | |
| reference_model, | |
| tokenizer, | |
| prompts=selection_prompt_bank, | |
| max_new_tokens=args.selection_max_tokens, | |
| ) | |
| selection["reference_behavior"] = { | |
| "avg_collapse_score": reference_behavior["avg_collapse_score"], | |
| "worst_collapse_score": reference_behavior["worst_collapse_score"], | |
| "avg_distinct_2": reference_behavior["avg_distinct_2"], | |
| "avg_repeated_3gram_ratio": reference_behavior[ | |
| "avg_repeated_3gram_ratio" | |
| ], | |
| } | |
| del reference_model | |
| _cleanup_device(device) | |
| best = None | |
| total_quant_time = 0.0 | |
| for family_name in candidate_names: | |
| print(f"Evaluating ternary PTQ family candidate: {family_name}") | |
| family_config = build_config(family_name) | |
| model = load_base_model() | |
| t0 = time.time() | |
| result = quantize_family_inplace( | |
| model, | |
| calibration_data=calibration_data, | |
| config=family_config, | |
| ) | |
| total_quant_time += time.time() - t0 | |
| summary = summarize_family_quantization(result) | |
| selection_ppl = evaluate_perplexity( | |
| model, | |
| tokenizer, | |
| seq_len=args.seq_len, | |
| max_samples=args.selection_eval_samples, | |
| ) | |
| selection_score = float(selection_ppl) | |
| selection_behavior = None | |
| if selection_metric == "collapse": | |
| selection_behavior = evaluate_prompt_bank( | |
| model, | |
| tokenizer, | |
| prompts=selection_prompt_bank, | |
| max_new_tokens=args.selection_max_tokens, | |
| ) | |
| reference_avg = ( | |
| 0.0 | |
| if reference_behavior is None | |
| else reference_behavior["avg_collapse_score"] | |
| ) | |
| reference_worst = ( | |
| reference_avg | |
| if reference_behavior is None | |
| else reference_behavior["worst_collapse_score"] | |
| ) | |
| collapse_excess = max( | |
| selection_behavior["avg_collapse_score"] - reference_avg, | |
| 0.0, | |
| ) | |
| worst_excess = max( | |
| selection_behavior["worst_collapse_score"] - reference_worst, | |
| 0.0, | |
| ) | |
| selection_score = selection_ppl * ( | |
| 1.0 | |
| + args.selection_collapse_weight * collapse_excess | |
| + args.selection_worst_weight * worst_excess | |
| ) | |
| if args.target_average_bits is not None: | |
| bits_excess = max( | |
| summary["full_model_effective_bits"] - args.target_average_bits, | |
| 0.0, | |
| ) | |
| selection_score *= ( | |
| 1.0 | |
| + args.selection_bits_weight | |
| * bits_excess | |
| / max(args.target_average_bits, 1e-6) | |
| ) | |
| selection["candidate_scores"][family_name] = { | |
| "label": family_config.label, | |
| "selection_ppl": selection_ppl, | |
| "selection_score": selection_score, | |
| "full_model_effective_bits": summary["full_model_effective_bits"], | |
| "quantized_fraction": summary["quantized_fraction"], | |
| } | |
| if selection_behavior is not None: | |
| selection["candidate_scores"][family_name]["selection_behavior"] = { | |
| "avg_collapse_score": selection_behavior["avg_collapse_score"], | |
| "worst_collapse_score": selection_behavior["worst_collapse_score"], | |
| "avg_distinct_2": selection_behavior["avg_distinct_2"], | |
| "avg_repeated_3gram_ratio": selection_behavior[ | |
| "avg_repeated_3gram_ratio" | |
| ], | |
| } | |
| if best is None or selection_score < best["selection_score"]: | |
| if best is not None: | |
| del best["model"] | |
| _cleanup_device(device) | |
| best = { | |
| "model": model, | |
| "family_config": family_config, | |
| "result": result, | |
| "summary": summary, | |
| "selection_ppl": selection_ppl, | |
| "selection_score": selection_score, | |
| "selection_behavior": selection_behavior, | |
| "family_name": family_name, | |
| } | |
| else: | |
| del model | |
| _cleanup_device(device) | |
| if best is None: | |
| raise RuntimeError("Controller failed to select a ternary PTQ family.") | |
| model = best["model"] | |
| family_config = best["family_config"] | |
| result = best["result"] | |
| summary = best["summary"] | |
| quant_time = total_quant_time | |
| result.plan.method_name = "Ternary-PTQ-auto" | |
| summary["method_name"] = "Ternary-PTQ-auto" | |
| summary["selected_family_preset"] = best["family_name"] | |
| selection.update( | |
| { | |
| "selected_family_preset": best["family_name"], | |
| "selected_family_label": family_config.label, | |
| "selection_ppl": best["selection_ppl"], | |
| "selection_score": best["selection_score"], | |
| } | |
| ) | |
| if best["selection_behavior"] is not None: | |
| selection["selected_behavior"] = { | |
| "avg_collapse_score": best["selection_behavior"]["avg_collapse_score"], | |
| "worst_collapse_score": best["selection_behavior"]["worst_collapse_score"], | |
| "avg_distinct_2": best["selection_behavior"]["avg_distinct_2"], | |
| "avg_repeated_3gram_ratio": best["selection_behavior"][ | |
| "avg_repeated_3gram_ratio" | |
| ], | |
| } | |
| print( | |
| f"Selected family: {best['family_name']} | " | |
| f"held-out score {best['selection_score']:.2f} | " | |
| f"PPL {best['selection_ppl']:.2f} | " | |
| f"full-model bits {summary['full_model_effective_bits']:.2f}" | |
| ) | |
| else: | |
| family_config = build_config(args.family) | |
| print(f"Applying ternary PTQ family: {family_config.label}") | |
| model = load_base_model() | |
| t0 = time.time() | |
| result = quantize_family_inplace( | |
| model, | |
| calibration_data=calibration_data, | |
| config=family_config, | |
| ) | |
| quant_time = time.time() - t0 | |
| summary = summarize_family_quantization(result) | |
| save_quantized_model( | |
| ternary_params=result.quantized_params, | |
| model_name=args.model, | |
| model_config=model_config, | |
| quant_config=family_config, | |
| output_dir=args.output, | |
| stats=result.stats, | |
| summary=summary, | |
| plan=result.plan, | |
| method_name=result.plan.method_name, | |
| ) | |
| report = { | |
| "method": result.plan.method_name, | |
| "model": args.model, | |
| "quant_time_sec": quant_time, | |
| "summary": summary, | |
| "family_config": family_config_to_dict(family_config), | |
| } | |
| if selection is not None: | |
| report["selection"] = selection | |
| report_path = Path(args.output) / "ternary_ptq_report.json" | |
| with open(report_path, "w") as f: | |
| json.dump(report, f, indent=2) | |
| print(f"Wrote ternary PTQ report to {report_path}") | |
| if args.eval: | |
| print("\nRunning validation on saved model...") | |
| quantized_model, tokenizer = load_ternary_model( | |
| args.output, | |
| device=device, | |
| runtime_mode=getattr(args, "runtime_mode", "packed"), | |
| ) | |
| ppl = evaluate_perplexity( | |
| quantized_model, | |
| tokenizer, | |
| seq_len=args.seq_len, | |
| max_samples=args.eval_samples, | |
| ) | |
| print(f"Ternary PTQ perplexity: {ppl:.2f}") | |
| if args.prompt: | |
| text = generate_text( | |
| quantized_model, | |
| tokenizer, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_tokens, | |
| do_sample=False, | |
| ) | |
| print(f"Prompt: {args.prompt}") | |
| print(f"Output: {text}") | |
| def cmd_eval(args): | |
| """Evaluate perplexity of a saved quantized model.""" | |
| from ternary_quant.eval import evaluate_perplexity | |
| from ternary_quant.inference import load_ternary_model | |
| model, tokenizer = load_ternary_model( | |
| args.model_dir, | |
| device=args.device, | |
| runtime_mode=getattr(args, "runtime_mode", "packed"), | |
| ) | |
| ppl = evaluate_perplexity( | |
| model, | |
| tokenizer, | |
| seq_len=args.seq_len, | |
| max_samples=args.max_samples, | |
| ) | |
| print(f"\nPerplexity: {ppl:.2f}") | |
| def cmd_compare(args): | |
| """Compare original and saved quantized model.""" | |
| from ternary_quant.eval import compare_models | |
| compare_models( | |
| original_model_name=args.original, | |
| ternary_model_dir=args.ternary, | |
| device=args.device, | |
| seq_len=args.seq_len, | |
| max_samples=args.max_samples, | |
| ) | |
| def cmd_generate(args): | |
| """Generate text with a saved quantized model.""" | |
| import numpy as np | |
| from ternary_quant.generative_adapters import inspect_generative_model | |
| from ternary_quant.inference import ( | |
| generate_generative_output, | |
| generate_text, | |
| load_ternary_model, | |
| ) | |
| model, asset = load_ternary_model( | |
| args.model_dir, | |
| device=args.device, | |
| runtime_mode=getattr(args, "runtime_mode", "packed"), | |
| ) | |
| model_info = inspect_generative_model( | |
| model, | |
| model_name=str(getattr(model, "name_or_path", "loaded-model")), | |
| ) | |
| image = None | |
| if args.image_path: | |
| try: | |
| from PIL import Image | |
| except Exception as exc: | |
| raise RuntimeError( | |
| "Reading --image-path requires Pillow. Install pillow or omit the image." | |
| ) from exc | |
| image = np.array(Image.open(args.image_path).convert("RGB")) | |
| if model_info.model_family == "image_text_to_text": | |
| output = generate_generative_output( | |
| model, | |
| asset, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_tokens, | |
| image=image, | |
| ) | |
| else: | |
| output = generate_text( | |
| model, | |
| asset, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_tokens, | |
| temperature=args.temperature, | |
| do_sample=args.temperature > 0, | |
| ) | |
| print(f"\nPrompt: {args.prompt}") | |
| print(f"Output: {output}") | |
| def cmd_inspect_generative(args): | |
| """Inspect a generative model and list its quantizable components.""" | |
| from ternary_quant.generative_adapters import ( | |
| generative_model_info_to_dict, | |
| load_generative_model, | |
| ) | |
| device = _resolve_device(args.device) | |
| dtype = _parse_dtype(args.dtype) | |
| model, _, model_info = load_generative_model( | |
| args.model, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| print(f"Model: {model_info.model_name}") | |
| print(f"Family: {model_info.model_family}") | |
| print(f"Model type: {model_info.model_type}") | |
| print(f"Architectures: {', '.join(model_info.architectures) or 'unknown'}") | |
| print(f"Default broad components: {', '.join(model_info.default_quantization_components)}") | |
| print("\nComponents:") | |
| for component in model_info.components: | |
| sample = ", ".join(component.sample_linear_like_names[:4]) or "(no linear modules)" | |
| print( | |
| f" {component.name:<22} path={component.path:<32} " | |
| f"linears={component.linear_like_count:<4} params={component.parameter_count:<12}" | |
| ) | |
| print(f" sample: {sample}") | |
| if args.output: | |
| payload = generative_model_info_to_dict(model_info) | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_path, "w") as f: | |
| json.dump(payload, f, indent=2) | |
| print(f"\nWrote component inventory to {output_path}") | |
| del model | |
| _cleanup_device(device) | |
| def cmd_quantize_broad(args): | |
| """Quantize selected components of a broad generative model.""" | |
| from ternary_quant.generative_adapters import ( | |
| BroadQuantizationConfig, | |
| broad_quant_config_to_dict, | |
| build_calibration_batches, | |
| evaluate_broad_prompt_bank, | |
| generative_model_info_to_dict, | |
| load_generative_model, | |
| make_demo_image, | |
| quantize_components_inplace, | |
| ) | |
| from ternary_quant.inference import load_ternary_model | |
| from ternary_quant.storage import save_quantized_model | |
| device = _resolve_device(args.device) | |
| dtype = _parse_dtype(args.dtype) | |
| model, asset, model_info = load_generative_model( | |
| args.model, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| components = ( | |
| args.components if args.components else model_info.default_quantization_components | |
| ) | |
| prompts = [args.prompt] if args.prompt else None | |
| broad_config = BroadQuantizationConfig( | |
| components=list(components), | |
| scheme=args.scheme, | |
| group_size=args.group_size, | |
| n_iter=args.n_iter, | |
| salient_fraction=args.salient_fraction, | |
| rescue_fraction=args.rescue_fraction, | |
| n_planes=3 if args.scheme == "tritplane3" else 2, | |
| allow_all_linear=args.allow_all_linear, | |
| max_length=args.seq_len, | |
| calibration_batch_size=args.calibration_batch_size, | |
| calibration_prompts=list(prompts) if prompts is not None else None, | |
| ) | |
| demo_image = make_demo_image() | |
| calibration_batches = build_calibration_batches( | |
| asset, | |
| model_info, | |
| max_length=args.seq_len, | |
| batch_size=args.calibration_batch_size, | |
| prompts=prompts, | |
| demo_images=[demo_image], | |
| ) | |
| result = quantize_components_inplace( | |
| model, | |
| model_info=model_info, | |
| calibration_batches=calibration_batches, | |
| config=broad_config, | |
| ) | |
| save_quantized_model( | |
| ternary_params=result.quantized_params, | |
| model_name=args.model, | |
| model_config=model.config, | |
| quant_config=broad_config, | |
| output_dir=args.output, | |
| stats=result.stats, | |
| summary=result.summary, | |
| method_name=result.summary["method_name"], | |
| model_family=model_info.model_family, | |
| ) | |
| report = { | |
| "method": result.summary["method_name"], | |
| "model": args.model, | |
| "model_info": generative_model_info_to_dict(model_info), | |
| "config": broad_quant_config_to_dict(broad_config), | |
| "summary": result.summary, | |
| } | |
| if args.eval: | |
| quantized_model, quantized_asset = load_ternary_model( | |
| args.output, | |
| device=device, | |
| runtime_mode=getattr(args, "runtime_mode", "packed"), | |
| ) | |
| eval_prompts = prompts or None | |
| validation = evaluate_broad_prompt_bank( | |
| quantized_model, | |
| quantized_asset, | |
| model_info, | |
| prompts=eval_prompts | |
| if eval_prompts is not None | |
| else ( | |
| [args.prompt] | |
| if args.prompt | |
| else ( | |
| ["Describe the image in one short sentence."] | |
| if model_info.model_family == "image_text_to_text" | |
| else [ | |
| "The capital of France is", | |
| "Answer briefly: What is 2 + 2?", | |
| ] | |
| ) | |
| ), | |
| max_new_tokens=args.max_tokens, | |
| demo_image=demo_image, | |
| ) | |
| report["validation"] = validation | |
| print("\nValidation:") | |
| print(f" Avg collapse: {validation['avg_collapse_score']:.3f}") | |
| print(f" Primary output: {validation['primary_text']}") | |
| del quantized_model | |
| _cleanup_device(device) | |
| report_path = Path(args.output) / "broad_generative_report.json" | |
| with open(report_path, "w") as f: | |
| json.dump(report, f, indent=2) | |
| print(f"Wrote broad generative report to {report_path}") | |
| # Print a compact summary | |
| s = result.summary | |
| print(f"\nQuantization summary:") | |
| print(f" Layers quantized: {s['quantized_modules']}") | |
| print(f" Full-model effective bits: {s['full_model_effective_bits']:.2f}") | |
| print(f" Compression ratio: {s['compression_ratio']:.2f}×") | |
| print(f" Avg reconstruction error: {s['avg_relative_error']:.4f}") | |
| if getattr(args, "push_to_hub", None): | |
| _push_to_hub(args.output, args.push_to_hub, args.model, result.summary, broad_config) | |
| def _push_to_hub(output_dir: str, hub_repo: str, source_model: str, summary: dict, config) -> None: | |
| """Push a quantized model directory to HuggingFace Hub.""" | |
| try: | |
| from huggingface_hub import HfApi | |
| except ImportError: | |
| print("huggingface_hub not installed. Run: pip install huggingface_hub") | |
| return | |
| output_path = Path(output_dir) | |
| # Write a model card | |
| model_card = f"""--- | |
| tags: | |
| - ternary-quant | |
| - quantization | |
| - ternary | |
| base_model: {source_model} | |
| --- | |
| # {hub_repo} | |
| Ternary-quantized version of [{source_model}](https://huggingface.co/{source_model}) | |
| produced with [ternary-quant](https://github.com/Asad-Ismail/ternary-quant). | |
| ## Quantization details | |
| - **Scheme**: {getattr(config, 'scheme', 'unknown')} | |
| - **Components**: {', '.join(getattr(config, 'components', []))} | |
| - **Full-model effective bits**: {summary.get('full_model_effective_bits', '?'):.2f} | |
| - **Compression ratio**: {summary.get('compression_ratio', '?'):.2f}× | |
| - **Avg reconstruction error**: {summary.get('avg_relative_error', '?'):.4f} | |
| ## Usage | |
| ```python | |
| from ternary_quant.inference import load_ternary_model | |
| model, tokenizer = load_ternary_model("{hub_repo}", runtime_mode="cached") | |
| inputs = tokenizer("Hello, world!", return_tensors="pt") | |
| outputs = model.generate(**inputs, max_new_tokens=50) | |
| print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
| ``` | |
| Or via CLI: | |
| ```bash | |
| pip install ternary-quant | |
| ternary-quant generate {hub_repo} --prompt "Hello" --runtime-mode cached | |
| ``` | |
| """ | |
| card_path = output_path / "README.md" | |
| card_path.write_text(model_card) | |
| api = HfApi() | |
| print(f"Pushing to {hub_repo}...") | |
| api.upload_folder( | |
| folder_path=str(output_path), | |
| repo_id=hub_repo, | |
| repo_type="model", | |
| ) | |
| print(f"Pushed to https://huggingface.co/{hub_repo}") | |
| def cmd_check(args): | |
| """Quick compatibility check using only the model config (no weights downloaded).""" | |
| from ternary_quant.generative_adapters import ( | |
| VLM_MODEL_TYPES, | |
| _default_components_for_family, | |
| detect_model_family_from_config, | |
| ) | |
| from transformers import AutoConfig | |
| print(f"Checking: {args.model}") | |
| try: | |
| config = AutoConfig.from_pretrained(args.model) | |
| except Exception as exc: | |
| print(f" Could not load config: {exc}") | |
| print(" → Model may be gated (requires HF token) or not found.") | |
| return | |
| model_type = getattr(config, "model_type", "unknown") | |
| architectures = list(getattr(config, "architectures", None) or []) | |
| family = detect_model_family_from_config(config) | |
| default_components = _default_components_for_family(family) | |
| print(f" model_type: {model_type}") | |
| print(f" architectures: {', '.join(architectures) or 'unknown'}") | |
| print(f" family: {family}") | |
| print(f" default components to quantize: {', '.join(default_components)}") | |
| is_vlm = model_type in VLM_MODEL_TYPES | |
| has_encoder_decoder = bool(getattr(config, "is_encoder_decoder", False)) | |
| if is_vlm: | |
| print(" → VLM: quantize text_backbone + multimodal_connector") | |
| print(f" ternary-quant quantize-broad {args.model} \\") | |
| print(f" --output ./$(basename {args.model})-ternary \\") | |
| print(f" --components text_backbone multimodal_connector \\") | |
| print(f" --scheme tritplane3 --dtype float16") | |
| elif has_encoder_decoder: | |
| print(" → Seq2seq / audio: quantize decoder") | |
| print(f" ternary-quant quantize-broad {args.model} \\") | |
| print(f" --output ./$(basename {args.model})-ternary \\") | |
| print(f" --components decoder --scheme tritplane3") | |
| else: | |
| print(" → Causal LM: quantize text_backbone") | |
| print(f" ternary-quant quantize-broad {args.model} \\") | |
| print(f" --output ./$(basename {args.model})-ternary \\") | |
| print(f" --components text_backbone --scheme tritplane3") | |
| print() | |
| print(" If quantization fails with 'No quantizable linear modules',") | |
| print(" add --allow-all-linear to quantize all nn.Linear layers.") | |
| def cmd_info(args): | |
| """Show info about a saved quantized model.""" | |
| model_dir = Path(args.model_dir) | |
| meta_path = model_dir / "metadata.json" | |
| if not meta_path.exists(): | |
| print(f"No quantized model found at {model_dir}") | |
| sys.exit(1) | |
| with open(meta_path) as f: | |
| metadata = json.load(f) | |
| print(f"Model: {metadata['model_name']}") | |
| print(f"Model family: {metadata.get('model_family', 'causal_lm')}") | |
| print(f"Method: {metadata.get('method_name', 'unknown')}") | |
| print(f"Format family: {metadata.get('format_family', 'legacy')}") | |
| print(f"Format version: {metadata['format_version']}") | |
| print(f"Layers quantized: {len(metadata['layer_info'])}") | |
| print(f"Packed size: {metadata['total_packed_bytes'] / 1e6:.1f} MB") | |
| print(f"FP16 size: {metadata['total_fp16_bytes'] / 1e6:.1f} MB") | |
| print(f"Compression: {metadata['compression_ratio']:.1f}x") | |
| qc = metadata["quant_config"] | |
| print("\nQuantization config:") | |
| for key, value in qc.items(): | |
| if key == "base_config": | |
| continue | |
| print(f" {key}: {value}") | |
| if metadata.get("summary"): | |
| summary = metadata["summary"] | |
| print("\nSummary:") | |
| for key in [ | |
| "quantized_fraction", | |
| "avg_relative_error", | |
| "avg_effective_bits", | |
| "full_model_effective_bits", | |
| "total_sparse_nnz", | |
| ]: | |
| if key in summary: | |
| value = summary[key] | |
| if isinstance(value, float): | |
| if "fraction" in key: | |
| print(f" {key}: {value:.1%}") | |
| else: | |
| print(f" {key}: {value:.4f}") | |
| else: | |
| print(f" {key}: {value}") | |
| if metadata.get("plan"): | |
| plan = metadata["plan"] | |
| print("\nPlan:") | |
| print(f" Method: {plan.get('method_name', 'unknown')}") | |
| print(f" Target average bits: {plan.get('target_average_bits')}") | |
| print(f" Predicted average bits: {plan.get('predicted_average_bits'):.2f}") | |
| def _parse_dtype(s: str) -> torch.dtype: | |
| return { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| }[s] | |
| def _resolve_device(device: str) -> str: | |
| if device != "auto": | |
| return device | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def _add_runtime_mode_arg(parser: argparse.ArgumentParser, *, default: str = "cached") -> None: | |
| parser.add_argument( | |
| "--runtime-mode", | |
| default=default, | |
| choices=["packed", "cached", "native", "metal", "triton", "gemlite"], | |
| help=( | |
| "Inference runtime path for saved quantized layers. " | |
| "'cached': dequantize once at load, fastest on GPU/CPU (recommended). " | |
| "'native': replace layers with nn.Linear, ~1.0× vs FP16. " | |
| "'packed': re-dequantize every forward, minimal live VRAM. " | |
| "'gemlite': NVIDIA GPU only — keeps weights 2-bit packed, good batch throughput. " | |
| "'triton': NVIDIA GPU only — custom Triton kernel, slightly faster than gemlite at batch=1. " | |
| "'metal': Apple Silicon adaptive — Metal kernel with cached fallback." | |
| ), | |
| ) | |
| def _cleanup_device(device: str) -> None: | |
| gc.collect() | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| if device == "mps": | |
| torch.mps.empty_cache() | |
| def _uniform_role_weights() -> dict[str, float]: | |
| return { | |
| "attention_inputs": 1.0, | |
| "attention_output": 1.0, | |
| "mlp_inputs": 1.0, | |
| "mlp_output": 1.0, | |
| } | |
| def main(): | |
| from ternary_quant.ptq_families import FAMILY_PRESETS, get_default_family_candidates | |
| parser = argparse.ArgumentParser( | |
| prog="ternary-quant", | |
| description="Post-training ternary quantization for HuggingFace generative models", | |
| ) | |
| subparsers = parser.add_subparsers(dest="command", required=True) | |
| p_catalog = subparsers.add_parser( | |
| "catalog", | |
| help="List validated, probe-only, and special-handling model entries", | |
| ) | |
| p_catalog.add_argument( | |
| "--status", | |
| default="all", | |
| choices=[ | |
| "all", | |
| "validated", | |
| "component_validated", | |
| "research_validated", | |
| "probe_only", | |
| "special_handling", | |
| ], | |
| ) | |
| p_catalog.add_argument( | |
| "--family", | |
| default="all", | |
| choices=["all", "causal_lm", "seq2seq_lm", "image_text_to_text"], | |
| ) | |
| p_catalog.add_argument("--show-commands", action="store_true") | |
| p_catalog.add_argument("--json", action="store_true") | |
| p_catalog.add_argument("--output", default=None) | |
| p_catalog.set_defaults(func=cmd_catalog) | |
| p_doctor = subparsers.add_parser( | |
| "doctor", | |
| help="Check environment readiness and runtime recommendations", | |
| ) | |
| p_doctor.add_argument("--json", action="store_true") | |
| p_doctor.add_argument("--output", default=None) | |
| p_doctor.set_defaults(func=cmd_doctor) | |
| p_quant = subparsers.add_parser("quantize", help="Quantize a model to ternary") | |
| p_quant.add_argument("model", help="HuggingFace model ID or local path") | |
| p_quant.add_argument("--output", "-o", required=True, help="Output directory") | |
| p_quant.add_argument("--device", default="auto", help="Device (auto/cuda/cpu/mps)") | |
| p_quant.add_argument( | |
| "--dtype", | |
| default="float16", | |
| choices=["float16", "bfloat16", "float32"], | |
| ) | |
| p_quant.add_argument("--n-iter", type=int, default=10, help="ITF iterations") | |
| p_quant.add_argument( | |
| "--no-activation-aware", | |
| action="store_true", | |
| help="Disable activation-aware quantization", | |
| ) | |
| p_quant.add_argument("--block-size", type=int, default=0, help="Column block size") | |
| p_quant.add_argument("--n-samples", type=int, default=128) | |
| p_quant.add_argument("--seq-len", type=int, default=2048) | |
| p_quant.add_argument("--dataset", default="wikitext") | |
| p_quant.add_argument("--dataset-config", default="wikitext-2-raw-v1") | |
| p_quant.add_argument("--seed", type=int, default=42) | |
| p_quant.add_argument("--skip-modules", nargs="+", default=None) | |
| p_quant.add_argument("--eval", action="store_true") | |
| p_quant.add_argument("--eval-samples", type=int, default=40) | |
| _add_runtime_mode_arg(p_quant) | |
| p_quant.set_defaults(func=cmd_quantize) | |
| p_small = subparsers.add_parser( | |
| "quantize-small", | |
| help="Role-aware sparse asymmetric ternarization for small models", | |
| ) | |
| p_small.add_argument("model", help="HuggingFace model ID or local path") | |
| p_small.add_argument("--output", "-o", required=True, help="Output directory") | |
| p_small.add_argument("--device", default="auto") | |
| p_small.add_argument( | |
| "--dtype", | |
| default="float16", | |
| choices=["float16", "bfloat16", "float32"], | |
| ) | |
| p_small.add_argument("--n-samples", type=int, default=16) | |
| p_small.add_argument("--seq-len", type=int, default=256) | |
| p_small.add_argument("--dataset", default="wikitext") | |
| p_small.add_argument("--dataset-config", default="wikitext-2-raw-v1") | |
| p_small.add_argument("--seed", type=int, default=42) | |
| p_small.add_argument("--group-size", type=int, default=32) | |
| p_small.add_argument("--n-iter", type=int, default=10) | |
| p_small.add_argument( | |
| "--planner", | |
| default="practical", | |
| choices=["practical", "budgeted", "sensitivity_budget", "auto", "collapse_auto"], | |
| help=( | |
| "Planner variant: fixed role-aware recipe, role-aware bit-budgeted recipe, " | |
| "sensitivity-only matched-bit baseline, held-out PPL selection, or " | |
| "held-out prompt-bank collapse-aware selection." | |
| ), | |
| ) | |
| p_small.add_argument("--salient-fraction", type=float, default=0.01) | |
| p_small.add_argument("--min-salient-fraction", type=float, default=0.0025) | |
| p_small.add_argument("--max-salient-fraction", type=float, default=0.01) | |
| p_small.add_argument( | |
| "--low-rank-rank", | |
| type=int, | |
| default=0, | |
| help="Optional per-module low-rank residual rank for quantized modules.", | |
| ) | |
| p_small.add_argument( | |
| "--adaptive-low-rank", | |
| action="store_true", | |
| help="Allocate low-rank rank adaptively per module using residual spectra.", | |
| ) | |
| p_small.add_argument( | |
| "--low-rank-chunk-rank", | |
| type=int, | |
| default=16, | |
| help="Rank chunk used by adaptive low-rank allocation.", | |
| ) | |
| p_small.add_argument( | |
| "--low-rank-target-average-bits", | |
| type=float, | |
| default=None, | |
| help="Optional full-model bit target for adaptive low-rank allocation.", | |
| ) | |
| p_small.add_argument( | |
| "--low-rank-fit-mode", | |
| default="activation_regression", | |
| choices=["weight_svd", "activation_regression"], | |
| help="How to fit optional low-rank residuals for quantized modules.", | |
| ) | |
| p_small.add_argument( | |
| "--low-rank-ridge", | |
| type=float, | |
| default=1e-4, | |
| help="Ridge penalty used for activation-regressed low-rank fitting.", | |
| ) | |
| p_small.add_argument( | |
| "--low-rank-max-samples", | |
| type=int, | |
| default=4096, | |
| help="Maximum captured tokens per module when fitting low-rank residuals.", | |
| ) | |
| p_small.add_argument( | |
| "--calibration-tune-steps", | |
| type=int, | |
| default=0, | |
| help="Optional number of calibration-only LM fine-tune steps for low-rank residuals.", | |
| ) | |
| p_small.add_argument( | |
| "--calibration-tune-lr", | |
| type=float, | |
| default=5e-5, | |
| help="Learning rate for optional low-rank calibration tuning.", | |
| ) | |
| p_small.add_argument( | |
| "--calibration-tune-batch-size", | |
| type=int, | |
| default=2, | |
| help="Batch size for optional low-rank calibration tuning.", | |
| ) | |
| p_small.add_argument( | |
| "--behavior-tune-weight", | |
| type=float, | |
| default=0.0, | |
| help=( | |
| "Optional weight for prompt-bank teacher-sequence tuning during low-rank " | |
| "calibration. Requires --calibration-tune-steps > 0." | |
| ), | |
| ) | |
| p_small.add_argument( | |
| "--behavior-tune-prompt-count", | |
| type=int, | |
| default=4, | |
| help="Number of prompts to use when building behavior-tuning teacher sequences.", | |
| ) | |
| p_small.add_argument( | |
| "--behavior-tune-max-tokens", | |
| type=int, | |
| default=48, | |
| help="Max generated tokens per prompt when building behavior-tuning teacher sequences.", | |
| ) | |
| p_small.add_argument( | |
| "--distill-weight", | |
| type=float, | |
| default=0.0, | |
| help="Optional teacher hidden-state distillation weight for calibration tuning.", | |
| ) | |
| p_small.add_argument( | |
| "--behavior-hidden-weight", | |
| type=float, | |
| default=0.0, | |
| help="Optional teacher hidden-state distillation weight on prompt-bank sequences.", | |
| ) | |
| p_small.add_argument( | |
| "--logit-distill-weight", | |
| type=float, | |
| default=0.0, | |
| help="Optional top-k teacher logit distillation weight for calibration tuning.", | |
| ) | |
| p_small.add_argument( | |
| "--behavior-logit-weight", | |
| type=float, | |
| default=0.0, | |
| help="Optional top-k teacher logit distillation weight on prompt-bank sequences.", | |
| ) | |
| p_small.add_argument( | |
| "--entropy-distill-weight", | |
| type=float, | |
| default=0.0, | |
| help="Optional teacher entropy-floor regularization weight for calibration tuning.", | |
| ) | |
| p_small.add_argument( | |
| "--behavior-entropy-weight", | |
| type=float, | |
| default=0.0, | |
| help="Optional teacher entropy-floor regularization weight on prompt-bank sequences.", | |
| ) | |
| p_small.add_argument( | |
| "--logit-distill-topk", | |
| type=int, | |
| default=32, | |
| help="Teacher top-k to cache for logit distillation.", | |
| ) | |
| p_small.add_argument( | |
| "--logit-distill-temperature", | |
| type=float, | |
| default=2.0, | |
| help="Temperature for top-k teacher logit distillation.", | |
| ) | |
| p_small.add_argument( | |
| "--importance-threshold-scale", | |
| type=float, | |
| default=0.0, | |
| help=( | |
| "AWQ-inspired per-channel importance thresholding. When > 0 and activations " | |
| "are used, input channels with high activation magnitude get a lower ternary " | |
| "threshold (fewer zeros = more signal preserved). 0.0 = uniform (default). " | |
| "Typical range: 0.25–0.5." | |
| ), | |
| ) | |
| p_small.add_argument("--adaptive-salient", action="store_true") | |
| p_small.add_argument("--boundary-layers", type=int, default=2) | |
| p_small.add_argument("--calibration-batch-size", type=int, default=4) | |
| p_small.add_argument("--quantize-attention-output", action="store_true") | |
| p_small.add_argument("--quantize-mlp-output", action="store_true") | |
| p_small.add_argument( | |
| "--target-average-bits", | |
| type=float, | |
| default=None, | |
| help="Optional full-model bit budget for the role-aware allocator.", | |
| ) | |
| p_small.add_argument("--eval", action="store_true") | |
| p_small.add_argument("--eval-samples", type=int, default=8) | |
| p_small.add_argument("--selection-eval-samples", type=int, default=2) | |
| p_small.add_argument("--selection-prompt-count", type=int, default=4) | |
| p_small.add_argument("--selection-max-tokens", type=int, default=48) | |
| p_small.add_argument("--selection-collapse-weight", type=float, default=2.0) | |
| p_small.add_argument("--selection-worst-weight", type=float, default=1.0) | |
| p_small.add_argument("--prompt", default=None) | |
| p_small.add_argument("--max-tokens", type=int, default=80) | |
| _add_runtime_mode_arg(p_small) | |
| p_small.set_defaults(func=cmd_quantize_small) | |
| p_ptq = subparsers.add_parser( | |
| "quantize-ptq", | |
| help="Compare or apply broader ternary PTQ families for small models", | |
| ) | |
| p_ptq.add_argument("model", help="HuggingFace model ID or local path") | |
| p_ptq.add_argument("--output", "-o", required=True, help="Output directory") | |
| p_ptq.add_argument("--device", default="auto") | |
| p_ptq.add_argument( | |
| "--dtype", | |
| default="float16", | |
| choices=["float16", "bfloat16", "float32"], | |
| ) | |
| p_ptq.add_argument("--n-samples", type=int, default=16) | |
| p_ptq.add_argument("--seq-len", type=int, default=256) | |
| p_ptq.add_argument("--dataset", default="wikitext") | |
| p_ptq.add_argument("--dataset-config", default="wikitext-2-raw-v1") | |
| p_ptq.add_argument("--seed", type=int, default=42) | |
| p_ptq.add_argument("--group-size", type=int, default=32) | |
| p_ptq.add_argument("--n-iter", type=int, default=10) | |
| p_ptq.add_argument( | |
| "--family", | |
| default="controller", | |
| choices=["controller", *sorted(FAMILY_PRESETS)], | |
| help="PTQ family preset to apply, or controller to select across families.", | |
| ) | |
| p_ptq.add_argument( | |
| "--candidate-families", | |
| nargs="*", | |
| default=list(get_default_family_candidates()), | |
| help="Candidate families considered by the controller.", | |
| ) | |
| p_ptq.add_argument("--boundary-layers", type=int, default=2) | |
| p_ptq.add_argument("--calibration-batch-size", type=int, default=4) | |
| p_ptq.add_argument("--quantize-attention-output", action="store_true") | |
| p_ptq.add_argument("--quantize-mlp-output", action="store_true") | |
| p_ptq.add_argument( | |
| "--target-average-bits", | |
| type=float, | |
| default=None, | |
| help="Optional full-model bit target used by budget-aware family presets and selection.", | |
| ) | |
| p_ptq.add_argument( | |
| "--selection-metric", | |
| default="ppl", | |
| choices=["ppl", "collapse"], | |
| help="Controller selection objective.", | |
| ) | |
| p_ptq.add_argument("--selection-eval-samples", type=int, default=2) | |
| p_ptq.add_argument("--selection-prompt-count", type=int, default=4) | |
| p_ptq.add_argument("--selection-max-tokens", type=int, default=48) | |
| p_ptq.add_argument("--selection-collapse-weight", type=float, default=2.0) | |
| p_ptq.add_argument("--selection-worst-weight", type=float, default=1.0) | |
| p_ptq.add_argument("--selection-bits-weight", type=float, default=0.25) | |
| p_ptq.add_argument("--eval", action="store_true") | |
| p_ptq.add_argument("--eval-samples", type=int, default=8) | |
| p_ptq.add_argument("--prompt", default=None) | |
| p_ptq.add_argument("--max-tokens", type=int, default=80) | |
| _add_runtime_mode_arg(p_ptq) | |
| p_ptq.set_defaults(func=cmd_quantize_ptq) | |
| p_broad = subparsers.add_parser( | |
| "quantize-broad", | |
| help="Quantize selected components of a broader generative model family", | |
| ) | |
| p_broad.add_argument("model", help="HuggingFace model ID or local path") | |
| p_broad.add_argument("--output", "-o", required=True, help="Output directory") | |
| p_broad.add_argument("--device", default="auto") | |
| p_broad.add_argument( | |
| "--dtype", | |
| default="float32", | |
| choices=["float16", "bfloat16", "float32"], | |
| ) | |
| p_broad.add_argument( | |
| "--components", | |
| nargs="*", | |
| default=None, | |
| help="Component names to quantize. Defaults to the family-specific broad preset.", | |
| ) | |
| p_broad.add_argument( | |
| "--scheme", | |
| default="groupwise", | |
| choices=["groupwise", "tritplane2", "tritplane3"], | |
| help="Broad quantization scheme.", | |
| ) | |
| p_broad.add_argument("--group-size", type=int, default=32) | |
| p_broad.add_argument("--n-iter", type=int, default=10) | |
| p_broad.add_argument("--salient-fraction", type=float, default=0.0) | |
| p_broad.add_argument("--rescue-fraction", type=float, default=0.0) | |
| p_broad.add_argument("--allow-all-linear", action="store_true") | |
| p_broad.add_argument("--seq-len", type=int, default=160) | |
| p_broad.add_argument("--calibration-batch-size", type=int, default=2) | |
| p_broad.add_argument("--prompt", default=None) | |
| p_broad.add_argument("--max-tokens", type=int, default=64) | |
| p_broad.add_argument("--eval", action="store_true") | |
| p_broad.add_argument( | |
| "--push-to-hub", | |
| default=None, | |
| metavar="REPO_ID", | |
| help="Push the quantized model to HuggingFace Hub (e.g. username/my-model-ternary).", | |
| ) | |
| _add_runtime_mode_arg(p_broad) | |
| p_broad.set_defaults( | |
| func=cmd_quantize_broad, | |
| n_planes=2, | |
| ) | |
| p_inspect = subparsers.add_parser( | |
| "inspect-generative", | |
| help="Inspect the generative-family components of a model", | |
| ) | |
| p_inspect.add_argument("model", help="HuggingFace model ID or local path") | |
| p_inspect.add_argument("--device", default="auto") | |
| p_inspect.add_argument( | |
| "--dtype", | |
| default="float32", | |
| choices=["float16", "bfloat16", "float32"], | |
| ) | |
| p_inspect.add_argument( | |
| "--output", | |
| default=None, | |
| help="Optional JSON output path for the component inventory.", | |
| ) | |
| p_inspect.set_defaults(func=cmd_inspect_generative) | |
| p_check = subparsers.add_parser( | |
| "check", | |
| help="Quick compatibility check for a model (no weights downloaded)", | |
| ) | |
| p_check.add_argument("model", help="HuggingFace model ID") | |
| p_check.set_defaults(func=cmd_check) | |
| p_eval = subparsers.add_parser("eval", help="Evaluate saved model perplexity") | |
| p_eval.add_argument("model_dir") | |
| p_eval.add_argument("--device", default="auto") | |
| p_eval.add_argument("--seq-len", type=int, default=2048) | |
| p_eval.add_argument("--max-samples", type=int, default=None) | |
| _add_runtime_mode_arg(p_eval) | |
| p_eval.set_defaults(func=cmd_eval) | |
| p_cmp = subparsers.add_parser("compare", help="Compare original vs quantized") | |
| p_cmp.add_argument("original") | |
| p_cmp.add_argument("ternary") | |
| p_cmp.add_argument("--device", default="auto") | |
| p_cmp.add_argument("--seq-len", type=int, default=2048) | |
| p_cmp.add_argument("--max-samples", type=int, default=40) | |
| p_cmp.set_defaults(func=cmd_compare) | |
| p_gen = subparsers.add_parser("generate", help="Generate text with saved model") | |
| p_gen.add_argument("model_dir") | |
| p_gen.add_argument("--prompt", "-p", required=True) | |
| p_gen.add_argument("--max-tokens", type=int, default=256) | |
| p_gen.add_argument("--temperature", type=float, default=0.7) | |
| p_gen.add_argument("--device", default="auto") | |
| p_gen.add_argument( | |
| "--image-path", | |
| default=None, | |
| help="Optional image path for image-text-to-text models. If omitted, a demo image is used.", | |
| ) | |
| _add_runtime_mode_arg(p_gen) | |
| p_gen.set_defaults(func=cmd_generate) | |
| p_info = subparsers.add_parser("info", help="Show info about a saved model") | |
| p_info.add_argument("model_dir") | |
| p_info.set_defaults(func=cmd_info) | |
| args = parser.parse_args() | |
| args.func(args) | |
| if __name__ == "__main__": | |
| main() | |