import gradio as gr import os import tempfile from huggingface_hub import HfApi, create_repo, list_repo_files, upload_file, hf_hub_download from safetensors_converter import convert_file, is_supported_file from typing import Optional def safe_upload_file(*args, **kwargs): try: upload_file(*args, **kwargs) except Exception as e: if "429" in str(e): raise Exception(f"Rate limit exceeded: {str(e)}") raise def get_files_to_convert(repo_id, vdir, token): """Get files from repo, handling both subdirectory and specific file paths""" api = HfApi(token=token) all_files = list_repo_files(repo_id=repo_id, token=token) files_to_convert = [] # If a specific file is provided if vdir and any(vdir.endswith(ext) for ext in ['.pth', '.pt', '.bin', '.ckpt']): if vdir in all_files and is_supported_file(vdir): files_to_convert.append(vdir) # If a subdirectory is provided elif vdir: # Find all files in that subdirectory for file_path in all_files: if file_path.startswith(vdir) and is_supported_file(file_path): files_to_convert.append(file_path) # If no specific path provided, convert all supported files else: for file_path in all_files: if is_supported_file(file_path): files_to_convert.append(file_path) return files_to_convert def generate_output_repo_name(profile: Optional[gr.OAuthProfile], input_repo: str, user_output_repo: str) -> str: """Generate output repo name based on user input and profile""" if not profile: return user_output_repo username = profile.username # If user provided a full repo name, use it as-is if user_output_repo and '/' in user_output_repo: return user_output_repo # If user provided just a repo name, prepend their username if user_output_repo: return f"{username}/{user_output_repo}" # If user provided nothing, generate from input repo + username if input_repo and '/' in input_repo: # Extract repo name from input (e.g., "org/repo" -> "repo") repo_name = input_repo.split('/')[-1] return f"{username}/{repo_name}" # Fallback return f"{username}/convertpt" def convert_repo(profile: Optional[gr.OAuthProfile], oauth_token: gr.OAuthToken, input_repo, vdir, output_repo_name): if not profile or not oauth_token: return "❌ Please login first!", "" # Autofill partial details. output_repo_name = generate_output_repo_name(profile, input_repo, output_repo_name) progress_log = [] error_log = [] def log(message): progress_log.append(message) print(message) log("Starting conversion...") try: # Create output repo create_repo( repo_id=output_repo_name, repo_type="model", private=False, exist_ok=True, token=oauth_token.token ) # Check what safetensors already exist in OUTPUT repo existing_files = list_repo_files(output_repo_name, token=oauth_token.token) existing_safetensors = {f for f in existing_files if f.endswith('.safetensors')} log(f"Found {len(existing_safetensors)} existing .safetensors files in output repo") input_files = get_files_to_convert(input_repo, vdir, oauth_token.token) log(f"Found {len(input_files)} convertible files in input repo") success_count = 0 skipped_count = 0 for input_file_path in input_files: # Convert input path to output safetensors path output_rel_path = os.path.splitext(input_file_path)[0] + '.safetensors' # Check if this safetensors file already exists in OUTPUT repo if output_rel_path in existing_safetensors: log(f"⏭️ Skipping: {output_rel_path} (already in output repo)") skipped_count += 1 continue # Download input file with tempfile.TemporaryDirectory() as temp_dir: input_local_path = hf_hub_download( repo_id=input_repo, filename=input_file_path, token=oauth_token.token, cache_dir=temp_dir ) # Convert the file output_local_path = os.path.join(temp_dir, "converted.safetensors") log(f"🔄 Converting: {input_file_path}") try: convert_file(input_local_path, output_local_path) # Upload to OUTPUT repo safe_upload_file( path_or_fileobj=output_local_path, path_in_repo=output_rel_path, repo_id=output_repo_name, repo_type="model", token=oauth_token.token ) success_count += 1 log(f"✅ Converted & uploaded: {output_rel_path}") except Exception as e: if "Rate limit exceeded" in str(e): # Upload exhausted, must wait. raise error_msg = f"Failed to convert: {input_file_path} | str(e)" error_log.append(error_msg) log(f"❌ {error_msg}") result = ("""🎉 Conversion complete!\n""" f"""- Successfully converted: {success_count}\n""" f"""- Skipped (already exist): {skipped_count}\n""" f"""- Failed: {len(error_log)}\n""" f"""- Output repo: https://huggingface.co/{output_repo_name}\n""" ) return result, "\n".join(error_log) if error_log else "No errors" except Exception as e: error_msg = f"❌ Error: {str(e)}" return error_msg, str(e) css = ''' #login { width: 100% !important; margin: 0 auto; } .error-log { max-height: 300px; overflow-y: auto; background-color: #f8d7da; padding: 10px; border-radius: 5px; margin-top: 10px; } ''' with gr.Blocks(css=css) as demo: gr.Markdown("# Safetensors Converter") gr.LoginButton(elem_id="login") with gr.Row(): input_repo = gr.Textbox( label="Input Repository", placeholder="username/repo-name" ) subdir = gr.Textbox( label="Subdirectory or File (optional)", placeholder="models/ or specific/file.pth", info="Leave empty for entire repo, or specify subdirectory/file" ) output_repo = gr.Textbox( label="Output Repository", placeholder="username/converted-repo", info="Defaults to your_username/input_repo or your_username/out_repo" ) convert_btn = gr.Button("Convert", variant="primary") output_log = gr.Textbox(label="Progress", lines=10) error_log = gr.Textbox(label="Errors", lines=5, elem_classes=["error-log"]) convert_btn.click( fn=convert_repo, inputs=[input_repo, subdir, output_repo], outputs=[output_log, error_log] ) if __name__ == "__main__": demo.launch()