Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import tempfile | |
| from huggingface_hub import HfApi, create_repo, list_repo_files, upload_file, hf_hub_download | |
| from safetensors_converter import convert_file, is_supported_file | |
| from typing import Optional | |
| def safe_upload_file(*args, **kwargs): | |
| try: | |
| upload_file(*args, **kwargs) | |
| except Exception as e: | |
| if "429" in str(e): | |
| raise Exception(f"Rate limit exceeded: {str(e)}") | |
| raise | |
| def get_files_to_convert(repo_id, vdir, token): | |
| """Get files from repo, handling both subdirectory and specific file paths""" | |
| api = HfApi(token=token) | |
| all_files = list_repo_files(repo_id=repo_id, token=token) | |
| files_to_convert = [] | |
| # If a specific file is provided | |
| if vdir and any(vdir.endswith(ext) for ext in ['.pth', '.pt', '.bin', '.ckpt']): | |
| if vdir in all_files and is_supported_file(vdir): | |
| files_to_convert.append(vdir) | |
| # If a subdirectory is provided | |
| elif vdir: | |
| # Find all files in that subdirectory | |
| for file_path in all_files: | |
| if file_path.startswith(vdir) and is_supported_file(file_path): | |
| files_to_convert.append(file_path) | |
| # If no specific path provided, convert all supported files | |
| else: | |
| for file_path in all_files: | |
| if is_supported_file(file_path): | |
| files_to_convert.append(file_path) | |
| return files_to_convert | |
| def generate_output_repo_name(profile: Optional[gr.OAuthProfile], input_repo: str, user_output_repo: str) -> str: | |
| """Generate output repo name based on user input and profile""" | |
| if not profile: | |
| return user_output_repo | |
| username = profile.username | |
| # If user provided a full repo name, use it as-is | |
| if user_output_repo and '/' in user_output_repo: | |
| return user_output_repo | |
| # If user provided just a repo name, prepend their username | |
| if user_output_repo: | |
| return f"{username}/{user_output_repo}" | |
| # If user provided nothing, generate from input repo + username | |
| if input_repo and '/' in input_repo: | |
| # Extract repo name from input (e.g., "org/repo" -> "repo") | |
| repo_name = input_repo.split('/')[-1] | |
| return f"{username}/{repo_name}" | |
| # Fallback | |
| return f"{username}/convertpt" | |
| def convert_repo(profile: Optional[gr.OAuthProfile], oauth_token: gr.OAuthToken, input_repo, vdir, output_repo_name): | |
| if not profile or not oauth_token: | |
| return "β Please login first!", "" | |
| # Autofill partial details. | |
| output_repo_name = generate_output_repo_name(profile, input_repo, output_repo_name) | |
| progress_log = [] | |
| error_log = [] | |
| def log(message): | |
| progress_log.append(message) | |
| print(message) | |
| log("Starting conversion...") | |
| try: | |
| # Create output repo | |
| create_repo( | |
| repo_id=output_repo_name, | |
| repo_type="model", | |
| private=False, | |
| exist_ok=True, | |
| token=oauth_token.token | |
| ) | |
| # Check what safetensors already exist in OUTPUT repo | |
| existing_files = list_repo_files(output_repo_name, token=oauth_token.token) | |
| existing_safetensors = {f for f in existing_files if f.endswith('.safetensors')} | |
| log(f"Found {len(existing_safetensors)} existing .safetensors files in output repo") | |
| input_files = get_files_to_convert(input_repo, vdir, oauth_token.token) | |
| log(f"Found {len(input_files)} convertible files in input repo") | |
| success_count = 0 | |
| skipped_count = 0 | |
| for input_file_path in input_files: | |
| # Convert input path to output safetensors path | |
| output_rel_path = os.path.splitext(input_file_path)[0] + '.safetensors' | |
| # Check if this safetensors file already exists in OUTPUT repo | |
| if output_rel_path in existing_safetensors: | |
| log(f"βοΈ Skipping: {output_rel_path} (already in output repo)") | |
| skipped_count += 1 | |
| continue | |
| # Download input file | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| input_local_path = hf_hub_download( | |
| repo_id=input_repo, | |
| filename=input_file_path, | |
| token=oauth_token.token, | |
| cache_dir=temp_dir | |
| ) | |
| # Convert the file | |
| output_local_path = os.path.join(temp_dir, "converted.safetensors") | |
| log(f"π Converting: {input_file_path}") | |
| try: | |
| convert_file(input_local_path, output_local_path) | |
| # Upload to OUTPUT repo | |
| safe_upload_file( | |
| path_or_fileobj=output_local_path, | |
| path_in_repo=output_rel_path, | |
| repo_id=output_repo_name, | |
| repo_type="model", | |
| token=oauth_token.token | |
| ) | |
| success_count += 1 | |
| log(f"β Converted & uploaded: {output_rel_path}") | |
| except Exception as e: | |
| if "Rate limit exceeded" in str(e): # Upload exhausted, must wait. | |
| raise | |
| error_msg = f"Failed to convert: {input_file_path} | str(e)" | |
| error_log.append(error_msg) | |
| log(f"β {error_msg}") | |
| result = ("""π Conversion complete!\n""" | |
| f"""- Successfully converted: {success_count}\n""" | |
| f"""- Skipped (already exist): {skipped_count}\n""" | |
| f"""- Failed: {len(error_log)}\n""" | |
| f"""- Output repo: https://huggingface.co/{output_repo_name}\n""" | |
| ) | |
| return result, "\n".join(error_log) if error_log else "No errors" | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| return error_msg, str(e) | |
| css = ''' | |
| #login { | |
| width: 100% !important; | |
| margin: 0 auto; | |
| } | |
| .error-log { | |
| max-height: 300px; | |
| overflow-y: auto; | |
| background-color: #f8d7da; | |
| padding: 10px; | |
| border-radius: 5px; | |
| margin-top: 10px; | |
| } | |
| ''' | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# Safetensors Converter") | |
| gr.LoginButton(elem_id="login") | |
| with gr.Row(): | |
| input_repo = gr.Textbox( | |
| label="Input Repository", | |
| placeholder="username/repo-name" | |
| ) | |
| subdir = gr.Textbox( | |
| label="Subdirectory or File (optional)", | |
| placeholder="models/ or specific/file.pth", | |
| info="Leave empty for entire repo, or specify subdirectory/file" | |
| ) | |
| output_repo = gr.Textbox( | |
| label="Output Repository", | |
| placeholder="username/converted-repo", | |
| info="Defaults to your_username/input_repo or your_username/out_repo" | |
| ) | |
| convert_btn = gr.Button("Convert", variant="primary") | |
| output_log = gr.Textbox(label="Progress", lines=10) | |
| error_log = gr.Textbox(label="Errors", lines=5, elem_classes=["error-log"]) | |
| convert_btn.click( | |
| fn=convert_repo, | |
| inputs=[input_repo, subdir, output_repo], | |
| outputs=[output_log, error_log] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |