File size: 7,213 Bytes
98c9be7
 
 
 
 
 
 
 
 
 
c85e14f
98c9be7
 
 
81541ca
 
281d59c
c85e14f
 
 
281d59c
 
 
 
 
 
 
 
 
 
98c9be7
 
281d59c
98c9be7
 
281d59c
98c9be7
 
 
 
281d59c
98c9be7
281d59c
98c9be7
 
 
 
 
 
 
 
 
 
81541ca
98c9be7
281d59c
98c9be7
 
 
 
 
c85e14f
 
 
 
 
281d59c
98c9be7
 
 
281d59c
 
 
 
 
 
98c9be7
 
 
 
81541ca
98c9be7
 
 
281d59c
98c9be7
 
 
281d59c
 
98c9be7
 
 
 
 
 
 
 
 
281d59c
98c9be7
 
 
281d59c
98c9be7
 
 
 
 
 
 
 
 
 
 
281d59c
98c9be7
281d59c
81541ca
 
281d59c
81541ca
281d59c
 
98c9be7
281d59c
98c9be7
81541ca
c85e14f
81541ca
98c9be7
 
 
 
81541ca
 
 
 
 
 
 
 
 
281d59c
 
 
 
 
c85e14f
 
 
 
 
 
 
 
 
281d59c
 
 
 
 
 
81541ca
 
 
98c9be7
 
 
 
 
 
281d59c
98c9be7
 
 
 
281d59c
81541ca
 
 
 
 
 
 
281d59c
98c9be7
 
c85e14f
98c9be7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import gc
import torch
import shutil
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file

TEMP_DIR = "temp_processing_dir"

def convert_and_upload(token, source_repo, target_repo, precision, target_components):
    if not token:
        yield "❌ Error: Please provide a valid Hugging Face Write Token."
        return
    if not target_repo.strip() or "your-username" in target_repo:
        yield "❌ Error: Please specify a valid Target Repository (e.g., your-username/repo-name)."
        return
    if not target_components:
        yield "❌ Error: Please select at least one component to quantize."
        return

    # Map precision string to PyTorch dtype
    if precision == "FP8":
        target_dtype = torch.float8_e4m3fn
    elif precision == "FP16":
        target_dtype = torch.float16
    elif precision == "BF16":
        target_dtype = torch.bfloat16
    else:
        target_dtype = None

    api = HfApi(token=token)
    yield f"πŸ”„ Connecting to Hugging Face and verifying target repo: {target_repo}..."

    try:
        api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
    except Exception as e:
        yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
        return

    yield f"πŸ“‹ Fetching file list from {source_repo}..."
    try:
        files = api.list_repo_files(source_repo)
    except Exception as e:
        yield f"❌ Error fetching files: {str(e)}"
        return

    os.makedirs(TEMP_DIR, exist_ok=True)

    for file in files:
        yield f"⏳ Processing {file}..."

        try:
            # Download file locally, bypassing symlink cache to save disk space
            local_path = hf_hub_download(
                repo_id=source_repo,
                filename=file,
                local_dir=TEMP_DIR,
                local_dir_use_symlinks=False 
            )

            # Check if this file belongs to one of the selected target components
            in_target_component = any(f"{comp}/" in file for comp in target_components)

            # Intercept and quantize only if it's a safetensors file in a selected folder
            if file.endswith(".safetensors") and in_target_component:
                yield f"🧠 Quantizing {file} to {precision}..."
                
                tensors = load_file(local_path)

                # Cast floating point tensors to the selected precision
                if target_dtype:
                    keys = list(tensors.keys())
                    for k in keys:
                        if tensors[k].is_floating_point():
                            tensors[k] = tensors[k].to(target_dtype)

                converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
                save_file(tensors, converted_path)

                # Wipe tensors from RAM to prevent OOM
                del tensors
                gc.collect()

                yield f"☁️ Uploading {precision} version of {file}..."
                api.upload_file(
                    path_or_fileobj=converted_path,
                    path_in_repo=file,
                    repo_id=target_repo,
                    commit_message=f"Upload {precision} quantized {file}"
                )
                
                os.remove(converted_path)

            else:
                yield f"☁️ Copying {file} as-is..."
                api.upload_file(
                    path_or_fileobj=local_path,
                    path_in_repo=file,
                    repo_id=target_repo,
                    commit_message=f"Copy {file} from original repo"
                )

            # Cleanup original downloaded file
            if os.path.exists(local_path):
                os.remove(local_path)
            
            gc.collect()

        except Exception as e:
            yield f"⚠️ Error processing {file}: {str(e)}\nSkipping to next file..."

    if os.path.exists(TEMP_DIR):
        shutil.rmtree(TEMP_DIR)

    yield f"βœ… All files processed and successfully uploaded to {target_repo}!"

# Dynamic UI Update for Target Repo Name
def update_target_repo(username, source, precision):
    user_prefix = username.strip() if username.strip() else "your-username"
    model_name = "Z-Image-Turbo" if "Turbo" in source else "Z-Image-Base"
    return f"{user_prefix}/{model_name}-{precision}"

# Build the Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸš€ Z-Image Quantizer & Uploader")
    gr.Markdown(
        "Convert the **Z-Image** or **Z-Image-Turbo** models to lower precisions (FP8, FP16, BF16) and push them directly to your own Hugging Face account.\n\n"
        "**How it works:** This tool sequentially downloads, quantizes the selected files, and uploads everything. "
        "It is designed to run safely on free Spaces (16GB RAM) by processing files one at a time."
    )

    with gr.Row():
        with gr.Column(scale=2):
            hf_token = gr.Textbox(
                label="Hugging Face Token (Write Access Required)", 
                type="password",
                placeholder="hf_..."
            )
            hf_username = gr.Textbox(
                label="Your Hugging Face Username", 
                placeholder="e.g., rootlocalghost"
            )
            source_repo = gr.Dropdown(
                choices=["Tongyi-MAI/Z-Image", "Tongyi-MAI/Z-Image-Turbo"], 
                value="Tongyi-MAI/Z-Image-Turbo", 
                label="Source Repository"
            )
            
            # Added checkbox group for granular component control
            target_components = gr.CheckboxGroup(
                choices=["text_encoder", "transformer"],
                value=["text_encoder", "transformer"],
                label="Components to Quantize",
                info="Select which parts of the model to convert. Unselected parts will be copied as-is."
            )
            
            precision = gr.Dropdown(
                choices=["FP8", "FP16", "BF16"], 
                value="FP8", 
                label="Quantization Precision"
            )
            target_repo = gr.Textbox(
                label="Target Repository (Auto-generated)", 
                value="your-username/Z-Image-Turbo-FP8",
                interactive=True
            )
            start_btn = gr.Button("Start Quantization & Upload", variant="primary")
        
        with gr.Column(scale=3):
            output_log = gr.Textbox(
                label="Operation Logs", 
                lines=17, 
                interactive=False,
                max_lines=20
            )

    # Automatically update the target repo name when inputs change
    inputs_to_watch = [hf_username, source_repo, precision]
    for inp in inputs_to_watch:
        inp.change(
            fn=update_target_repo, 
            inputs=inputs_to_watch, 
            outputs=[target_repo]
        )

    start_btn.click(
        fn=convert_and_upload, 
        inputs=[hf_token, source_repo, target_repo, precision, target_components], 
        outputs=[output_log]
    )

if __name__ == "__main__":
    demo.launch()