import gradio as gr import torch from transformers import TorchAoConfig, AutoModel, AutoTokenizer import tempfile from huggingface_hub import HfApi, snapshot_download, list_models from gradio_huggingfacehub_search import HuggingfaceHubSearch from packaging import version import os from torchao.quantization import ( Int4WeightOnlyConfig, Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Float8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig, GemliteUIntXWeightOnlyConfig, ) # === Load Hugging Face token from environment === HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("❌ Missing HF_TOKEN environment variable. Please set it before running the app.") # === Quantization configuration maps === MAP_QUANT_TYPE_TO_NAME = { "Int4WeightOnly": "int4wo", "GemliteUIntXWeightOnly": "intxwo-gemlite", "Int8WeightOnly": "int8wo", "Int8DynamicActivationInt8Weight": "int8da8w8", "Float8WeightOnly": "float8wo", "Float8DynamicActivationFloat8Weight": "float8da8w8", "autoquant": "autoquant", } MAP_QUANT_TYPE_TO_CONFIG = { "Int4WeightOnly": Int4WeightOnlyConfig, "GemliteUIntXWeightOnly": GemliteUIntXWeightOnlyConfig, "Int8WeightOnly": Int8WeightOnlyConfig, "Int8DynamicActivationInt8Weight": Int8DynamicActivationInt8WeightConfig, "Float8WeightOnly": Float8WeightOnlyConfig, "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig, } # === Helper functions === def get_username(): try: api = HfApi(token=HF_TOKEN) info = api.whoami() return info["name"] except Exception: return "anonymous" def check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name): """Check if a model exists in the user's Hugging Face repository.""" try: models = list_models(author=username, token=HF_TOKEN) model_names = [model.id for model in models] if quantized_model_name: repo_name = f"{username}/{quantized_model_name}" else: if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and group_size is not None: repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}" else: repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}" if repo_name in model_names: return f"Model '{repo_name}' already exists in your repository." else: return None except Exception as e: return f"Error checking model existence: {str(e)}" def create_model_card(model_name, quantization_type, group_size): try: model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=HF_TOKEN) readme_path = os.path.join(model_path, "README.md") original_readme = "" if os.path.exists(readme_path): with open(readme_path, "r", encoding="utf-8") as f: original_readme = f.read() except Exception: original_readme = "" yaml_header = f"""--- base_model: - {model_name} tags: - torchao-my-repo --- # {model_name} (Quantized) ## Quantization Details - **Quantization Type**: {quantization_type} - **Group Size**: {group_size} """ if original_readme: yaml_header += "\n\n# 📄 Original Model Info\n\n" + original_readme return yaml_header def quantize_model(model_name, quantization_type, group_size=128, progress=gr.Progress()): print(f"Quantizing model: {quantization_type}") progress(0, desc="Preparing Quantization") if quantization_type == "GemliteUIntXWeightOnly": quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size) elif quantization_type == "Int4WeightOnly": from torchao.dtypes import Int4CPULayout quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size, layout=Int4CPULayout()) elif quantization_type == "autoquant": quant_config = "autoquant" else: quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]() quantization_config = TorchAoConfig(quant_config) progress(0.10, desc="Quantizing model") model = AutoModel.from_pretrained( model_name, torch_dtype="auto", quantization_config=quantization_config, device_map="cpu", token=HF_TOKEN, ) progress(0.45, desc="Quantization completed") return model def save_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, progress=gr.Progress()): username = get_username() progress(0.50, desc="Preparing to push") print("Saving quantized model") with tempfile.TemporaryDirectory() as tmpdirname: tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN) tokenizer.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname, safe_serialization=False) if quantized_model_name: repo_name = f"{username}/{quantized_model_name}" else: if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and (group_size is not None): repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}" else: repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}" progress(0.70, desc="Creating model card") model_card = create_model_card(model_name, quantization_type, group_size) with open(os.path.join(tmpdirname, "README.md"), "w") as f: f.write(model_card) api = HfApi(token=HF_TOKEN) api.create_repo(repo_name, exist_ok=True, private=not public) progress(0.80, desc="Pushing to Hub") api.upload_folder(folder_path=tmpdirname, repo_id=repo_name, repo_type="model") progress(1.00, desc="Done") repo_link = f""" """ return f"

🎉 Quantization Completed


{repo_link}" def quantize_and_save(model_name, quantization_type, group_size, quantized_model_name, public): username = get_username() if not username or username == "anonymous": return "

❌ Authentication Error

Invalid or missing HF_TOKEN.

" if group_size and group_size.strip(): try: group_size = int(group_size) except ValueError: group_size = None else: group_size = None exists_message = check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name) if exists_message: return f"

⚠️ Model Already Exists

{exists_message}

" try: quantized_model = quantize_model(model_name, quantization_type, group_size) return save_model(quantized_model, model_name, quantization_type, group_size, quantized_model_name, public) except Exception as e: return f"

❌ Error

{str(e)}

" # === Gradio UI === with gr.Blocks() as demo: gr.Markdown("# 🤗 TorchAO Quantizer (Token Mode) 🔥") gr.Markdown("Uses your environment HF_TOKEN — no login required.") with gr.Row(): model_name = HuggingfaceHubSearch(label="🔍 Hub Model ID", placeholder="Search a model", search_type="model") quantization_type = gr.Dropdown( choices=list(MAP_QUANT_TYPE_TO_NAME.keys()), value="Int8WeightOnly", label="Quantization Type" ) group_size = gr.Textbox(label="Group Size (optional)", value="128") quantized_model_name = gr.Textbox(label="Custom Model Name", value="") public = gr.Checkbox(label="Make Public", value=True) output_link = gr.Markdown() quantize_button = gr.Button("🚀 Quantize and Push") quantize_button.click( fn=quantize_and_save, inputs=[model_name, quantization_type, group_size, quantized_model_name, public], outputs=output_link, ) demo.launch(share=True)