Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| import shutil | |
| from tempfile import TemporaryDirectory | |
| from huggingface_hub import hf_hub_download, HfApi | |
| from safetensors.torch import save_file, load_file | |
| from collections import defaultdict | |
| from typing import Dict, List | |
| # --- Logic copied from the original `convert.py` script --- | |
| # These internal functions are necessary for correctly handling shared tensors. | |
| # We copy them here to make the application self-contained. | |
| # Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py | |
| def _is_complete(storage): | |
| # The UserWarning from this line can be ignored; it's expected. | |
| return storage.size() * storage.element_size() == storage.nbytes() | |
| def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]: | |
| tensors = list(state_dict.values()) | |
| storages = {tensor.storage().data_ptr(): [] for tensor in tensors} | |
| for name, tensor in state_dict.items(): | |
| storages[tensor.storage().data_ptr()].append(name) | |
| return [names for names in storages.values() if len(names) > 1] | |
| def _remove_duplicate_names( | |
| state_dict: Dict[str, torch.Tensor] | |
| ) -> Dict[str, List[str]]: | |
| shareds = _find_shared_tensors(state_dict) | |
| to_remove = defaultdict(list) | |
| for shared in shareds: | |
| complete_names = set([name for name in shared if _is_complete(state_dict[name])]) | |
| if not complete_names: | |
| name = list(shared)[0] | |
| state_dict[name] = state_dict[name].clone() | |
| complete_names = {name} | |
| keep_name = sorted(list(complete_names))[0] | |
| for name in sorted(shared): | |
| if name != keep_name: | |
| to_remove[keep_name].append(name) | |
| return to_remove | |
| def check_file_size(sf_filename: str, pt_filename: str): | |
| sf_size = os.stat(sf_filename).st_size | |
| pt_size = os.stat(pt_filename).st_size | |
| if (sf_size - pt_size) / pt_size > 0.01: | |
| return ( | |
| f"WARNING: The converted file size ({sf_size} bytes) " | |
| f"differs from the original ({pt_size} bytes) by more than 1%." | |
| ) | |
| return None | |
| def convert_file(pt_filename: str, sf_filename: str, device: str): | |
| """Main function to convert a single file.""" | |
| loaded = torch.load(pt_filename, map_location=device, weights_only=True) | |
| if "state_dict" in loaded: | |
| loaded = loaded["state_dict"] | |
| to_removes = _remove_duplicate_names(loaded) | |
| metadata = {"format": "pt"} | |
| for kept_name, to_remove_group in to_removes.items(): | |
| for to_remove in to_remove_group: | |
| if to_remove not in metadata: | |
| metadata[to_remove] = kept_name | |
| del loaded[to_remove] | |
| loaded = {k: v.contiguous() for k, v in loaded.items()} | |
| os.makedirs(os.path.dirname(sf_filename), exist_ok=True) | |
| save_file(loaded, sf_filename, metadata=metadata) | |
| size_warning = check_file_size(sf_filename, pt_filename) | |
| reloaded = load_file(sf_filename) | |
| for k in loaded: | |
| pt_tensor = loaded[k].to("cpu") | |
| sf_tensor = reloaded[k].to("cpu") | |
| if not torch.equal(pt_tensor, sf_tensor): | |
| raise RuntimeError(f"Tensors do not match for key {k}!") | |
| return size_warning | |
| # --- Main Gradio App Logic --- | |
| def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)): | |
| if not model_id: | |
| return None, "Error: Model ID cannot be empty." | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| log_messages = [f"β Detected device: {device.upper()}"] | |
| try: | |
| api = HfApi() | |
| info = api.model_info(repo_id=model_id, revision=revision) | |
| filenames = [s.rfilename for s in info.siblings] | |
| except Exception as e: | |
| return None, f"β Error: Failed to get model info for `{model_id}`.\n{e}" | |
| files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")] | |
| if not files_to_convert: | |
| return None, f"βΉοΈ No .bin or .ckpt files found in model `{model_id}` for conversion." | |
| log_messages.append(f"π Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}") | |
| with TemporaryDirectory() as temp_dir: | |
| temp_converted_files = [] | |
| for filename in progress.tqdm(files_to_convert, desc="Converting files"): | |
| try: | |
| log_messages.append(f"\nπ Downloading `{filename}`...") | |
| pt_path = hf_hub_download( | |
| repo_id=model_id, filename=filename, revision=revision, | |
| cache_dir=os.path.join(temp_dir, "downloads"), | |
| ) | |
| log_messages.append(f"π οΈ Converting `{filename}`...") | |
| sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors" | |
| sf_path = os.path.join(temp_dir, "converted", sf_filename) | |
| size_warning = convert_file(pt_path, sf_path, device) | |
| if size_warning: | |
| log_messages.append(f"β οΈ {size_warning}") | |
| temp_converted_files.append(sf_path) | |
| log_messages.append(f"β Successfully converted to `{sf_filename}`") | |
| except Exception as e: | |
| log_messages.append(f"β Error processing file `{filename}`: {e}") | |
| continue | |
| if not temp_converted_files: | |
| return None, "\n".join(log_messages) + "\n\nFailed to convert any files." | |
| # --- KEY CHANGE --- | |
| # Copy files from the temporary directory to a persistent (for Gradio) location | |
| # before the directory is deleted. | |
| persistent_files = [] | |
| for temp_path in temp_converted_files: | |
| # shutil.copy() creates a new file that won't be deleted | |
| persistent_path = shutil.copy(temp_path, ".") | |
| persistent_files.append(persistent_path) | |
| # -------------------- | |
| final_message = "\n".join(log_messages) + "\n\n" + "π All files processed successfully! Ready for download." | |
| # Return the paths to the persistent files | |
| return persistent_files, final_message | |
| # --- Create Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Model Converter to `.safetensors` | |
| This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories | |
| to the safe and fast `.safetensors` format. | |
| **How to use:** | |
| 1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`). | |
| 2. Click the "Convert" button. | |
| 3. Wait for the process to complete and download the resulting files. | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5") | |
| revision = gr.Textbox(label="Revision (branch)", value="main") | |
| convert_button = gr.Button("Convert", variant="primary") | |
| gr.Markdown("### Result") | |
| log_output = gr.Markdown(value="Waiting for input...") | |
| file_output = gr.File(label="Download Converted Files") | |
| gr.Markdown( | |
| "<p style='color:grey;font-size:0.8em;'>" | |
| "<b>Note:</b> A `UserWarning: TypedStorage is deprecated` message may appear in the logs. " | |
| "This is normal and does not affect the result." | |
| "</p>" | |
| ) | |
| convert_button.click( | |
| fn=process_model, | |
| inputs=[model_id, revision], | |
| outputs=[file_output, log_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |