rahul7star commited on
Commit
f26ebbe
·
verified ·
1 Parent(s): a24abf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -463
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
4
  import tempfile
5
- from huggingface_hub import HfApi, snapshot_download
6
- from huggingface_hub import list_models
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
  from packaging import version
9
  import os
@@ -16,6 +15,12 @@ from torchao.quantization import (
16
  GemliteUIntXWeightOnlyConfig,
17
  )
18
 
 
 
 
 
 
 
19
  MAP_QUANT_TYPE_TO_NAME = {
20
  "Int4WeightOnly": "int4wo",
21
  "GemliteUIntXWeightOnly": "intxwo-gemlite",
@@ -34,559 +39,175 @@ MAP_QUANT_TYPE_TO_CONFIG = {
34
  "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig,
35
  }
36
 
37
-
38
- def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
39
- # ^ expect a gr.OAuthProfile object as input to get the user's profile
40
- # if the user is not logged in, profile will be None
41
- if profile is None:
42
- return "Hello !"
43
- return f"Hello {profile.name} !"
 
44
 
45
 
46
- def check_model_exists(
47
- oauth_token: gr.OAuthToken | None,
48
- username,
49
- quantization_type,
50
- group_size,
51
- model_name,
52
- quantized_model_name,
53
- ):
54
  """Check if a model exists in the user's Hugging Face repository."""
55
  try:
56
- models = list_models(author=username, token=oauth_token.token)
57
  model_names = [model.id for model in models]
58
  if quantized_model_name:
59
  repo_name = f"{username}/{quantized_model_name}"
60
  else:
61
- if (
62
- quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"]
63
- ) and (group_size is not None):
64
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
65
  else:
66
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
67
  if repo_name in model_names:
68
  return f"Model '{repo_name}' already exists in your repository."
69
  else:
70
- return None # Model does not exist
71
  except Exception as e:
72
- # raise e
73
  return f"Error checking model existence: {str(e)}"
74
 
75
 
76
  def create_model_card(model_name, quantization_type, group_size):
77
- # Try to download the original README
78
- original_readme = ""
79
- original_yaml_header = ""
80
  try:
81
- # Download the README.md file from the original model
82
- model_path = snapshot_download(
83
- repo_id=model_name, allow_patterns=["README.md"], repo_type="model"
84
- )
85
  readme_path = os.path.join(model_path, "README.md")
86
-
87
  if os.path.exists(readme_path):
88
  with open(readme_path, "r", encoding="utf-8") as f:
89
- content = f.read()
90
-
91
- if content.startswith("---"):
92
- parts = content.split("---", 2)
93
- if len(parts) >= 3:
94
- original_yaml_header = parts[1]
95
- original_readme = "---".join(parts[2:])
96
- else:
97
- original_readme = content
98
- else:
99
- original_readme = content
100
- except Exception as e:
101
- print(f"Error reading original README: {str(e)}")
102
  original_readme = ""
103
 
104
- # Create new YAML header with base_model field
105
  yaml_header = f"""---
106
  base_model:
107
- - {model_name}"""
108
-
109
- # Add any original YAML fields except base_model
110
- if original_yaml_header:
111
- in_base_model_section = False
112
- found_tags = False
113
- for line in original_yaml_header.strip().split("\n"):
114
- # Skip if we're in a base_model section that continues to the next line
115
- if in_base_model_section:
116
- if (
117
- line.strip().startswith("-")
118
- or not line.strip()
119
- or line.startswith(" ")
120
- ):
121
- continue
122
- else:
123
- in_base_model_section = False
124
-
125
- # Check for base_model field
126
- if line.strip().startswith("base_model:"):
127
- in_base_model_section = True
128
- # If base_model has inline value (like "base_model: model_name")
129
- if ":" in line and len(line.split(":", 1)[1].strip()) > 0:
130
- in_base_model_section = False
131
- continue
132
-
133
- # Check for tags field and add bnb-my-repo
134
- if line.strip().startswith("tags:"):
135
- found_tags = True
136
- yaml_header += f"\n{line}"
137
- yaml_header += "\n- torchao-my-repo"
138
- continue
139
-
140
- yaml_header += f"\n{line}"
141
-
142
- # If tags field wasn't found, add it
143
- if not found_tags:
144
- yaml_header += "\ntags:"
145
- yaml_header += "\n- torchao-my-repo"
146
- # Complete the YAML header
147
- yaml_header += "\n---"
148
-
149
- # Create the quantization info section
150
- quant_info = f"""
151
  # {model_name} (Quantized)
152
 
153
- ## Description
154
- This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}).
155
-
156
- It's quantized using the TorchAO library using the [torchao-my-repo](https://huggingface.co/spaces/pytorch/torchao-my-repo) space.
157
-
158
  ## Quantization Details
159
  - **Quantization Type**: {quantization_type}
160
  - **Group Size**: {group_size}
161
 
162
  """
 
 
 
163
 
164
- # Combine everything
165
- model_card = yaml_header + quant_info
166
-
167
- # Append original README content if available
168
- if original_readme and not original_readme.isspace():
169
- model_card += "\n\n# 📄 Original Model Information\n\n" + original_readme
170
- return model_card
171
 
172
-
173
- def quantize_model(
174
- model_name, quantization_type, group_size=128, auth_token=None, username=None, progress=gr.Progress()
175
- ):
176
  print(f"Quantizing model: {quantization_type}")
177
  progress(0, desc="Preparing Quantization")
178
- if (
179
- quantization_type == "GemliteUIntXWeightOnly"
180
- ):
181
- quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](
182
- group_size=group_size
183
- )
184
- quantization_config = TorchAoConfig(quant_config)
185
  elif quantization_type == "Int4WeightOnly":
186
  from torchao.dtypes import Int4CPULayout
187
-
188
- quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](
189
- group_size=group_size, layout=Int4CPULayout()
190
- )
191
- quantization_config = TorchAoConfig(quant_config)
192
  elif quantization_type == "autoquant":
193
- quantization_config = TorchAoConfig(quantization_type)
194
  else:
195
  quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]()
196
- quantization_config = TorchAoConfig(quant_config)
 
197
  progress(0.10, desc="Quantizing model")
 
198
  model = AutoModel.from_pretrained(
199
  model_name,
200
  torch_dtype="auto",
201
  quantization_config=quantization_config,
202
  device_map="cpu",
203
- use_auth_token=auth_token.token,
204
  )
205
  progress(0.45, desc="Quantization completed")
206
  return model
207
 
208
 
209
- def save_model(
210
- model,
211
- model_name,
212
- quantization_type,
213
- group_size=128,
214
- username=None,
215
- auth_token=None,
216
- quantized_model_name=None,
217
- public=True,
218
- progress=gr.Progress(),
219
- ):
220
  progress(0.50, desc="Preparing to push")
221
  print("Saving quantized model")
 
222
  with tempfile.TemporaryDirectory() as tmpdirname:
223
- # Load and save the tokenizer
224
- tokenizer = AutoTokenizer.from_pretrained(
225
- model_name, use_auth_token=auth_token.token
226
- )
227
- tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token)
228
 
229
- # Save the model
230
- progress(0.60, desc="Saving model")
231
- model.save_pretrained(
232
- tmpdirname, safe_serialization=False, use_auth_token=auth_token.token
233
- )
234
-
235
  if quantized_model_name:
236
  repo_name = f"{username}/{quantized_model_name}"
237
  else:
238
- if (
239
- quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"]
240
- ) and (group_size is not None):
241
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
242
  else:
243
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
 
244
  progress(0.70, desc="Creating model card")
245
  model_card = create_model_card(model_name, quantization_type, group_size)
246
  with open(os.path.join(tmpdirname, "README.md"), "w") as f:
247
  f.write(model_card)
248
- # Push to Hub
249
- api = HfApi(token=auth_token.token)
250
  api.create_repo(repo_name, exist_ok=True, private=not public)
251
  progress(0.80, desc="Pushing to Hub")
252
- api.upload_folder(
253
- folder_path=tmpdirname,
254
- repo_id=repo_name,
255
- repo_type="model",
256
- )
257
- progress(1.00, desc="Pushing to Hub completed")
258
-
259
- import io
260
- from contextlib import redirect_stdout
261
- import html
262
-
263
- # Capture the model architecture string
264
- f = io.StringIO()
265
- with redirect_stdout(f):
266
- print(model)
267
- model_architecture_str = f.getvalue()
268
-
269
- # Escape HTML characters and format with line breaks
270
- model_architecture_str_html = html.escape(model_architecture_str).replace(
271
- "\n", "<br/>"
272
- )
273
-
274
- # Format it for display in markdown with proper styling
275
- model_architecture_info = f"""
276
- <div class="model-architecture-container" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
277
- <h3 style="margin-top: 0; color: #2E7D32;">📋 Model Architecture</h3>
278
- <div class="model-architecture" style="max-height: 500px; overflow-y: auto; overflow-x: auto; background-color: #f5f5f5; padding: 5px; border-radius: 8px; font-family: monospace; white-space: pre-wrap;">
279
- <div style="line-height: 1.2; font-size: 0.75em;">{model_architecture_str_html}</div>
280
- </div>
281
- </div>
282
- """
283
 
284
  repo_link = f"""
285
- <div class="repo-link" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
286
- <h3 style="margin-top: 0; color: #2E7D32;">🔗 Repository Link</h3>
287
- <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a></p>
288
  </div>
289
  """
290
- return (
291
- f"<h1>🎉 Quantization Completed</h1><br/>{repo_link}{model_architecture_info}"
292
- )
293
 
294
 
295
- def quantize_and_save(
296
- profile: gr.OAuthProfile | None,
297
- oauth_token: gr.OAuthToken | None,
298
- model_name,
299
- quantization_type,
300
- group_size,
301
- quantized_model_name,
302
- public,
303
- ):
304
- if oauth_token is None:
305
- return """
306
- <div class="error-box">
307
- <h3>❌ Authentication Error</h3>
308
- <p>Please sign in to your HuggingFace account to use the quantizer.</p>
309
- </div>
310
- """
311
- if not profile:
312
- return """
313
- <div class="error-box">
314
- <h3>❌ Authentication Error</h3>
315
- <p>Please sign in to your HuggingFace account to use the quantizer.</p>
316
- </div>
317
- """
318
- if not group_size.isdigit():
319
- if group_size != "":
320
- return """
321
- <div class="error-box">
322
- <h3>❌ Group Size Error</h3>
323
- <p>Group Size is a parameter for Int4WeightOnly or GemliteUIntXWeightOnly</p>
324
- </div>
325
- """
326
 
327
  if group_size and group_size.strip():
328
- group_size = int(group_size)
 
 
 
329
  else:
330
  group_size = None
331
 
332
- exists_message = check_model_exists(
333
- oauth_token,
334
- profile.username,
335
- quantization_type,
336
- group_size,
337
- model_name,
338
- quantized_model_name,
339
- )
340
  if exists_message:
341
- return f"""
342
- <div class="warning-box">
343
- <h3>⚠️ Model Already Exists</h3>
344
- <p>{exists_message}</p>
345
- </div>
346
- """
347
- # if quantization_type == "int4_weight_only" :
348
- # return "int4_weight_only not supported on cpu"
349
 
350
  try:
351
- quantized_model = quantize_model(
352
- model_name, quantization_type, group_size, oauth_token, profile.username
353
- )
354
- return save_model(
355
- quantized_model,
356
- model_name,
357
- quantization_type,
358
- group_size,
359
- profile.username,
360
- oauth_token,
361
- quantized_model_name,
362
- public,
363
- )
364
  except Exception as e:
365
- # raise e
366
- return str(e)
367
 
368
 
369
- def get_model_size(model):
370
- """
371
- Calculate the size of a PyTorch model in gigabytes.
372
-
373
- Args:
374
- model: PyTorch model
375
-
376
- Returns:
377
- float: Size of the model in GB
378
- """
379
- # Get model state dict
380
- state_dict = model.state_dict()
381
-
382
- # Calculate total size in bytes
383
- total_size = 0
384
- for param in state_dict.values():
385
- # Calculate bytes for each parameter
386
- total_size += param.nelement() * param.element_size()
387
-
388
- # Convert bytes to gigabytes (1 GB = 1,073,741,824 bytes)
389
- size_gb = total_size / (1024**3)
390
- size_gb = round(size_gb, 2)
391
-
392
- return size_gb
393
-
394
-
395
- # Add enhanced CSS styling
396
- css = """
397
- /* Custom CSS for enhanced UI */
398
- .gradio-container {overflow-y: auto;}
399
-
400
- /* Fix alignment for radio buttons and dropdowns */
401
- .gradio-radio, .gradio-dropdown {
402
- display: flex !important;
403
- align-items: center !important;
404
- margin: 10px 0 !important;
405
- }
406
-
407
- /* Consistent spacing and alignment */
408
- .gradio-dropdown, .gradio-textbox, .gradio-radio {
409
- margin-bottom: 12px !important;
410
- width: 100% !important;
411
- }
412
-
413
-
414
- button[variant="primary"]::before {
415
- content: "🔥 "; /* PyTorch flame icon */
416
- }
417
-
418
- button[variant="primary"]:hover {
419
- transform: translateY(-5px) scale(1.05) !important;
420
- box-shadow: 0 10px 25px rgba(238, 76, 44, 0.7) !important;
421
- }
422
-
423
- @keyframes pytorch-glow {
424
- from {
425
- box-shadow: 0 0 10px rgba(238, 76, 44, 0.5);
426
- }
427
- to {
428
- box-shadow: 0 0 20px rgba(238, 76, 44, 0.8), 0 0 30px rgba(255, 156, 0, 0.5);
429
- }
430
- }
431
-
432
- /* Login button styling */
433
- #login-button {
434
- background: linear-gradient(135deg, #EE4C2C, #FF9C00) !important;
435
- color: white !important;
436
- font-weight: 700 !important;
437
- border: none !important;
438
- border-radius: 15px !important;
439
- box-shadow: 0 0 15px rgba(238, 76, 44, 0.5) !important;
440
- transition: all 0.3s ease !important;
441
- max-width: 300px !important;
442
- margin: 0 auto !important;
443
- }
444
-
445
- .quantize-button {
446
- background: linear-gradient(135deg, #EE4C2C, #FF9C00) !important;
447
- color: white !important;
448
- font-weight: 700 !important;
449
- border: none !important;
450
- border-radius: 15px !important;
451
- box-shadow: 0 0 15px rgba(238, 76, 44, 0.5) !important;
452
- transition: all 0.3s ease !important;
453
- animation: pytorch-glow 1.5s infinite alternate !important;
454
- transform-origin: center !important;
455
- letter-spacing: 0.5px !important;
456
- text-shadow: 0 1px 2px rgba(0, 0, 0, 0.2) !important;
457
- }
458
-
459
- .quantize-button:hover {
460
- transform: translateY(-3px) scale(1.03) !important;
461
- box-shadow: 0 8px 20px rgba(238, 76, 44, 0.7) !important;
462
- }
463
- """
464
-
465
- # Update the main app layout
466
- with gr.Blocks(css=css) as demo:
467
- gr.Markdown(
468
- """
469
- # 🤗 TorchAO Model Quantizer ✨
470
-
471
- Quantize your favorite Hugging Face models using TorchAO and save them to your profile!
472
-
473
- <br/>
474
- """
475
- )
476
-
477
- gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)
478
-
479
- m1 = gr.Markdown()
480
- demo.load(hello, inputs=None, outputs=m1)
481
 
482
  with gr.Row():
483
- with gr.Column():
484
- with gr.Row():
485
- model_name = HuggingfaceHubSearch(
486
- label="🔍 Hub Model ID",
487
- placeholder="Search for model id on Huggingface",
488
- search_type="model",
489
- )
490
-
491
- gr.Markdown("""### ⚙️ Quantization Settings""")
492
- with gr.Row():
493
- with gr.Column():
494
- quantization_type = gr.Dropdown(
495
- info="Select the Quantization method",
496
- choices=[
497
- "Int4WeightOnly",
498
- "GemliteUIntXWeightOnly",
499
- "Int8WeightOnly",
500
- "Int8DynamicActivationInt8Weight",
501
- "Float8WeightOnly",
502
- "Float8DynamicActivationFloat8Weight",
503
- "autoquant",
504
- ],
505
- value="Int8WeightOnly",
506
- filterable=False,
507
- show_label=False,
508
- )
509
-
510
- group_size = gr.Textbox(
511
- info="Group Size (only for int4_weight_only and int8_weight_only)",
512
- value="128",
513
- interactive=(quantization_type.value == "Int4WeightOnly" or quantization_type.value == "Int8WeightOnly"),
514
- show_label=False,
515
- )
516
-
517
- gr.Markdown(
518
- """
519
- ### 💾 Saving Settings
520
- """
521
- )
522
- with gr.Row():
523
- quantized_model_name = gr.Textbox(
524
- label="✏️ Model Name",
525
- info="Model Name (optional : to override default)",
526
- value="",
527
- interactive=True,
528
- elem_classes="model-name-textbox",
529
- show_label=False,
530
- )
531
- with gr.Row():
532
- public = gr.Checkbox(
533
- label="🌐 Make model public",
534
- info="If checked, the model will be publicly accessible",
535
- value=True,
536
- interactive=True,
537
- show_label=True,
538
- )
539
-
540
- with gr.Column():
541
- quantize_button = gr.Button(
542
- "🚀 Quantize and Push to Hub", elem_classes="quantize-button", elem_id="quantize-button"
543
- )
544
- output_link = gr.Markdown(
545
- label="🔗 Quantized Model Info", container=True, min_height=200
546
- )
547
-
548
- # Add information section
549
- with gr.Accordion("📚 About TorchAO Quantization", open=True):
550
- gr.Markdown(
551
- """
552
- ## 📝 Quantization Options
553
-
554
- ### Quantization Types
555
- "Int4WeightOnly",
556
- "GemliteUIntXWeightOnly"
557
- "Int8WeightOnly",
558
- "Int8DynamicActivationInt8Weight",
559
- "Float8WeightOnly",
560
- "Float8DynamicActivationFloat8Weight",
561
- - **Int4WeightOnly**: 4-bit weight-only quantization
562
- - **GemliteUIntXWeightOnly**: uintx gemlite quantization (default to 4 bit only for now)
563
- - **Int8WeightOnly**: 8-bit weight-only quantization
564
- - **Int8DynamicActivationInt8Weight**: 8-bit quantization for both weights and activations
565
- - **Float8WeightOnly**: float8-bit weight-only quantization
566
- - **Float8DynamicActivationFloat8Weight**: float8-bit quantization for both weights and activations
567
- - **autoquant**: automatic quantization (uses the best quantization method for the model)
568
-
569
- ### Group Size
570
- - Only applicable for int4_weight_only and int8_weight_only quantization
571
- - Default value is 128
572
- - Affects the granularity of quantization
573
-
574
- ## 🔍 How It Works
575
- 1. Downloads the original model
576
- 2. Applies TorchAO quantization with your selected settings
577
- 3. Uploads the quantized model to your HuggingFace account
578
-
579
- ## 📊 Memory Benefits
580
- - int4 quantization can reduce model size by up to 75%
581
- - int8 quantization typically reduces size by about 50%
582
- """
583
  )
584
- # Keep existing click handler
 
 
 
 
 
585
  quantize_button.click(
586
  fn=quantize_and_save,
587
  inputs=[model_name, quantization_type, group_size, quantized_model_name, public],
588
- outputs=[output_link],
589
  )
590
 
591
- # Launch the app
592
  demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import TorchAoConfig, AutoModel, AutoTokenizer
4
  import tempfile
5
+ from huggingface_hub import HfApi, snapshot_download, list_models
 
6
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
7
  from packaging import version
8
  import os
 
15
  GemliteUIntXWeightOnlyConfig,
16
  )
17
 
18
+ # === Load Hugging Face token from environment ===
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+ if not HF_TOKEN:
21
+ raise ValueError("❌ Missing HF_TOKEN environment variable. Please set it before running the app.")
22
+
23
+ # === Quantization configuration maps ===
24
  MAP_QUANT_TYPE_TO_NAME = {
25
  "Int4WeightOnly": "int4wo",
26
  "GemliteUIntXWeightOnly": "intxwo-gemlite",
 
39
  "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig,
40
  }
41
 
42
+ # === Helper functions ===
43
+ def get_username():
44
+ try:
45
+ api = HfApi(token=HF_TOKEN)
46
+ info = api.whoami()
47
+ return info["name"]
48
+ except Exception:
49
+ return "anonymous"
50
 
51
 
52
+ def check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name):
 
 
 
 
 
 
 
53
  """Check if a model exists in the user's Hugging Face repository."""
54
  try:
55
+ models = list_models(author=username, token=HF_TOKEN)
56
  model_names = [model.id for model in models]
57
  if quantized_model_name:
58
  repo_name = f"{username}/{quantized_model_name}"
59
  else:
60
+ if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and group_size is not None:
 
 
61
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
62
  else:
63
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
64
  if repo_name in model_names:
65
  return f"Model '{repo_name}' already exists in your repository."
66
  else:
67
+ return None
68
  except Exception as e:
 
69
  return f"Error checking model existence: {str(e)}"
70
 
71
 
72
  def create_model_card(model_name, quantization_type, group_size):
 
 
 
73
  try:
74
+ model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=HF_TOKEN)
 
 
 
75
  readme_path = os.path.join(model_path, "README.md")
76
+ original_readme = ""
77
  if os.path.exists(readme_path):
78
  with open(readme_path, "r", encoding="utf-8") as f:
79
+ original_readme = f.read()
80
+ except Exception:
 
 
 
 
 
 
 
 
 
 
 
81
  original_readme = ""
82
 
 
83
  yaml_header = f"""---
84
  base_model:
85
+ - {model_name}
86
+ tags:
87
+ - torchao-my-repo
88
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # {model_name} (Quantized)
90
 
 
 
 
 
 
91
  ## Quantization Details
92
  - **Quantization Type**: {quantization_type}
93
  - **Group Size**: {group_size}
94
 
95
  """
96
+ if original_readme:
97
+ yaml_header += "\n\n# 📄 Original Model Info\n\n" + original_readme
98
+ return yaml_header
99
 
 
 
 
 
 
 
 
100
 
101
+ def quantize_model(model_name, quantization_type, group_size=128, progress=gr.Progress()):
 
 
 
102
  print(f"Quantizing model: {quantization_type}")
103
  progress(0, desc="Preparing Quantization")
104
+
105
+ if quantization_type == "GemliteUIntXWeightOnly":
106
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size)
 
 
 
 
107
  elif quantization_type == "Int4WeightOnly":
108
  from torchao.dtypes import Int4CPULayout
109
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size, layout=Int4CPULayout())
 
 
 
 
110
  elif quantization_type == "autoquant":
111
+ quant_config = "autoquant"
112
  else:
113
  quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]()
114
+
115
+ quantization_config = TorchAoConfig(quant_config)
116
  progress(0.10, desc="Quantizing model")
117
+
118
  model = AutoModel.from_pretrained(
119
  model_name,
120
  torch_dtype="auto",
121
  quantization_config=quantization_config,
122
  device_map="cpu",
123
+ token=HF_TOKEN,
124
  )
125
  progress(0.45, desc="Quantization completed")
126
  return model
127
 
128
 
129
+ def save_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, progress=gr.Progress()):
130
+ username = get_username()
 
 
 
 
 
 
 
 
 
131
  progress(0.50, desc="Preparing to push")
132
  print("Saving quantized model")
133
+
134
  with tempfile.TemporaryDirectory() as tmpdirname:
135
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
136
+ tokenizer.save_pretrained(tmpdirname)
137
+ model.save_pretrained(tmpdirname, safe_serialization=False)
 
 
138
 
 
 
 
 
 
 
139
  if quantized_model_name:
140
  repo_name = f"{username}/{quantized_model_name}"
141
  else:
142
+ if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and (group_size is not None):
 
 
143
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
144
  else:
145
  repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
146
+
147
  progress(0.70, desc="Creating model card")
148
  model_card = create_model_card(model_name, quantization_type, group_size)
149
  with open(os.path.join(tmpdirname, "README.md"), "w") as f:
150
  f.write(model_card)
151
+
152
+ api = HfApi(token=HF_TOKEN)
153
  api.create_repo(repo_name, exist_ok=True, private=not public)
154
  progress(0.80, desc="Pushing to Hub")
155
+ api.upload_folder(folder_path=tmpdirname, repo_id=repo_name, repo_type="model")
156
+ progress(1.00, desc="Done")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  repo_link = f"""
159
+ <div class="repo-link">
160
+ <h3>🔗 Repository Link</h3>
161
+ <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank">{repo_name}</a></p>
162
  </div>
163
  """
164
+ return f"<h1>🎉 Quantization Completed</h1><br/>{repo_link}"
 
 
165
 
166
 
167
+ def quantize_and_save(model_name, quantization_type, group_size, quantized_model_name, public):
168
+ username = get_username()
169
+ if not username or username == "anonymous":
170
+ return "<div class='error-box'><h3>❌ Authentication Error</h3><p>Invalid or missing HF_TOKEN.</p></div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if group_size and group_size.strip():
173
+ try:
174
+ group_size = int(group_size)
175
+ except ValueError:
176
+ group_size = None
177
  else:
178
  group_size = None
179
 
180
+ exists_message = check_model_exists(username, quantization_type, group_size, model_name, quantized_model_name)
 
 
 
 
 
 
 
181
  if exists_message:
182
+ return f"<div class='warning-box'><h3>⚠️ Model Already Exists</h3><p>{exists_message}</p></div>"
 
 
 
 
 
 
 
183
 
184
  try:
185
+ quantized_model = quantize_model(model_name, quantization_type, group_size)
186
+ return save_model(quantized_model, model_name, quantization_type, group_size, quantized_model_name, public)
 
 
 
 
 
 
 
 
 
 
 
187
  except Exception as e:
188
+ return f"<div class='error-box'><h3>❌ Error</h3><p>{str(e)}</p></div>"
 
189
 
190
 
191
+ # === Gradio UI ===
192
+ with gr.Blocks() as demo:
193
+ gr.Markdown("# 🤗 TorchAO Quantizer (Token Mode) 🔥")
194
+ gr.Markdown("Uses your environment HF_TOKEN — no login required.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  with gr.Row():
197
+ model_name = HuggingfaceHubSearch(label="🔍 Hub Model ID", placeholder="Search a model", search_type="model")
198
+ quantization_type = gr.Dropdown(
199
+ choices=list(MAP_QUANT_TYPE_TO_NAME.keys()), value="Int8WeightOnly", label="Quantization Type"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  )
201
+ group_size = gr.Textbox(label="Group Size (optional)", value="128")
202
+ quantized_model_name = gr.Textbox(label="Custom Model Name", value="")
203
+ public = gr.Checkbox(label="Make Public", value=True)
204
+ output_link = gr.Markdown()
205
+ quantize_button = gr.Button("🚀 Quantize and Push")
206
+
207
  quantize_button.click(
208
  fn=quantize_and_save,
209
  inputs=[model_name, quantization_type, group_size, quantized_model_name, public],
210
+ outputs=output_link,
211
  )
212
 
 
213
  demo.launch(share=True)