| import os |
| import subprocess |
| import signal |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
| import gradio as gr |
| import tempfile |
| import torch |
| import requests |
|
|
| from huggingface_hub import HfApi, ModelCard, whoami |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch |
| from pathlib import Path |
| from textwrap import dedent |
|
|
|
|
| |
|
|
| import subprocess |
| import threading |
| from queue import Queue, Empty |
|
|
| def stream_output(pipe, queue): |
| """Read output from pipe and put it in the queue.""" |
| for line in iter(pipe.readline, b''): |
| queue.put(line.decode('utf-8').rstrip()) |
| pipe.close() |
|
|
| def run_command(command, env_vars): |
| |
| process = subprocess.Popen( |
| command, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| |
| universal_newlines=False, |
| env=env_vars, |
| ) |
| |
| |
| stdout_queue = Queue() |
| stderr_queue = Queue() |
| |
| |
| stdout_thread = threading.Thread(target=stream_output, args=(process.stdout, stdout_queue)) |
| stderr_thread = threading.Thread(target=stream_output, args=(process.stderr, stderr_queue)) |
| stdout_thread.daemon = True |
| stderr_thread.daemon = True |
| stdout_thread.start() |
| stderr_thread.start() |
|
|
| output_stdout = "" |
| output_stderr = "" |
| |
| while process.poll() is None: |
| |
| try: |
| stdout_line = stdout_queue.get_nowait() |
| print(f"STDOUT: {stdout_line}") |
| output_stdout += stdout_line + "\n" |
| except Empty: |
| pass |
| |
| |
| try: |
| stderr_line = stderr_queue.get_nowait() |
| print(f"STDERR: {stderr_line}") |
| output_stderr += stderr_line + "\n" |
| except Empty: |
| pass |
| |
| |
| stdout_thread.join() |
| stderr_thread.join() |
| |
| return (process.returncode, output_stdout, output_stderr) |
|
|
| |
|
|
| def guess_base_model(ft_model_id): |
| res = requests.get(f"https://huggingface.co/api/models/{ft_model_id}") |
| res = res.json() |
| for tag in res["tags"]: |
| if tag.startswith("base_model:"): |
| return tag.split(":")[-1] |
| raise Exception("Cannot guess the base model, please enter it manually") |
|
|
|
|
| def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None): |
| |
| try: |
| whoami(oauth_token.token) |
| except Exception as e: |
| raise gr.Error("You must be logged in") |
|
|
| model_name = ft_model_id.split('/')[-1] |
|
|
| |
| whoami(oauth_token.token) |
|
|
| if not os.path.exists("outputs"): |
| os.makedirs("outputs") |
|
|
| try: |
| api = HfApi(token=oauth_token.token) |
|
|
| if not base_model_id: |
| base_model_id = guess_base_model(ft_model_id) |
| print("guess_base_model", base_model_id) |
| |
| with tempfile.TemporaryDirectory(dir="outputs") as outputdir: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| cmd = [ |
| "mergekit-extract-lora", |
| ft_model_id, |
| base_model_id, |
| outputdir, |
| f"--rank={rank}", |
| f"--device={device}" |
| ] |
| print("cmd", cmd) |
| env_vars = dict(os.environ, HF_TOKEN=oauth_token.token) |
| returncode, output_stdout, output_stderr = run_command(cmd, env_vars) |
| print("returncode", returncode) |
| print("output_stdout", output_stdout) |
| print("output_stderr", output_stderr) |
| if returncode != 0: |
| raise Exception(f"Error converting to LoRA PEFT {output_stderr}") |
| print("Model converted to LoRA PEFT successfully!") |
| print(f"Converted model path: {outputdir}") |
|
|
| |
| if not os.listdir(outputdir): |
| raise Exception("Output directory is empty!") |
|
|
| |
| username = whoami(oauth_token.token)["name"] |
| new_repo_url = api.create_repo(repo_id=f"{username}/LoRA-{model_name}", exist_ok=True, private=private_repo) |
| new_repo_id = new_repo_url.repo_id |
| print("Repo created successfully!", new_repo_url) |
|
|
| |
| api.upload_folder( |
| folder_path=outputdir, |
| path_in_repo="", |
| repo_id=new_repo_id, |
| ) |
| print("Uploaded", outputdir) |
|
|
| return ( |
| f'<h1>✅ DONE</h1><br/><br/>Find your repo here: <a href="{new_repo_url}" target="_blank" style="text-decoration:underline">{new_repo_id}</a>' |
| ) |
| except Exception as e: |
| return (f"<h1>❌ ERROR</h1><br/><br/>{e}") |
|
|
|
|
| css="""/* Custom CSS to allow scrolling */ |
| .gradio-container {overflow-y: auto;} |
| """ |
| |
| with gr.Blocks(css=css) as demo: |
| gr.Markdown("You must be logged in.") |
| gr.LoginButton(min_width=250) |
|
|
| ft_model_id = HuggingfaceHubSearch( |
| label="Fine tuned model repository", |
| placeholder="Fine tuned model", |
| search_type="model", |
| ) |
|
|
| base_model_id = HuggingfaceHubSearch( |
| label="Base model repository (optional)", |
| placeholder="If empty, it will be guessed from repo tags", |
| search_type="model", |
| ) |
|
|
| rank = gr.Dropdown( |
| ["16", "32", "64", "128"], |
| label="LoRA rank", |
| info="Higher the rank, better the result, but heavier the adapter", |
| value="32", |
| filterable=False, |
| visible=True |
| ) |
|
|
| private_repo = gr.Checkbox( |
| value=False, |
| label="Private Repo", |
| info="Create a private repo under your username." |
| ) |
|
|
| iface = gr.Interface( |
| fn=process_model, |
| inputs=[ |
| ft_model_id, |
| base_model_id, |
| rank, |
| private_repo, |
| ], |
| outputs=[ |
| gr.Markdown(label="output"), |
| ], |
| title="Convert fine tuned model into LoRA with mergekit-extract-lora", |
| description="The space takes a fine tuned model, a base model, then make a PEFT-compatible LoRA adapter based on the difference between 2 models.<br/><br/>NOTE: Each conversion takes about <b>5 to 20 minutes</b>, depending on how big the model is.", |
| api_name=False |
| ) |
|
|
| |
| demo.queue(default_concurrency_limit=1, max_size=5).launch(debug=True, show_api=False) |