File size: 7,483 Bytes
2a924b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | import os
import gc
import torch
import shutil
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file
TEMP_DIR = "temp_processing_dir"
def convert_and_upload(token, source_repo, target_repo, precision, target_components):
if not token:
yield "β Error: Please provide a valid Hugging Face Write Token."
return
if not target_repo.strip() or "your-username" in target_repo:
yield "β Error: Please specify a valid Target Repository (e.g., your-username/repo-name)."
return
if not target_components:
yield "β Error: Please select at least one component to quantize."
return
# Map precision string to PyTorch dtype
if precision == "FP8":
target_dtype = torch.float8_e4m3fn
elif precision == "FP16":
target_dtype = torch.float16
elif precision == "BF16":
target_dtype = torch.bfloat16
else:
target_dtype = None
api = HfApi(token=token)
yield f"π Connecting to Hugging Face and verifying target repo: {target_repo}..."
try:
api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
except Exception as e:
yield f"β Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
return
yield f"π Fetching file list from {source_repo}..."
try:
files = api.list_repo_files(source_repo)
except Exception as e:
yield f"β Error fetching files: {str(e)}"
return
os.makedirs(TEMP_DIR, exist_ok=True)
for file in files:
yield f"β³ Processing {file}..."
try:
# Download file locally, bypassing symlink cache to save disk space
local_path = hf_hub_download(
repo_id=source_repo,
filename=file,
local_dir=TEMP_DIR,
local_dir_use_symlinks=False
)
# Check if this file belongs to one of the user-selected components
in_target_component = any(f"{comp}/" in file for comp in target_components)
# Intercept and quantize only if it's a safetensors file in a selected folder
if file.endswith(".safetensors") and in_target_component:
yield f"π§ Quantizing {file} to {precision}..."
tensors = load_file(local_path)
# Cast floating point tensors to the selected precision
if target_dtype:
keys = list(tensors.keys())
for k in keys:
if tensors[k].is_floating_point():
tensors[k] = tensors[k].to(target_dtype)
converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
save_file(tensors, converted_path)
# Aggressive memory flush to prevent OOM (Crucial for the 9.3GB transformer shard)
del tensors
gc.collect()
yield f"βοΈ Uploading {precision} version of {file}..."
api.upload_file(
path_or_fileobj=converted_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Upload {precision} quantized {file}"
)
os.remove(converted_path)
else:
yield f"βοΈ Copying {file} as-is..."
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Copy {file} from original repo"
)
# Cleanup original downloaded file
if os.path.exists(local_path):
os.remove(local_path)
gc.collect()
except Exception as e:
yield f"β οΈ Error processing {file}: {str(e)}\nSkipping to next file..."
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
yield f"β
All files processed and successfully uploaded to {target_repo}!"
# Dynamic UI Update for Target Repo Name
def update_target_repo(username, source, precision):
user_prefix = username.strip() if username.strip() else "your-username"
model_name = source.split("/")[-1] if "/" in source else source
return f"{user_prefix}/{model_name}-{precision}"
# Build the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π ERNIE-Image Dedicated Quantizer")
gr.Markdown(
"Convert the massive **ERNIE-Image** and **ERNIE-Image-Turbo** models to lower precisions (FP8, FP16, BF16).\n\n"
"**Memory Management:** This tool processes the files shard-by-shard. The largest file is the 9.31 GB transformer shard, "
"which will peak near 14 GB of RAM during FP8 conversion. The script flushes memory aggressively after each step to prevent crashing the free tier."
)
with gr.Row():
with gr.Column(scale=2):
hf_token = gr.Textbox(
label="Hugging Face Token (Write Access Required)",
type="password",
placeholder="hf_..."
)
hf_username = gr.Textbox(
label="Your Hugging Face Username",
placeholder="e.g., rootlocalghost"
)
# Locked down to only ERNIE models
source_repo = gr.Dropdown(
choices=[
"baidu/ERNIE-Image-Turbo",
"baidu/ERNIE-Image"
],
value="baidu/ERNIE-Image-Turbo",
label="Source Repository",
allow_custom_value=False
)
# Included 'pe' (Prompt Encoder) because it has a 7.14 GB safetensors file
target_components = gr.CheckboxGroup(
choices=["pe", "text_encoder", "transformer", "vae"],
value=["pe", "text_encoder", "transformer"],
label="Components to Quantize",
info="Select which folders should be cast to the new precision. Unselected folders will be copied as-is."
)
precision = gr.Dropdown(
choices=["FP8", "FP16", "BF16"],
value="FP8",
label="Target Precision"
)
target_repo = gr.Textbox(
label="Target Repository (Auto-generated)",
value="your-username/ERNIE-Image-Turbo-FP8",
interactive=True
)
start_btn = gr.Button("Start Quantization & Upload", variant="primary")
with gr.Column(scale=3):
output_log = gr.Textbox(
label="Operation Logs",
lines=20,
interactive=False,
max_lines=25
)
# Automatically update the target repo name when inputs change
inputs_to_watch = [hf_username, source_repo, precision]
for inp in inputs_to_watch:
inp.change(
fn=update_target_repo,
inputs=inputs_to_watch,
outputs=[target_repo]
)
start_btn.click(
fn=convert_and_upload,
inputs=[hf_token, source_repo, target_repo, precision, target_components],
outputs=[output_log]
)
if __name__ == "__main__":
demo.launch() |