Model-Quantizer / app.py
rootlocalghost's picture
Update app.py
62e6098 verified
Raw
History Blame Contribute Delete
9.09 kB
import os
import gc
import torch
import shutil
import uuid
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file
ARCH_PROFILES = {
"FLUX / Generic Rectified Flow": ["norm", "ln_", "embed", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"],
"Z-Image / DiT Core": ["t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder", "embed_tokens", "norm", "ln_", "shared"],
"Stable Diffusion (SDXL/SD3)": ["time_embed", "label_emb", "norm", "ln_", "out."]
}
def convert_and_upload(token, source_repo, target_repo, precision, target_components, arch_profile):
if not token:
yield "❌ Error: Please provide a valid Hugging Face Write Token."
return
if not target_repo.strip() or "/" not in target_repo:
yield "❌ Error: Target Repository must be in format 'username/repo-name'."
return
if not target_components:
yield "❌ Error: Please select at least one component to quantize."
return
# Map precision
target_dtype = None
is_int8 = precision == "INT8"
if precision == "FP8": target_dtype = torch.float8_e4m3fn
elif precision == "FP16": target_dtype = torch.float16
elif precision == "BF16": target_dtype = torch.bfloat16
api = HfApi(token=token)
yield f"🔄 Verifying target repo: {target_repo}..."
try:
api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
except Exception as e:
yield f"❌ Error creating repo: {str(e)}"
return
yield f"📋 Fetching files from {source_repo}..."
try:
files = api.list_repo_files(source_repo)
except Exception as e:
yield f"❌ Error fetching files: {str(e)}"
return
cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}"
success_count, error_count = 0, 0
exclude_prefixes = ARCH_PROFILES.get(arch_profile, [])
for file in files:
if "/" not in file and file.endswith(".safetensors"):
yield f"🗑️ Auto-skipping massive root model: {file}..."
continue
yield f"⏳ Processing {file}..."
try:
os.makedirs(cache_dir, exist_ok=True)
local_path = hf_hub_download(repo_id=source_repo, filename=file, cache_dir=cache_dir, token=token)
in_target_component = any(f"{comp}/" in file for comp in target_components)
if file.endswith(".safetensors") and in_target_component:
yield f"🧠 Quantizing {file} to {precision}..."
tensors = load_file(local_path)
new_tensors = {}
for k, v in tensors.items():
if is_int8:
is_2d_weight = "weight" in k and len(v.shape) == 2
is_excluded = any(ex in k for ex in exclude_prefixes)
if is_2d_weight and not is_excluded:
if v.dtype == torch.float8_e4m3fn: v = v.to(torch.bfloat16)
scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0
scale = scale.clamp(min=1e-8)
new_tensors[f"{k.rsplit('.', 1)[0]}.weight_int8"] = torch.round(v / scale).clamp(-127, 127).to(torch.int8)
new_tensors[f"{k.rsplit('.', 1)[0]}.weight_scale"] = scale.to(torch.bfloat16)
else:
new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v
else:
new_tensors[k] = v.to(target_dtype) if v.is_floating_point() else v
converted_path = "converted.safetensors"
save_file(new_tensors, converted_path)
del tensors, new_tensors
gc.collect()
yield f"☁️ Uploading {precision} version of {file}..."
api.upload_file(path_or_fileobj=converted_path, path_in_repo=file, repo_id=target_repo)
os.remove(converted_path)
else:
yield f"☁️ Copying {file} as-is..."
api.upload_file(path_or_fileobj=local_path, path_in_repo=file, repo_id=target_repo)
success_count += 1
if os.path.exists(cache_dir): shutil.rmtree(cache_dir)
gc.collect()
except Exception as e:
error_count += 1
yield f"⚠️ Error processing {file}: {str(e)}\nSkipping..."
if os.path.exists(cache_dir): shutil.rmtree(cache_dir)
yield f"✅ Finished! Processed: {success_count} | Errors: {error_count}."
# --- UI LOGIC ---
def generate_target_repo(source, precision):
model_name = source.split("/")[-1] if "/" in source else source
return f"your-username/{model_name}-{precision}"
def toggle_int8_warning(precision):
return gr.update(visible=(precision == "INT8"))
# --- GUI ---
# FIXED: Removed the theme argument from gr.Blocks()
with gr.Blocks() as demo:
gr.Markdown(
"""
# ⚡ Universal Model Quantizer Hub
Convert massive diffusion and transformer models directly on the Hugging Face hub.
Engineered with aggressive cache-clearing to prevent storage crashes on free-tier Spaces.
"""
)
with gr.Row():
with gr.Column(scale=5):
with gr.Tabs():
with gr.TabItem("1. Authentication & Source"):
hf_token = gr.Textbox(label="HF Access Token (Write)", type="password", placeholder="hf_...")
source_repo = gr.Textbox(
label="Source Repository",
placeholder="e.g., black-forest-labs/FLUX.1-dev",
info="Paste any Hugging Face model repository ID."
)
gr.Markdown("### Popular Presets")
with gr.Row():
preset_flux = gr.Button("FLUX.2-klein-9B", size="sm")
preset_zimage = gr.Button("Z-Image-Turbo", size="sm")
preset_sd3 = gr.Button("SD3.5-Large", size="sm")
with gr.TabItem("2. Quantization Rules"):
arch_profile = gr.Radio(
choices=list(ARCH_PROFILES.keys()),
value="FLUX / Generic Rectified Flow",
label="Architecture Profile",
info="Crucial for INT8: Selects which layers to protect from precision loss."
)
target_components = gr.CheckboxGroup(
choices=["transformer", "text_encoder", "text_encoder_2", "vae"],
value=["transformer"],
label="Folders to Quantize",
info="Unselected folders will be copied to the new repo unchanged."
)
with gr.TabItem("3. Output Settings"):
precision = gr.Dropdown(
choices=["FP8", "FP16", "BF16", "INT8"],
value="INT8",
label="Target Precision"
)
int8_warning = gr.Markdown(
"⚠️ **INT8 Selected:** Keys will be split into `weight_int8` and `weight_scale`. "
"Requires custom XPU/CUDA native linear classes to execute.",
visible=True
)
target_repo = gr.Textbox(
label="Target Repository",
placeholder="your-username/model-name",
interactive=True
)
start_btn = gr.Button("🚀 Start Cloud Quantization", variant="primary", size="lg")
with gr.Column(scale=4):
output_log = gr.Textbox(
label="Terminal Output",
lines=24,
interactive=False,
max_lines=30
)
preset_flux.click(lambda: ("black-forest-labs/FLUX.2-klein-9B", "FLUX / Generic Rectified Flow"), outputs=[source_repo, arch_profile])
preset_zimage.click(lambda: ("your-username/Z-Image-Turbo", "Z-Image / DiT Core"), outputs=[source_repo, arch_profile])
preset_sd3.click(lambda: ("stabilityai/stable-diffusion-3.5-large", "Stable Diffusion (SDXL/SD3)"), outputs=[source_repo, arch_profile])
source_repo.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
precision.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
precision.change(fn=toggle_int8_warning, inputs=[precision], outputs=[int8_warning])
start_btn.click(
fn=convert_and_upload,
inputs=[hf_token, source_repo, target_repo, precision, target_components, arch_profile],
outputs=[output_log]
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate"))