rootlocalghost commited on
Commit
041d8b6
Β·
verified Β·
1 Parent(s): 7deb3dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import shutil
5
+ import gradio as gr
6
+ from huggingface_hub import HfApi, hf_hub_download
7
+ from safetensors.torch import load_file, save_file
8
+
9
+ TEMP_DIR = "temp_processing_dir"
10
+
11
+ def convert_and_upload(token, source_repo, target_repo, precision, target_components):
12
+ if not token:
13
+ yield "❌ Error: Please provide a valid Hugging Face Write Token."
14
+ return
15
+ if not target_repo.strip() or "your-username" in target_repo:
16
+ yield "❌ Error: Please specify a valid Target Repository (e.g., your-username/repo-name)."
17
+ return
18
+ if not target_components:
19
+ yield "❌ Error: Please select at least one component to quantize."
20
+ return
21
+
22
+ # Map precision string to PyTorch dtype
23
+ if precision == "FP8":
24
+ target_dtype = torch.float8_e4m3fn
25
+ elif precision == "FP16":
26
+ target_dtype = torch.float16
27
+ elif precision == "BF16":
28
+ target_dtype = torch.bfloat16
29
+ else:
30
+ target_dtype = None
31
+
32
+ api = HfApi(token=token)
33
+ yield f"πŸ”„ Connecting to Hugging Face and verifying target repo: {target_repo}..."
34
+
35
+ try:
36
+ api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
37
+ except Exception as e:
38
+ yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
39
+ return
40
+
41
+ yield f"πŸ“‹ Fetching file list from {source_repo}..."
42
+ try:
43
+ files = api.list_repo_files(source_repo)
44
+ except Exception as e:
45
+ yield f"❌ Error fetching files: {str(e)}"
46
+ return
47
+
48
+ os.makedirs(TEMP_DIR, exist_ok=True)
49
+
50
+ for file in files:
51
+ # AUTO-DELETE/SKIP LOGIC: Detect large .safetensors files at the root level (no slashes in path)
52
+ is_root_safetensor = "/" not in file and file.endswith(".safetensors")
53
+
54
+ if is_root_safetensor:
55
+ yield f"πŸ—‘οΈ Auto-skipping massive root model: {file}..."
56
+ try:
57
+ # If pushing to an existing repo, explicitly delete the large root file if it exists there
58
+ api.delete_file(path_in_repo=file, repo_id=target_repo, token=token, commit_message=f"Auto-deleted massive root file {file}")
59
+ yield f"βœ… Ensured {file} is removed from target repository."
60
+ except Exception:
61
+ pass # File doesn't exist in target repo yet, which is fine
62
+ continue
63
+
64
+ yield f"⏳ Processing {file}..."
65
+
66
+ try:
67
+ # Download file locally, bypassing symlink cache to save disk space
68
+ local_path = hf_hub_download(
69
+ repo_id=source_repo,
70
+ filename=file,
71
+ local_dir=TEMP_DIR,
72
+ local_dir_use_symlinks=False
73
+ )
74
+
75
+ # Check if this file belongs to one of the user-selected components (e.g., text_encoder, transformer)
76
+ in_target_component = any(f"{comp}/" in file for comp in target_components)
77
+
78
+ # Intercept and quantize only if it's a safetensors file in a selected folder
79
+ if file.endswith(".safetensors") and in_target_component:
80
+ yield f"🧠 Quantizing {file} to {precision}..."
81
+
82
+ tensors = load_file(local_path)
83
+
84
+ # Cast floating point tensors to the selected precision
85
+ if target_dtype:
86
+ keys = list(tensors.keys())
87
+ for k in keys:
88
+ if tensors[k].is_floating_point():
89
+ tensors[k] = tensors[k].to(target_dtype)
90
+
91
+ converted_path = os.path.join(TEMP_DIR, "converted.safetensors")
92
+ save_file(tensors, converted_path)
93
+
94
+ # Aggressive memory flush to prevent OOM
95
+ del tensors
96
+ gc.collect()
97
+
98
+ yield f"☁️ Uploading {precision} version of {file}..."
99
+ api.upload_file(
100
+ path_or_fileobj=converted_path,
101
+ path_in_repo=file,
102
+ repo_id=target_repo,
103
+ commit_message=f"Upload {precision} quantized {file}"
104
+ )
105
+
106
+ os.remove(converted_path)
107
+
108
+ else:
109
+ yield f"☁️ Copying {file} as-is..."
110
+ api.upload_file(
111
+ path_or_fileobj=local_path,
112
+ path_in_repo=file,
113
+ repo_id=target_repo,
114
+ commit_message=f"Copy {file} from original repo"
115
+ )
116
+
117
+ # Cleanup original downloaded file
118
+ if os.path.exists(local_path):
119
+ os.remove(local_path)
120
+
121
+ gc.collect()
122
+
123
+ except Exception as e:
124
+ yield f"⚠️ Error processing {file}: {str(e)}\nSkipping to next file..."
125
+
126
+ if os.path.exists(TEMP_DIR):
127
+ shutil.rmtree(TEMP_DIR)
128
+
129
+ yield f"βœ… All files processed and successfully uploaded to {target_repo}!"
130
+
131
+ # Dynamic UI Update for Target Repo Name
132
+ def update_target_repo(username, source, precision):
133
+ user_prefix = username.strip() if username.strip() else "your-username"
134
+ model_name = source.split("/")[-1] if "/" in source else source
135
+ return f"{user_prefix}/{model_name}-{precision}"
136
+
137
+ # Build the Gradio UI
138
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
139
+ gr.Markdown("# πŸš€ Auto-Purging Model Quantizer & Uploader")
140
+ gr.Markdown(
141
+ "Convert sharded Diffusers models (like FLUX, LongCat, Z-Image) to lower precisions (FP8, FP16, BF16).\n\n"
142
+ "**Auto-Delete Feature:** This tool is strictly designed to handle sharded folders. It will **automatically ignore and delete** any massive "
143
+ "`.safetensors` files located at the root of the repository to ensure your 16GB RAM limit is never breached and your target repository stays clean."
144
+ )
145
+
146
+ with gr.Row():
147
+ with gr.Column(scale=2):
148
+ hf_token = gr.Textbox(
149
+ label="Hugging Face Token (Write Access Required)",
150
+ type="password",
151
+ placeholder="hf_..."
152
+ )
153
+ hf_username = gr.Textbox(
154
+ label="Your Hugging Face Username",
155
+ placeholder="e.g., rootlocalghost"
156
+ )
157
+ source_repo = gr.Dropdown(
158
+ choices=[
159
+ "black-forest-labs/FLUX.2-klein-9B",
160
+ "black-forest-labs/FLUX.2-klein-4B",
161
+ "Tongyi-MAI/Z-Image-Turbo",
162
+ "meituan-longcat/LongCat-Image-Edit-Turbo"
163
+ ],
164
+ value="black-forest-labs/FLUX.2-klein-9B",
165
+ label="Source Repository",
166
+ allow_custom_value=True
167
+ )
168
+
169
+ target_components = gr.CheckboxGroup(
170
+ choices=["text_encoder", "transformer", "vae"],
171
+ value=["text_encoder", "transformer"],
172
+ label="Components to Quantize (Folders)",
173
+ info="Select which folders should be cast to the new precision. Everything else is copied as-is."
174
+ )
175
+
176
+ precision = gr.Dropdown(
177
+ choices=["FP8", "FP16", "BF16"],
178
+ value="FP8",
179
+ label="Target Precision"
180
+ )
181
+ target_repo = gr.Textbox(
182
+ label="Target Repository (Auto-generated)",
183
+ value="your-username/FLUX.2-klein-9B-FP8",
184
+ interactive=True
185
+ )
186
+ start_btn = gr.Button("Start Quantization & Upload", variant="primary")
187
+
188
+ with gr.Column(scale=3):
189
+ output_log = gr.Textbox(
190
+ label="Operation Logs",
191
+ lines=20,
192
+ interactive=False,
193
+ max_lines=25
194
+ )
195
+
196
+ # Automatically update the target repo name when inputs change
197
+ inputs_to_watch = [hf_username, source_repo, precision]
198
+ for inp in inputs_to_watch:
199
+ inp.change(
200
+ fn=update_target_repo,
201
+ inputs=inputs_to_watch,
202
+ outputs=[target_repo]
203
+ )
204
+
205
+ start_btn.click(
206
+ fn=convert_and_upload,
207
+ inputs=[hf_token, source_repo, target_repo, precision, target_components],
208
+ outputs=[output_log]
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ demo.launch()