TorchAO-Quant / app.py
rahul7star's picture
Update app.py
f26ebbe verified
import gradio as gr
import torch
from transformers import TorchAoConfig, AutoModel, AutoTokenizer
import tempfile
from huggingface_hub import HfApi, snapshot_download, list_models
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from packaging import version
import os
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Float8WeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
GemliteUIntXWeightOnlyConfig,
)
# === Load Hugging Face token from environment ===
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("❌ Missing HF_TOKEN environment variable. Please set it before running the app.")
# === Quantization configuration maps ===
MAP_QUANT_TYPE_TO_NAME = {
"Int4WeightOnly": "int4wo",
"GemliteUIntXWeightOnly": "intxwo-gemlite",
"Int8WeightOnly": "int8wo",
"Int8DynamicActivationInt8Weight": "int8da8w8",
"Float8WeightOnly": "float8wo",
"Float8DynamicActivationFloat8Weight": "float8da8w8",
"autoquant": "autoquant",
}
MAP_QUANT_TYPE_TO_CONFIG = {
"Int4WeightOnly": Int4WeightOnlyConfig,
"GemliteUIntXWeightOnly": GemliteUIntXWeightOnlyConfig,
"Int8WeightOnly": Int8WeightOnlyConfig,
"Int8DynamicActivationInt8Weight": Int8DynamicActivationInt8WeightConfig,
"Float8WeightOnly": Float8WeightOnlyConfig,
"Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig,
}
# === Helper functions ===
def get_username():
try:
api = HfApi(token=HF_TOKEN)
info = api.whoami()
return info["name"]
except Exception:
return "anonymous"
def check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name):
"""Check if a model exists in the user's Hugging Face repository."""
try:
models = list_models(author=username, token=HF_TOKEN)
model_names = [model.id for model in models]
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and group_size is not None:
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
if repo_name in model_names:
return f"Model '{repo_name}' already exists in your repository."
else:
return None
except Exception as e:
return f"Error checking model existence: {str(e)}"
def create_model_card(model_name, quantization_type, group_size):
try:
model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=HF_TOKEN)
readme_path = os.path.join(model_path, "README.md")
original_readme = ""
if os.path.exists(readme_path):
with open(readme_path, "r", encoding="utf-8") as f:
original_readme = f.read()
except Exception:
original_readme = ""
yaml_header = f"""---
base_model:
- {model_name}
tags:
- torchao-my-repo
---
# {model_name} (Quantized)
## Quantization Details
- **Quantization Type**: {quantization_type}
- **Group Size**: {group_size}
"""
if original_readme:
yaml_header += "\n\n# πŸ“„ Original Model Info\n\n" + original_readme
return yaml_header
def quantize_model(model_name, quantization_type, group_size=128, progress=gr.Progress()):
print(f"Quantizing model: {quantization_type}")
progress(0, desc="Preparing Quantization")
if quantization_type == "GemliteUIntXWeightOnly":
quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size)
elif quantization_type == "Int4WeightOnly":
from torchao.dtypes import Int4CPULayout
quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size, layout=Int4CPULayout())
elif quantization_type == "autoquant":
quant_config = "autoquant"
else:
quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]()
quantization_config = TorchAoConfig(quant_config)
progress(0.10, desc="Quantizing model")
model = AutoModel.from_pretrained(
model_name,
torch_dtype="auto",
quantization_config=quantization_config,
device_map="cpu",
token=HF_TOKEN,
)
progress(0.45, desc="Quantization completed")
return model
def save_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, progress=gr.Progress()):
username = get_username()
progress(0.50, desc="Preparing to push")
print("Saving quantized model")
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
tokenizer.save_pretrained(tmpdirname)
model.save_pretrained(tmpdirname, safe_serialization=False)
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and (group_size is not None):
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
progress(0.70, desc="Creating model card")
model_card = create_model_card(model_name, quantization_type, group_size)
with open(os.path.join(tmpdirname, "README.md"), "w") as f:
f.write(model_card)
api = HfApi(token=HF_TOKEN)
api.create_repo(repo_name, exist_ok=True, private=not public)
progress(0.80, desc="Pushing to Hub")
api.upload_folder(folder_path=tmpdirname, repo_id=repo_name, repo_type="model")
progress(1.00, desc="Done")
repo_link = f"""
<div class="repo-link">
<h3>πŸ”— Repository Link</h3>
<p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank">{repo_name}</a></p>
</div>
"""
return f"<h1>πŸŽ‰ Quantization Completed</h1><br/>{repo_link}"
def quantize_and_save(model_name, quantization_type, group_size, quantized_model_name, public):
username = get_username()
if not username or username == "anonymous":
return "<div class='error-box'><h3>❌ Authentication Error</h3><p>Invalid or missing HF_TOKEN.</p></div>"
if group_size and group_size.strip():
try:
group_size = int(group_size)
except ValueError:
group_size = None
else:
group_size = None
exists_message = check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name)
if exists_message:
return f"<div class='warning-box'><h3>⚠️ Model Already Exists</h3><p>{exists_message}</p></div>"
try:
quantized_model = quantize_model(model_name, quantization_type, group_size)
return save_model(quantized_model, model_name, quantization_type, group_size, quantized_model_name, public)
except Exception as e:
return f"<div class='error-box'><h3>❌ Error</h3><p>{str(e)}</p></div>"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("# πŸ€— TorchAO Quantizer (Token Mode) πŸ”₯")
gr.Markdown("Uses your environment HF_TOKEN β€” no login required.")
with gr.Row():
model_name = HuggingfaceHubSearch(label="πŸ” Hub Model ID", placeholder="Search a model", search_type="model")
quantization_type = gr.Dropdown(
choices=list(MAP_QUANT_TYPE_TO_NAME.keys()), value="Int8WeightOnly", label="Quantization Type"
)
group_size = gr.Textbox(label="Group Size (optional)", value="128")
quantized_model_name = gr.Textbox(label="Custom Model Name", value="")
public = gr.Checkbox(label="Make Public", value=True)
output_link = gr.Markdown()
quantize_button = gr.Button("πŸš€ Quantize and Push")
quantize_button.click(
fn=quantize_and_save,
inputs=[model_name, quantization_type, group_size, quantized_model_name, public],
outputs=output_link,
)
demo.launch(share=True)