Spaces:
Running
Running
File size: 9,091 Bytes
f239fbd 62e6098 f239fbd 62e6098 f239fbd 62e6098 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | import os
import gc
import torch
import shutil
import uuid
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file
ARCH_PROFILES = {
"FLUX / Generic Rectified Flow": ["norm", "ln_", "embed", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"],
"Z-Image / DiT Core": ["t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder", "embed_tokens", "norm", "ln_", "shared"],
"Stable Diffusion (SDXL/SD3)": ["time_embed", "label_emb", "norm", "ln_", "out."]
}
def convert_and_upload(token, source_repo, target_repo, precision, target_components, arch_profile):
if not token:
yield "β Error: Please provide a valid Hugging Face Write Token."
return
if not target_repo.strip() or "/" not in target_repo:
yield "β Error: Target Repository must be in format 'username/repo-name'."
return
if not target_components:
yield "β Error: Please select at least one component to quantize."
return
# Map precision
target_dtype = None
is_int8 = precision == "INT8"
if precision == "FP8": target_dtype = torch.float8_e4m3fn
elif precision == "FP16": target_dtype = torch.float16
elif precision == "BF16": target_dtype = torch.bfloat16
api = HfApi(token=token)
yield f"π Verifying target repo: {target_repo}..."
try:
api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
except Exception as e:
yield f"β Error creating repo: {str(e)}"
return
yield f"π Fetching files from {source_repo}..."
try:
files = api.list_repo_files(source_repo)
except Exception as e:
yield f"β Error fetching files: {str(e)}"
return
cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}"
success_count, error_count = 0, 0
exclude_prefixes = ARCH_PROFILES.get(arch_profile, [])
for file in files:
if "/" not in file and file.endswith(".safetensors"):
yield f"ποΈ Auto-skipping massive root model: {file}..."
continue
yield f"β³ Processing {file}..."
try:
os.makedirs(cache_dir, exist_ok=True)
local_path = hf_hub_download(repo_id=source_repo, filename=file, cache_dir=cache_dir, token=token)
in_target_component = any(f"{comp}/" in file for comp in target_components)
if file.endswith(".safetensors") and in_target_component:
yield f"π§ Quantizing {file} to {precision}..."
tensors = load_file(local_path)
new_tensors = {}
for k, v in tensors.items():
if is_int8:
is_2d_weight = "weight" in k and len(v.shape) == 2
is_excluded = any(ex in k for ex in exclude_prefixes)
if is_2d_weight and not is_excluded:
if v.dtype == torch.float8_e4m3fn: v = v.to(torch.bfloat16)
scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0
scale = scale.clamp(min=1e-8)
new_tensors[f"{k.rsplit('.', 1)[0]}.weight_int8"] = torch.round(v / scale).clamp(-127, 127).to(torch.int8)
new_tensors[f"{k.rsplit('.', 1)[0]}.weight_scale"] = scale.to(torch.bfloat16)
else:
new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v
else:
new_tensors[k] = v.to(target_dtype) if v.is_floating_point() else v
converted_path = "converted.safetensors"
save_file(new_tensors, converted_path)
del tensors, new_tensors
gc.collect()
yield f"βοΈ Uploading {precision} version of {file}..."
api.upload_file(path_or_fileobj=converted_path, path_in_repo=file, repo_id=target_repo)
os.remove(converted_path)
else:
yield f"βοΈ Copying {file} as-is..."
api.upload_file(path_or_fileobj=local_path, path_in_repo=file, repo_id=target_repo)
success_count += 1
if os.path.exists(cache_dir): shutil.rmtree(cache_dir)
gc.collect()
except Exception as e:
error_count += 1
yield f"β οΈ Error processing {file}: {str(e)}\nSkipping..."
if os.path.exists(cache_dir): shutil.rmtree(cache_dir)
yield f"β
Finished! Processed: {success_count} | Errors: {error_count}."
# --- UI LOGIC ---
def generate_target_repo(source, precision):
model_name = source.split("/")[-1] if "/" in source else source
return f"your-username/{model_name}-{precision}"
def toggle_int8_warning(precision):
return gr.update(visible=(precision == "INT8"))
# --- GUI ---
# FIXED: Removed the theme argument from gr.Blocks()
with gr.Blocks() as demo:
gr.Markdown(
"""
# β‘ Universal Model Quantizer Hub
Convert massive diffusion and transformer models directly on the Hugging Face hub.
Engineered with aggressive cache-clearing to prevent storage crashes on free-tier Spaces.
"""
)
with gr.Row():
with gr.Column(scale=5):
with gr.Tabs():
with gr.TabItem("1. Authentication & Source"):
hf_token = gr.Textbox(label="HF Access Token (Write)", type="password", placeholder="hf_...")
source_repo = gr.Textbox(
label="Source Repository",
placeholder="e.g., black-forest-labs/FLUX.1-dev",
info="Paste any Hugging Face model repository ID."
)
gr.Markdown("### Popular Presets")
with gr.Row():
preset_flux = gr.Button("FLUX.2-klein-9B", size="sm")
preset_zimage = gr.Button("Z-Image-Turbo", size="sm")
preset_sd3 = gr.Button("SD3.5-Large", size="sm")
with gr.TabItem("2. Quantization Rules"):
arch_profile = gr.Radio(
choices=list(ARCH_PROFILES.keys()),
value="FLUX / Generic Rectified Flow",
label="Architecture Profile",
info="Crucial for INT8: Selects which layers to protect from precision loss."
)
target_components = gr.CheckboxGroup(
choices=["transformer", "text_encoder", "text_encoder_2", "vae"],
value=["transformer"],
label="Folders to Quantize",
info="Unselected folders will be copied to the new repo unchanged."
)
with gr.TabItem("3. Output Settings"):
precision = gr.Dropdown(
choices=["FP8", "FP16", "BF16", "INT8"],
value="INT8",
label="Target Precision"
)
int8_warning = gr.Markdown(
"β οΈ **INT8 Selected:** Keys will be split into `weight_int8` and `weight_scale`. "
"Requires custom XPU/CUDA native linear classes to execute.",
visible=True
)
target_repo = gr.Textbox(
label="Target Repository",
placeholder="your-username/model-name",
interactive=True
)
start_btn = gr.Button("π Start Cloud Quantization", variant="primary", size="lg")
with gr.Column(scale=4):
output_log = gr.Textbox(
label="Terminal Output",
lines=24,
interactive=False,
max_lines=30
)
preset_flux.click(lambda: ("black-forest-labs/FLUX.2-klein-9B", "FLUX / Generic Rectified Flow"), outputs=[source_repo, arch_profile])
preset_zimage.click(lambda: ("your-username/Z-Image-Turbo", "Z-Image / DiT Core"), outputs=[source_repo, arch_profile])
preset_sd3.click(lambda: ("stabilityai/stable-diffusion-3.5-large", "Stable Diffusion (SDXL/SD3)"), outputs=[source_repo, arch_profile])
source_repo.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
precision.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
precision.change(fn=toggle_int8_warning, inputs=[precision], outputs=[int8_warning])
start_btn.click(
fn=convert_and_upload,
inputs=[hf_token, source_repo, target_repo, precision, target_components, arch_profile],
outputs=[output_log]
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate")) |