Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from transformers import TorchAoConfig, AutoModel, AutoTokenizer | |
| import tempfile | |
| from huggingface_hub import HfApi, snapshot_download, list_models | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from packaging import version | |
| import os | |
| from torchao.quantization import ( | |
| Int4WeightOnlyConfig, | |
| Int8WeightOnlyConfig, | |
| Int8DynamicActivationInt8WeightConfig, | |
| Float8WeightOnlyConfig, | |
| Float8DynamicActivationFloat8WeightConfig, | |
| GemliteUIntXWeightOnlyConfig, | |
| ) | |
| # === Load Hugging Face token from environment === | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("β Missing HF_TOKEN environment variable. Please set it before running the app.") | |
| # === Quantization configuration maps === | |
| MAP_QUANT_TYPE_TO_NAME = { | |
| "Int4WeightOnly": "int4wo", | |
| "GemliteUIntXWeightOnly": "intxwo-gemlite", | |
| "Int8WeightOnly": "int8wo", | |
| "Int8DynamicActivationInt8Weight": "int8da8w8", | |
| "Float8WeightOnly": "float8wo", | |
| "Float8DynamicActivationFloat8Weight": "float8da8w8", | |
| "autoquant": "autoquant", | |
| } | |
| MAP_QUANT_TYPE_TO_CONFIG = { | |
| "Int4WeightOnly": Int4WeightOnlyConfig, | |
| "GemliteUIntXWeightOnly": GemliteUIntXWeightOnlyConfig, | |
| "Int8WeightOnly": Int8WeightOnlyConfig, | |
| "Int8DynamicActivationInt8Weight": Int8DynamicActivationInt8WeightConfig, | |
| "Float8WeightOnly": Float8WeightOnlyConfig, | |
| "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig, | |
| } | |
| # === Helper functions === | |
| def get_username(): | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| info = api.whoami() | |
| return info["name"] | |
| except Exception: | |
| return "anonymous" | |
| def check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name): | |
| """Check if a model exists in the user's Hugging Face repository.""" | |
| try: | |
| models = list_models(author=username, token=HF_TOKEN) | |
| model_names = [model.id for model in models] | |
| if quantized_model_name: | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else: | |
| if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and group_size is not None: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}" | |
| else: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}" | |
| if repo_name in model_names: | |
| return f"Model '{repo_name}' already exists in your repository." | |
| else: | |
| return None | |
| except Exception as e: | |
| return f"Error checking model existence: {str(e)}" | |
| def create_model_card(model_name, quantization_type, group_size): | |
| try: | |
| model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=HF_TOKEN) | |
| readme_path = os.path.join(model_path, "README.md") | |
| original_readme = "" | |
| if os.path.exists(readme_path): | |
| with open(readme_path, "r", encoding="utf-8") as f: | |
| original_readme = f.read() | |
| except Exception: | |
| original_readme = "" | |
| yaml_header = f"""--- | |
| base_model: | |
| - {model_name} | |
| tags: | |
| - torchao-my-repo | |
| --- | |
| # {model_name} (Quantized) | |
| ## Quantization Details | |
| - **Quantization Type**: {quantization_type} | |
| - **Group Size**: {group_size} | |
| """ | |
| if original_readme: | |
| yaml_header += "\n\n# π Original Model Info\n\n" + original_readme | |
| return yaml_header | |
| def quantize_model(model_name, quantization_type, group_size=128, progress=gr.Progress()): | |
| print(f"Quantizing model: {quantization_type}") | |
| progress(0, desc="Preparing Quantization") | |
| if quantization_type == "GemliteUIntXWeightOnly": | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size) | |
| elif quantization_type == "Int4WeightOnly": | |
| from torchao.dtypes import Int4CPULayout | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size, layout=Int4CPULayout()) | |
| elif quantization_type == "autoquant": | |
| quant_config = "autoquant" | |
| else: | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]() | |
| quantization_config = TorchAoConfig(quant_config) | |
| progress(0.10, desc="Quantizing model") | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| quantization_config=quantization_config, | |
| device_map="cpu", | |
| token=HF_TOKEN, | |
| ) | |
| progress(0.45, desc="Quantization completed") | |
| return model | |
| def save_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, progress=gr.Progress()): | |
| username = get_username() | |
| progress(0.50, desc="Preparing to push") | |
| print("Saving quantized model") | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN) | |
| tokenizer.save_pretrained(tmpdirname) | |
| model.save_pretrained(tmpdirname, safe_serialization=False) | |
| if quantized_model_name: | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else: | |
| if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and (group_size is not None): | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}" | |
| else: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}" | |
| progress(0.70, desc="Creating model card") | |
| model_card = create_model_card(model_name, quantization_type, group_size) | |
| with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
| f.write(model_card) | |
| api = HfApi(token=HF_TOKEN) | |
| api.create_repo(repo_name, exist_ok=True, private=not public) | |
| progress(0.80, desc="Pushing to Hub") | |
| api.upload_folder(folder_path=tmpdirname, repo_id=repo_name, repo_type="model") | |
| progress(1.00, desc="Done") | |
| repo_link = f""" | |
| <div class="repo-link"> | |
| <h3>π Repository Link</h3> | |
| <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank">{repo_name}</a></p> | |
| </div> | |
| """ | |
| return f"<h1>π Quantization Completed</h1><br/>{repo_link}" | |
| def quantize_and_save(model_name, quantization_type, group_size, quantized_model_name, public): | |
| username = get_username() | |
| if not username or username == "anonymous": | |
| return "<div class='error-box'><h3>β Authentication Error</h3><p>Invalid or missing HF_TOKEN.</p></div>" | |
| if group_size and group_size.strip(): | |
| try: | |
| group_size = int(group_size) | |
| except ValueError: | |
| group_size = None | |
| else: | |
| group_size = None | |
| exists_message = check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name) | |
| if exists_message: | |
| return f"<div class='warning-box'><h3>β οΈ Model Already Exists</h3><p>{exists_message}</p></div>" | |
| try: | |
| quantized_model = quantize_model(model_name, quantization_type, group_size) | |
| return save_model(quantized_model, model_name, quantization_type, group_size, quantized_model_name, public) | |
| except Exception as e: | |
| return f"<div class='error-box'><h3>β Error</h3><p>{str(e)}</p></div>" | |
| # === Gradio UI === | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π€ TorchAO Quantizer (Token Mode) π₯") | |
| gr.Markdown("Uses your environment HF_TOKEN β no login required.") | |
| with gr.Row(): | |
| model_name = HuggingfaceHubSearch(label="π Hub Model ID", placeholder="Search a model", search_type="model") | |
| quantization_type = gr.Dropdown( | |
| choices=list(MAP_QUANT_TYPE_TO_NAME.keys()), value="Int8WeightOnly", label="Quantization Type" | |
| ) | |
| group_size = gr.Textbox(label="Group Size (optional)", value="128") | |
| quantized_model_name = gr.Textbox(label="Custom Model Name", value="") | |
| public = gr.Checkbox(label="Make Public", value=True) | |
| output_link = gr.Markdown() | |
| quantize_button = gr.Button("π Quantize and Push") | |
| quantize_button.click( | |
| fn=quantize_and_save, | |
| inputs=[model_name, quantization_type, group_size, quantized_model_name, public], | |
| outputs=output_link, | |
| ) | |
| demo.launch(share=True) | |