File size: 7,412 Bytes
63f29fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import gradio as gr
import os
import tempfile
from huggingface_hub import HfApi, create_repo, list_repo_files, upload_file, hf_hub_download
from safetensors_converter import convert_file, is_supported_file
from typing import Optional

def safe_upload_file(*args, **kwargs):
    try:
        upload_file(*args, **kwargs)
    except Exception as e:
        if "429" in str(e):
            raise Exception(f"Rate limit exceeded: {str(e)}")
        raise

def get_files_to_convert(repo_id, vdir, token):
    """Get files from repo, handling both subdirectory and specific file paths"""
    api = HfApi(token=token)
    all_files = list_repo_files(repo_id=repo_id, token=token)
    
    files_to_convert = []
    
    # If a specific file is provided
    if vdir and any(vdir.endswith(ext) for ext in ['.pth', '.pt', '.bin', '.ckpt']):
        if vdir in all_files and is_supported_file(vdir):
            files_to_convert.append(vdir)
    
    # If a subdirectory is provided
    elif vdir:
        # Find all files in that subdirectory
        for file_path in all_files:
            if file_path.startswith(vdir) and is_supported_file(file_path):
                files_to_convert.append(file_path)
    
    # If no specific path provided, convert all supported files
    else:
        for file_path in all_files:
            if is_supported_file(file_path):
                files_to_convert.append(file_path)
    
    return files_to_convert

def generate_output_repo_name(profile: Optional[gr.OAuthProfile], input_repo: str, user_output_repo: str) -> str:
    """Generate output repo name based on user input and profile"""
    if not profile:
        return user_output_repo
    
    username = profile.username
    
    # If user provided a full repo name, use it as-is
    if user_output_repo and '/' in user_output_repo:
        return user_output_repo
    
    # If user provided just a repo name, prepend their username
    if user_output_repo:
        return f"{username}/{user_output_repo}"
    
    # If user provided nothing, generate from input repo + username
    if input_repo and '/' in input_repo:
        # Extract repo name from input (e.g., "org/repo" -> "repo")
        repo_name = input_repo.split('/')[-1]
        return f"{username}/{repo_name}"
    
    # Fallback
    return f"{username}/convertpt"

def convert_repo(profile: Optional[gr.OAuthProfile], oauth_token: gr.OAuthToken, input_repo, vdir, output_repo_name):
    if not profile or not oauth_token:
        return "❌ Please login first!", ""

    # Autofill partial details.
    output_repo_name = generate_output_repo_name(profile, input_repo, output_repo_name)
    
    progress_log = []
    error_log = []
    
    def log(message):
        progress_log.append(message)
        print(message)
    
    log("Starting conversion...")
    
    try:
        # Create output repo
        create_repo(
            repo_id=output_repo_name,
            repo_type="model", 
            private=False,
            exist_ok=True,
            token=oauth_token.token
        )
        
        # Check what safetensors already exist in OUTPUT repo
        existing_files = list_repo_files(output_repo_name, token=oauth_token.token)
        existing_safetensors = {f for f in existing_files if f.endswith('.safetensors')}
        log(f"Found {len(existing_safetensors)} existing .safetensors files in output repo")
        
        input_files = get_files_to_convert(input_repo, vdir, oauth_token.token)
        log(f"Found {len(input_files)} convertible files in input repo")
        
        success_count = 0
        skipped_count = 0
        
        for input_file_path in input_files:
            # Convert input path to output safetensors path
            output_rel_path = os.path.splitext(input_file_path)[0] + '.safetensors'
            
            # Check if this safetensors file already exists in OUTPUT repo
            if output_rel_path in existing_safetensors:
                log(f"⏭️ Skipping: {output_rel_path} (already in output repo)")
                skipped_count += 1
                continue
            
            # Download input file
            with tempfile.TemporaryDirectory() as temp_dir:
                input_local_path = hf_hub_download(
                    repo_id=input_repo,
                    filename=input_file_path,
                    token=oauth_token.token,
                    cache_dir=temp_dir
                )
                
                # Convert the file
                output_local_path = os.path.join(temp_dir, "converted.safetensors")
                
                log(f"πŸ”„ Converting: {input_file_path}")
                try:
                    convert_file(input_local_path, output_local_path)
                    # Upload to OUTPUT repo
                    safe_upload_file(
                        path_or_fileobj=output_local_path,
                        path_in_repo=output_rel_path,
                        repo_id=output_repo_name,
                        repo_type="model",
                        token=oauth_token.token
                    )
                    success_count += 1
                    log(f"βœ… Converted & uploaded: {output_rel_path}")
                except Exception as e:
                    if "Rate limit exceeded" in str(e): # Upload exhausted, must wait.
                        raise
                    error_msg = f"Failed to convert: {input_file_path} | str(e)"
                    error_log.append(error_msg)
                    log(f"❌ {error_msg}")
        
        result = ("""πŸŽ‰ Conversion complete!\n"""
                  f"""- Successfully converted: {success_count}\n"""
                  f"""- Skipped (already exist): {skipped_count}\n"""
                  f"""- Failed: {len(error_log)}\n"""
                  f"""- Output repo: https://huggingface.co/{output_repo_name}\n"""
        )
        return result, "\n".join(error_log) if error_log else "No errors"
        
    except Exception as e:
        error_msg = f"❌ Error: {str(e)}"
        return error_msg, str(e)

css = '''
#login {
    width: 100% !important;
    margin: 0 auto;
}
.error-log {
    max-height: 300px;
    overflow-y: auto;
    background-color: #f8d7da;
    padding: 10px;
    border-radius: 5px;
    margin-top: 10px;
}
'''

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Safetensors Converter")
    
    gr.LoginButton(elem_id="login")
    
    with gr.Row():
        input_repo = gr.Textbox(
            label="Input Repository",
            placeholder="username/repo-name"
        )
        subdir = gr.Textbox(
            label="Subdirectory or File (optional)",
            placeholder="models/ or specific/file.pth",
            info="Leave empty for entire repo, or specify subdirectory/file"
        )
        output_repo = gr.Textbox(
            label="Output Repository", 
            placeholder="username/converted-repo",
            info="Defaults to your_username/input_repo or your_username/out_repo"
        )
    
    convert_btn = gr.Button("Convert", variant="primary")
    output_log = gr.Textbox(label="Progress", lines=10)
    error_log = gr.Textbox(label="Errors", lines=5, elem_classes=["error-log"])
    
    convert_btn.click(
        fn=convert_repo,
        inputs=[input_repo, subdir, output_repo],
        outputs=[output_log, error_log]
    )

if __name__ == "__main__":
    demo.launch()