Spaces:
Paused
Paused
| import os | |
| if os.environ.get("SPACES_ZERO_GPU") is not None: | |
| import spaces | |
| else: | |
| class spaces: | |
| def GPU(func): | |
| def wrapper(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| return wrapper | |
| import gradio as gr | |
| from pathlib import Path | |
| import gc | |
| import shutil | |
| import torch | |
| from utils import set_token, upload_repo, is_repo_exists, is_repo_name | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers import BitsAndBytesConfig | |
| def fake_gpu(): | |
| pass | |
| MODEL_CLASS = { | |
| "AutoModelForCausalLM": [AutoModelForCausalLM, AutoTokenizer], | |
| } | |
| DTYPE_DICT = { | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| "fp32": torch.float32, | |
| "fp8": torch.float8_e4m3fn | |
| } | |
| def get_model_class(): | |
| return list(MODEL_CLASS.keys()) | |
| def get_model(mclass: str): | |
| return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[0] | |
| def get_tokenizer(mclass: str): | |
| return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[1] | |
| def get_dtype(dtype: str): | |
| return DTYPE_DICT.get(dtype, torch.bfloat16) | |
| def save_readme_md(dir, repo_id): | |
| orig_name = repo_id | |
| orig_url = f"https://huggingface.co/{repo_id}/" | |
| md = f"""--- | |
| license: other | |
| language: | |
| - en | |
| library_name: transformers | |
| base_model: {repo_id} | |
| tags: | |
| - transformers | |
| --- | |
| Quants of [{orig_name}]({orig_url}). | |
| """ | |
| path = str(Path(dir, "README.md")) | |
| with open(path, mode='w', encoding="utf-8") as f: | |
| f.write(md) | |
| def quantize_repo(repo_id: str, dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)): | |
| progress(0, desc="Start quantizing...") | |
| out_dir = repo_id.split("/")[-1] | |
| type_kwargs = {} | |
| if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype) | |
| nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=get_dtype(dtype), | |
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=get_dtype(dtype)) | |
| quant_kwargs = {} | |
| if qtype == "nf4": quant_kwargs["quantization_config"] = nf4_config | |
| progress(0.1, desc="Loading...") | |
| tokenizer = get_tokenizer(mclass).from_pretrained(repo_id, legathy=False) | |
| model = get_model(mclass).from_pretrained(repo_id, **type_kwargs, **quant_kwargs) | |
| progress(0.5, desc="Saving...") | |
| tokenizer.save_pretrained(out_dir) | |
| model.save_pretrained(out_dir, safe_serialization=True) | |
| if Path(out_dir).exists(): save_readme_md(out_dir, repo_id) | |
| del tokenizer | |
| del model | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| progress(1, desc="Quantized.") | |
| return out_dir | |
| def quantize_gr(repo_id: str, hf_token: str, urls: list[str], newrepo_id: str, is_private: bool=True, is_overwrite: bool=False, | |
| dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)): | |
| if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token | |
| if not hf_token: raise gr.Error("HF write token is required for this process.") | |
| set_token(hf_token) | |
| if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO") # default repo id | |
| if not is_repo_name(repo_id): raise gr.Error(f"Invalid repo name: {repo_id}") | |
| if not is_repo_name(newrepo_id): raise gr.Error(f"Invalid repo name: {newrepo_id}") | |
| if not is_overwrite and is_repo_exists(newrepo_id): raise gr.Error(f"Repo already exists: {newrepo_id}") | |
| progress(0, desc="Start quantizing...") | |
| new_path = quantize_repo(repo_id, dtype, qtype, mclass) | |
| if not new_path: return "" | |
| if not urls: urls = [] | |
| progress(0.5, desc="Start uploading...") | |
| repo_url = upload_repo(newrepo_id, new_path, is_private) | |
| progress(1, desc="Processing...") | |
| shutil.rmtree(new_path) | |
| urls.append(repo_url) | |
| md = "### Your new repo:\n" | |
| for u in urls: | |
| md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>" | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return gr.update(value=urls, choices=urls), gr.update(value=md) | |