Spaces:
Sleeping
Sleeping
File size: 7,412 Bytes
63f29fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import gradio as gr
import os
import tempfile
from huggingface_hub import HfApi, create_repo, list_repo_files, upload_file, hf_hub_download
from safetensors_converter import convert_file, is_supported_file
from typing import Optional
def safe_upload_file(*args, **kwargs):
try:
upload_file(*args, **kwargs)
except Exception as e:
if "429" in str(e):
raise Exception(f"Rate limit exceeded: {str(e)}")
raise
def get_files_to_convert(repo_id, vdir, token):
"""Get files from repo, handling both subdirectory and specific file paths"""
api = HfApi(token=token)
all_files = list_repo_files(repo_id=repo_id, token=token)
files_to_convert = []
# If a specific file is provided
if vdir and any(vdir.endswith(ext) for ext in ['.pth', '.pt', '.bin', '.ckpt']):
if vdir in all_files and is_supported_file(vdir):
files_to_convert.append(vdir)
# If a subdirectory is provided
elif vdir:
# Find all files in that subdirectory
for file_path in all_files:
if file_path.startswith(vdir) and is_supported_file(file_path):
files_to_convert.append(file_path)
# If no specific path provided, convert all supported files
else:
for file_path in all_files:
if is_supported_file(file_path):
files_to_convert.append(file_path)
return files_to_convert
def generate_output_repo_name(profile: Optional[gr.OAuthProfile], input_repo: str, user_output_repo: str) -> str:
"""Generate output repo name based on user input and profile"""
if not profile:
return user_output_repo
username = profile.username
# If user provided a full repo name, use it as-is
if user_output_repo and '/' in user_output_repo:
return user_output_repo
# If user provided just a repo name, prepend their username
if user_output_repo:
return f"{username}/{user_output_repo}"
# If user provided nothing, generate from input repo + username
if input_repo and '/' in input_repo:
# Extract repo name from input (e.g., "org/repo" -> "repo")
repo_name = input_repo.split('/')[-1]
return f"{username}/{repo_name}"
# Fallback
return f"{username}/convertpt"
def convert_repo(profile: Optional[gr.OAuthProfile], oauth_token: gr.OAuthToken, input_repo, vdir, output_repo_name):
if not profile or not oauth_token:
return "β Please login first!", ""
# Autofill partial details.
output_repo_name = generate_output_repo_name(profile, input_repo, output_repo_name)
progress_log = []
error_log = []
def log(message):
progress_log.append(message)
print(message)
log("Starting conversion...")
try:
# Create output repo
create_repo(
repo_id=output_repo_name,
repo_type="model",
private=False,
exist_ok=True,
token=oauth_token.token
)
# Check what safetensors already exist in OUTPUT repo
existing_files = list_repo_files(output_repo_name, token=oauth_token.token)
existing_safetensors = {f for f in existing_files if f.endswith('.safetensors')}
log(f"Found {len(existing_safetensors)} existing .safetensors files in output repo")
input_files = get_files_to_convert(input_repo, vdir, oauth_token.token)
log(f"Found {len(input_files)} convertible files in input repo")
success_count = 0
skipped_count = 0
for input_file_path in input_files:
# Convert input path to output safetensors path
output_rel_path = os.path.splitext(input_file_path)[0] + '.safetensors'
# Check if this safetensors file already exists in OUTPUT repo
if output_rel_path in existing_safetensors:
log(f"βοΈ Skipping: {output_rel_path} (already in output repo)")
skipped_count += 1
continue
# Download input file
with tempfile.TemporaryDirectory() as temp_dir:
input_local_path = hf_hub_download(
repo_id=input_repo,
filename=input_file_path,
token=oauth_token.token,
cache_dir=temp_dir
)
# Convert the file
output_local_path = os.path.join(temp_dir, "converted.safetensors")
log(f"π Converting: {input_file_path}")
try:
convert_file(input_local_path, output_local_path)
# Upload to OUTPUT repo
safe_upload_file(
path_or_fileobj=output_local_path,
path_in_repo=output_rel_path,
repo_id=output_repo_name,
repo_type="model",
token=oauth_token.token
)
success_count += 1
log(f"β
Converted & uploaded: {output_rel_path}")
except Exception as e:
if "Rate limit exceeded" in str(e): # Upload exhausted, must wait.
raise
error_msg = f"Failed to convert: {input_file_path} | str(e)"
error_log.append(error_msg)
log(f"β {error_msg}")
result = ("""π Conversion complete!\n"""
f"""- Successfully converted: {success_count}\n"""
f"""- Skipped (already exist): {skipped_count}\n"""
f"""- Failed: {len(error_log)}\n"""
f"""- Output repo: https://huggingface.co/{output_repo_name}\n"""
)
return result, "\n".join(error_log) if error_log else "No errors"
except Exception as e:
error_msg = f"β Error: {str(e)}"
return error_msg, str(e)
css = '''
#login {
width: 100% !important;
margin: 0 auto;
}
.error-log {
max-height: 300px;
overflow-y: auto;
background-color: #f8d7da;
padding: 10px;
border-radius: 5px;
margin-top: 10px;
}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Safetensors Converter")
gr.LoginButton(elem_id="login")
with gr.Row():
input_repo = gr.Textbox(
label="Input Repository",
placeholder="username/repo-name"
)
subdir = gr.Textbox(
label="Subdirectory or File (optional)",
placeholder="models/ or specific/file.pth",
info="Leave empty for entire repo, or specify subdirectory/file"
)
output_repo = gr.Textbox(
label="Output Repository",
placeholder="username/converted-repo",
info="Defaults to your_username/input_repo or your_username/out_repo"
)
convert_btn = gr.Button("Convert", variant="primary")
output_log = gr.Textbox(label="Progress", lines=10)
error_log = gr.Textbox(label="Errors", lines=5, elem_classes=["error-log"])
convert_btn.click(
fn=convert_repo,
inputs=[input_repo, subdir, output_repo],
outputs=[output_log, error_log]
)
if __name__ == "__main__":
demo.launch() |