rootlocalghost's picture
Create app.py
2a924b8 verified
import os
import gc
import torch
import shutil
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file
TEMP_DIR = "temp_processing_dir"
def convert_and_upload(token, source_repo, target_repo, precision, target_components):
if not token:
yield "❌ Error: Please provide a valid Hugging Face Write Token."
return
if not target_repo.strip() or "your-username" in target_repo:
yield "❌ Error: Please specify a valid Target Repository (e.g., your-username/repo-name)."
return
if not target_components:
yield "❌ Error: Please select at least one component to quantize."
return
# Map precision string to PyTorch dtype
if precision == "FP8":
target_dtype = torch.float8_e4m3fn
elif precision == "FP16":
target_dtype = torch.float16
elif precision == "BF16":
target_dtype = torch.bfloat16
else:
target_dtype = None
api = HfApi(token=token)
yield f"πŸ”„ Connecting to Hugging Face and verifying target repo: {target_repo}..."
try:
api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
except Exception as e:
yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
return
yield f"πŸ“‹ Fetching file list from {source_repo}..."
try:
files = api.list_repo_files(source_repo)
except Exception as e:
yield f"❌ Error fetching files: {str(e)}"
return
os.makedirs(TEMP_DIR, exist_ok=True)
for file in files:
yield f"⏳ Processing {file}..."
try:
# Download file locally, bypassing symlink cache to save disk space
local_path = hf_hub_download(
repo_id=source_repo,
filename=file,
local_dir=TEMP_DIR,
local_dir_use_symlinks=False
)
# Check if this file belongs to one of the user-selected components
in_target_component = any(f"{comp}/" in file for comp in target_components)
# Intercept and quantize only if it's a safetensors file in a selected folder
if file.endswith(".safetensors") and in_target_component:
yield f"🧠 Quantizing {file} to {precision}..."
tensors = load_file(local_path)
# Cast floating point tensors to the selected precision
if target_dtype:
keys = list(tensors.keys())
for k in keys:
if tensors[k].is_floating_point():
tensors[k] = tensors[k].to(target_dtype)
converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
save_file(tensors, converted_path)
# Aggressive memory flush to prevent OOM (Crucial for the 9.3GB transformer shard)
del tensors
gc.collect()
yield f"☁️ Uploading {precision} version of {file}..."
api.upload_file(
path_or_fileobj=converted_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Upload {precision} quantized {file}"
)
os.remove(converted_path)
else:
yield f"☁️ Copying {file} as-is..."
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Copy {file} from original repo"
)
# Cleanup original downloaded file
if os.path.exists(local_path):
os.remove(local_path)
gc.collect()
except Exception as e:
yield f"⚠️ Error processing {file}: {str(e)}\nSkipping to next file..."
if os.path.exists(TEMP_DIR):
shutil.rmtree(TEMP_DIR)
yield f"βœ… All files processed and successfully uploaded to {target_repo}!"
# Dynamic UI Update for Target Repo Name
def update_target_repo(username, source, precision):
user_prefix = username.strip() if username.strip() else "your-username"
model_name = source.split("/")[-1] if "/" in source else source
return f"{user_prefix}/{model_name}-{precision}"
# Build the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸš€ ERNIE-Image Dedicated Quantizer")
gr.Markdown(
"Convert the massive **ERNIE-Image** and **ERNIE-Image-Turbo** models to lower precisions (FP8, FP16, BF16).\n\n"
"**Memory Management:** This tool processes the files shard-by-shard. The largest file is the 9.31 GB transformer shard, "
"which will peak near 14 GB of RAM during FP8 conversion. The script flushes memory aggressively after each step to prevent crashing the free tier."
)
with gr.Row():
with gr.Column(scale=2):
hf_token = gr.Textbox(
label="Hugging Face Token (Write Access Required)",
type="password",
placeholder="hf_..."
)
hf_username = gr.Textbox(
label="Your Hugging Face Username",
placeholder="e.g., rootlocalghost"
)
# Locked down to only ERNIE models
source_repo = gr.Dropdown(
choices=[
"baidu/ERNIE-Image-Turbo",
"baidu/ERNIE-Image"
],
value="baidu/ERNIE-Image-Turbo",
label="Source Repository",
allow_custom_value=False
)
# Included 'pe' (Prompt Encoder) because it has a 7.14 GB safetensors file
target_components = gr.CheckboxGroup(
choices=["pe", "text_encoder", "transformer", "vae"],
value=["pe", "text_encoder", "transformer"],
label="Components to Quantize",
info="Select which folders should be cast to the new precision. Unselected folders will be copied as-is."
)
precision = gr.Dropdown(
choices=["FP8", "FP16", "BF16"],
value="FP8",
label="Target Precision"
)
target_repo = gr.Textbox(
label="Target Repository (Auto-generated)",
value="your-username/ERNIE-Image-Turbo-FP8",
interactive=True
)
start_btn = gr.Button("Start Quantization & Upload", variant="primary")
with gr.Column(scale=3):
output_log = gr.Textbox(
label="Operation Logs",
lines=20,
interactive=False,
max_lines=25
)
# Automatically update the target repo name when inputs change
inputs_to_watch = [hf_username, source_repo, precision]
for inp in inputs_to_watch:
inp.change(
fn=update_target_repo,
inputs=inputs_to_watch,
outputs=[target_repo]
)
start_btn.click(
fn=convert_and_upload,
inputs=[hf_token, source_repo, target_repo, precision, target_components],
outputs=[output_log]
)
if __name__ == "__main__":
demo.launch()