pt2st / app.py
Symbiomatrix's picture
Create app.py
63f29fb verified
import gradio as gr
import os
import tempfile
from huggingface_hub import HfApi, create_repo, list_repo_files, upload_file, hf_hub_download
from safetensors_converter import convert_file, is_supported_file
from typing import Optional
def safe_upload_file(*args, **kwargs):
try:
upload_file(*args, **kwargs)
except Exception as e:
if "429" in str(e):
raise Exception(f"Rate limit exceeded: {str(e)}")
raise
def get_files_to_convert(repo_id, vdir, token):
"""Get files from repo, handling both subdirectory and specific file paths"""
api = HfApi(token=token)
all_files = list_repo_files(repo_id=repo_id, token=token)
files_to_convert = []
# If a specific file is provided
if vdir and any(vdir.endswith(ext) for ext in ['.pth', '.pt', '.bin', '.ckpt']):
if vdir in all_files and is_supported_file(vdir):
files_to_convert.append(vdir)
# If a subdirectory is provided
elif vdir:
# Find all files in that subdirectory
for file_path in all_files:
if file_path.startswith(vdir) and is_supported_file(file_path):
files_to_convert.append(file_path)
# If no specific path provided, convert all supported files
else:
for file_path in all_files:
if is_supported_file(file_path):
files_to_convert.append(file_path)
return files_to_convert
def generate_output_repo_name(profile: Optional[gr.OAuthProfile], input_repo: str, user_output_repo: str) -> str:
"""Generate output repo name based on user input and profile"""
if not profile:
return user_output_repo
username = profile.username
# If user provided a full repo name, use it as-is
if user_output_repo and '/' in user_output_repo:
return user_output_repo
# If user provided just a repo name, prepend their username
if user_output_repo:
return f"{username}/{user_output_repo}"
# If user provided nothing, generate from input repo + username
if input_repo and '/' in input_repo:
# Extract repo name from input (e.g., "org/repo" -> "repo")
repo_name = input_repo.split('/')[-1]
return f"{username}/{repo_name}"
# Fallback
return f"{username}/convertpt"
def convert_repo(profile: Optional[gr.OAuthProfile], oauth_token: gr.OAuthToken, input_repo, vdir, output_repo_name):
if not profile or not oauth_token:
return "❌ Please login first!", ""
# Autofill partial details.
output_repo_name = generate_output_repo_name(profile, input_repo, output_repo_name)
progress_log = []
error_log = []
def log(message):
progress_log.append(message)
print(message)
log("Starting conversion...")
try:
# Create output repo
create_repo(
repo_id=output_repo_name,
repo_type="model",
private=False,
exist_ok=True,
token=oauth_token.token
)
# Check what safetensors already exist in OUTPUT repo
existing_files = list_repo_files(output_repo_name, token=oauth_token.token)
existing_safetensors = {f for f in existing_files if f.endswith('.safetensors')}
log(f"Found {len(existing_safetensors)} existing .safetensors files in output repo")
input_files = get_files_to_convert(input_repo, vdir, oauth_token.token)
log(f"Found {len(input_files)} convertible files in input repo")
success_count = 0
skipped_count = 0
for input_file_path in input_files:
# Convert input path to output safetensors path
output_rel_path = os.path.splitext(input_file_path)[0] + '.safetensors'
# Check if this safetensors file already exists in OUTPUT repo
if output_rel_path in existing_safetensors:
log(f"⏭️ Skipping: {output_rel_path} (already in output repo)")
skipped_count += 1
continue
# Download input file
with tempfile.TemporaryDirectory() as temp_dir:
input_local_path = hf_hub_download(
repo_id=input_repo,
filename=input_file_path,
token=oauth_token.token,
cache_dir=temp_dir
)
# Convert the file
output_local_path = os.path.join(temp_dir, "converted.safetensors")
log(f"πŸ”„ Converting: {input_file_path}")
try:
convert_file(input_local_path, output_local_path)
# Upload to OUTPUT repo
safe_upload_file(
path_or_fileobj=output_local_path,
path_in_repo=output_rel_path,
repo_id=output_repo_name,
repo_type="model",
token=oauth_token.token
)
success_count += 1
log(f"βœ… Converted & uploaded: {output_rel_path}")
except Exception as e:
if "Rate limit exceeded" in str(e): # Upload exhausted, must wait.
raise
error_msg = f"Failed to convert: {input_file_path} | str(e)"
error_log.append(error_msg)
log(f"❌ {error_msg}")
result = ("""πŸŽ‰ Conversion complete!\n"""
f"""- Successfully converted: {success_count}\n"""
f"""- Skipped (already exist): {skipped_count}\n"""
f"""- Failed: {len(error_log)}\n"""
f"""- Output repo: https://huggingface.co/{output_repo_name}\n"""
)
return result, "\n".join(error_log) if error_log else "No errors"
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
return error_msg, str(e)
css = '''
#login {
width: 100% !important;
margin: 0 auto;
}
.error-log {
max-height: 300px;
overflow-y: auto;
background-color: #f8d7da;
padding: 10px;
border-radius: 5px;
margin-top: 10px;
}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Safetensors Converter")
gr.LoginButton(elem_id="login")
with gr.Row():
input_repo = gr.Textbox(
label="Input Repository",
placeholder="username/repo-name"
)
subdir = gr.Textbox(
label="Subdirectory or File (optional)",
placeholder="models/ or specific/file.pth",
info="Leave empty for entire repo, or specify subdirectory/file"
)
output_repo = gr.Textbox(
label="Output Repository",
placeholder="username/converted-repo",
info="Defaults to your_username/input_repo or your_username/out_repo"
)
convert_btn = gr.Button("Convert", variant="primary")
output_log = gr.Textbox(label="Progress", lines=10)
error_log = gr.Textbox(label="Errors", lines=5, elem_classes=["error-log"])
convert_btn.click(
fn=convert_repo,
inputs=[input_repo, subdir, output_repo],
outputs=[output_log, error_log]
)
if __name__ == "__main__":
demo.launch()