File size: 7,213 Bytes
98c9be7 c85e14f 98c9be7 81541ca 281d59c c85e14f 281d59c 98c9be7 281d59c 98c9be7 281d59c 98c9be7 281d59c 98c9be7 281d59c 98c9be7 81541ca 98c9be7 281d59c 98c9be7 c85e14f 281d59c 98c9be7 281d59c 98c9be7 81541ca 98c9be7 281d59c 98c9be7 281d59c 98c9be7 281d59c 98c9be7 281d59c 98c9be7 281d59c 98c9be7 281d59c 81541ca 281d59c 81541ca 281d59c 98c9be7 281d59c 98c9be7 81541ca c85e14f 81541ca 98c9be7 81541ca 281d59c c85e14f 281d59c 81541ca 98c9be7 281d59c 98c9be7 281d59c 81541ca 281d59c 98c9be7 c85e14f 98c9be7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | import os
import gc
import torch
import shutil
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file
TEMP_DIR = "temp_processing_dir"
def convert_and_upload(token, source_repo, target_repo, precision, target_components):
if not token:
yield "β Error: Please provide a valid Hugging Face Write Token."
return
if not target_repo.strip() or "your-username" in target_repo:
yield "β Error: Please specify a valid Target Repository (e.g., your-username/repo-name)."
return
if not target_components:
yield "β Error: Please select at least one component to quantize."
return
# Map precision string to PyTorch dtype
if precision == "FP8":
target_dtype = torch.float8_e4m3fn
elif precision == "FP16":
target_dtype = torch.float16
elif precision == "BF16":
target_dtype = torch.bfloat16
else:
target_dtype = None
api = HfApi(token=token)
yield f"π Connecting to Hugging Face and verifying target repo: {target_repo}..."
try:
api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
except Exception as e:
yield f"β Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
return
yield f"π Fetching file list from {source_repo}..."
try:
files = api.list_repo_files(source_repo)
except Exception as e:
yield f"β Error fetching files: {str(e)}"
return
os.makedirs(TEMP_DIR, exist_ok=True)
for file in files:
yield f"β³ Processing {file}..."
try:
# Download file locally, bypassing symlink cache to save disk space
local_path = hf_hub_download(
repo_id=source_repo,
filename=file,
local_dir=TEMP_DIR,
local_dir_use_symlinks=False
)
# Check if this file belongs to one of the selected target components
in_target_component = any(f"{comp}/" in file for comp in target_components)
# Intercept and quantize only if it's a safetensors file in a selected folder
if file.endswith(".safetensors") and in_target_component:
yield f"π§ Quantizing {file} to {precision}..."
tensors = load_file(local_path)
# Cast floating point tensors to the selected precision
if target_dtype:
keys = list(tensors.keys())
for k in keys:
if tensors[k].is_floating_point():
tensors[k] = tensors[k].to(target_dtype)
converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
save_file(tensors, converted_path)
# Wipe tensors from RAM to prevent OOM
del tensors
gc.collect()
yield f"βοΈ Uploading {precision} version of {file}..."
api.upload_file(
path_or_fileobj=converted_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Upload {precision} quantized {file}"
)
os.remove(converted_path)
else:
yield f"βοΈ Copying {file} as-is..."
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Copy {file} from original repo"
)
# Cleanup original downloaded file
if os.path.exists(local_path):
os.remove(local_path)
gc.collect()
except Exception as e:
yield f"β οΈ Error processing {file}: {str(e)}\nSkipping to next file..."
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
yield f"β
All files processed and successfully uploaded to {target_repo}!"
# Dynamic UI Update for Target Repo Name
def update_target_repo(username, source, precision):
user_prefix = username.strip() if username.strip() else "your-username"
model_name = "Z-Image-Turbo" if "Turbo" in source else "Z-Image-Base"
return f"{user_prefix}/{model_name}-{precision}"
# Build the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π Z-Image Quantizer & Uploader")
gr.Markdown(
"Convert the **Z-Image** or **Z-Image-Turbo** models to lower precisions (FP8, FP16, BF16) and push them directly to your own Hugging Face account.\n\n"
"**How it works:** This tool sequentially downloads, quantizes the selected files, and uploads everything. "
"It is designed to run safely on free Spaces (16GB RAM) by processing files one at a time."
)
with gr.Row():
with gr.Column(scale=2):
hf_token = gr.Textbox(
label="Hugging Face Token (Write Access Required)",
type="password",
placeholder="hf_..."
)
hf_username = gr.Textbox(
label="Your Hugging Face Username",
placeholder="e.g., rootlocalghost"
)
source_repo = gr.Dropdown(
choices=["Tongyi-MAI/Z-Image", "Tongyi-MAI/Z-Image-Turbo"],
value="Tongyi-MAI/Z-Image-Turbo",
label="Source Repository"
)
# Added checkbox group for granular component control
target_components = gr.CheckboxGroup(
choices=["text_encoder", "transformer"],
value=["text_encoder", "transformer"],
label="Components to Quantize",
info="Select which parts of the model to convert. Unselected parts will be copied as-is."
)
precision = gr.Dropdown(
choices=["FP8", "FP16", "BF16"],
value="FP8",
label="Quantization Precision"
)
target_repo = gr.Textbox(
label="Target Repository (Auto-generated)",
value="your-username/Z-Image-Turbo-FP8",
interactive=True
)
start_btn = gr.Button("Start Quantization & Upload", variant="primary")
with gr.Column(scale=3):
output_log = gr.Textbox(
label="Operation Logs",
lines=17,
interactive=False,
max_lines=20
)
# Automatically update the target repo name when inputs change
inputs_to_watch = [hf_username, source_repo, precision]
for inp in inputs_to_watch:
inp.change(
fn=update_target_repo,
inputs=inputs_to_watch,
outputs=[target_repo]
)
start_btn.click(
fn=convert_and_upload,
inputs=[hf_token, source_repo, target_repo, precision, target_components],
outputs=[output_log]
)
if __name__ == "__main__":
demo.launch() |