import os import gc import torch import shutil import gradio as gr from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file, save_file TEMP_DIR = "temp_processing_dir" def convert_and_upload(token, source_repo, target_repo, precision, target_components): if not token: yield "❌ Error: Please provide a valid Hugging Face Write Token." return if not target_repo.strip() or "your-username" in target_repo: yield "❌ Error: Please specify a valid Target Repository (e.g., your-username/repo-name)." return if not target_components: yield "❌ Error: Please select at least one component to quantize." return # Map precision string to PyTorch dtype if precision == "FP8": target_dtype = torch.float8_e4m3fn elif precision == "FP16": target_dtype = torch.float16 elif precision == "BF16": target_dtype = torch.bfloat16 else: target_dtype = None api = HfApi(token=token) yield f"🔄 Connecting to Hugging Face and verifying target repo: {target_repo}..." try: api.create_repo(repo_id=target_repo, exist_ok=True, private=False) except Exception as e: yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions." return yield f"📋 Fetching file list from {source_repo}..." try: files = api.list_repo_files(source_repo) except Exception as e: yield f"❌ Error fetching files: {str(e)}" return os.makedirs(TEMP_DIR, exist_ok=True) for file in files: yield f"⏳ Processing {file}..." try: # Download file locally, bypassing symlink cache to save disk space local_path = hf_hub_download( repo_id=source_repo, filename=file, local_dir=TEMP_DIR, local_dir_use_symlinks=False ) # Check if this file belongs to one of the user-selected components in_target_component = any(f"{comp}/" in file for comp in target_components) # Intercept and quantize only if it's a safetensors file in a selected folder if file.endswith(".safetensors") and in_target_component: yield f"🧠 Quantizing {file} to {precision}..." tensors = load_file(local_path) # Cast floating point tensors to the selected precision if target_dtype: keys = list(tensors.keys()) for k in keys: if tensors[k].is_floating_point(): tensors[k] = tensors[k].to(target_dtype) converted_path = os.path.join(TEMP_DIR, "converted.safetensors") save_file(tensors, converted_path) # Aggressive memory flush to prevent OOM (Crucial for the 9.3GB transformer shard) del tensors gc.collect() yield f"☁️ Uploading {precision} version of {file}..." api.upload_file( path_or_fileobj=converted_path, path_in_repo=file, repo_id=target_repo, commit_message=f"Upload {precision} quantized {file}" ) os.remove(converted_path) else: yield f"☁️ Copying {file} as-is..." api.upload_file( path_or_fileobj=local_path, path_in_repo=file, repo_id=target_repo, commit_message=f"Copy {file} from original repo" ) # Cleanup original downloaded file if os.path.exists(local_path): os.remove(local_path) gc.collect() except Exception as e: yield f"⚠️ Error processing {file}: {str(e)}\nSkipping to next file..." if os.path.exists(TEMP_DIR): shutil.rmtree(TEMP_DIR) yield f"✅ All files processed and successfully uploaded to {target_repo}!" # Dynamic UI Update for Target Repo Name def update_target_repo(username, source, precision): user_prefix = username.strip() if username.strip() else "your-username" model_name = source.split("/")[-1] if "/" in source else source return f"{user_prefix}/{model_name}-{precision}" # Build the Gradio UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🚀 ERNIE-Image Dedicated Quantizer") gr.Markdown( "Convert the massive **ERNIE-Image** and **ERNIE-Image-Turbo** models to lower precisions (FP8, FP16, BF16).\n\n" "**Memory Management:** This tool processes the files shard-by-shard. The largest file is the 9.31 GB transformer shard, " "which will peak near 14 GB of RAM during FP8 conversion. The script flushes memory aggressively after each step to prevent crashing the free tier." ) with gr.Row(): with gr.Column(scale=2): hf_token = gr.Textbox( label="Hugging Face Token (Write Access Required)", type="password", placeholder="hf_..." ) hf_username = gr.Textbox( label="Your Hugging Face Username", placeholder="e.g., rootlocalghost" ) # Locked down to only ERNIE models source_repo = gr.Dropdown( choices=[ "baidu/ERNIE-Image-Turbo", "baidu/ERNIE-Image" ], value="baidu/ERNIE-Image-Turbo", label="Source Repository", allow_custom_value=False ) # Included 'pe' (Prompt Encoder) because it has a 7.14 GB safetensors file target_components = gr.CheckboxGroup( choices=["pe", "text_encoder", "transformer", "vae"], value=["pe", "text_encoder", "transformer"], label="Components to Quantize", info="Select which folders should be cast to the new precision. Unselected folders will be copied as-is." ) precision = gr.Dropdown( choices=["FP8", "FP16", "BF16"], value="FP8", label="Target Precision" ) target_repo = gr.Textbox( label="Target Repository (Auto-generated)", value="your-username/ERNIE-Image-Turbo-FP8", interactive=True ) start_btn = gr.Button("Start Quantization & Upload", variant="primary") with gr.Column(scale=3): output_log = gr.Textbox( label="Operation Logs", lines=20, interactive=False, max_lines=25 ) # Automatically update the target repo name when inputs change inputs_to_watch = [hf_username, source_repo, precision] for inp in inputs_to_watch: inp.change( fn=update_target_repo, inputs=inputs_to_watch, outputs=[target_repo] ) start_btn.click( fn=convert_and_upload, inputs=[hf_token, source_repo, target_repo, precision, target_components], outputs=[output_log] ) if __name__ == "__main__": demo.launch()