Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from huggingface_hub import HfApi | |
| import subprocess | |
| def create_config_yaml( | |
| model_name, | |
| model1, | |
| model1_layers, | |
| model2, | |
| model2_layers, | |
| merge_method, | |
| base_model, | |
| parameters, | |
| dtype, | |
| ): | |
| yaml_config = ( | |
| f" slices:\n" | |
| " - sources:\n" | |
| f" - model: {model1}\n" | |
| f" layer_range: {model1_layers}\n" | |
| f" - model: {model2}\n" | |
| f" layer_range: {model2_layers}\n" | |
| f" merge_method: {merge_method}\n" | |
| f" base_model: {base_model}\n" | |
| f" parameters:\n" | |
| f" {parameters}\n" | |
| f" dtype: {dtype}\n" | |
| ) | |
| print("Writing YAML config to 'config.yaml'...") | |
| try: | |
| with open("config.yaml", "w", encoding="utf-8") as f: | |
| f.write(yaml_config) | |
| print("File 'config.yaml' written successfully.") | |
| except Exception as e: | |
| print(f"Error writing file: {e}") | |
| return yaml_config | |
| def execute_merge_command(): | |
| # Define the command and arguments | |
| command = "mergekit-yaml" | |
| args = ["config.yaml", "./output-model-directory"] | |
| # Execute the command | |
| result = subprocess.run([command] + args, capture_output=True, text=True) | |
| # Check if the command was executed successfully | |
| if result.returncode == 0: | |
| print("Command executed successfully") | |
| return f"Output:\n{result.stdout}" | |
| else: | |
| print("Error in executing command") | |
| return f"Error:\n{result.stderr}" | |
| # Function to push to HF Hub (for the third tab) | |
| def push_to_hf_hub(model_name, yaml_config): | |
| # Username and API token setup | |
| username = "arcee-ai" | |
| api_token = os.getenv("HF_TOKEN") | |
| if api_token is None: | |
| return "Hugging Face API token not set. Please set the HF_TOKEN environment variable." | |
| # Initialize HfApi with token | |
| api = HfApi(token=api_token) | |
| repo_id = f"{username}/{model_name}" | |
| try: | |
| # Create a new repository on Hugging Face | |
| api.create_repo(repo_id=repo_id, repo_type="model") | |
| # For demonstration, let's just create a yaml file inside a folder | |
| # os.makedirs("merge", exist_ok=True) | |
| with open("config.yaml", "w") as file: | |
| file.write(yaml_config) | |
| # Upload the contents of the 'merge' folder to the repository | |
| api.upload_folder(repo_id=repo_id, folder_path="merge") | |
| return f"Successfully pushed to HF Hub: {repo_id}" | |
| except Exception as e: | |
| return str(e) | |
| # make sure to add the themes as well | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo")) as app: | |
| gr.Markdown("# Mergekit GUI") # Title for your Gradio app | |
| with gr.Tab("Config YAML"): | |
| # Inputs for the YAML config | |
| with gr.Row(): | |
| model_name_input = gr.Textbox(label="Model Name") | |
| model1_input = gr.Textbox(label="Model 1") | |
| model1_layers_input = gr.Textbox( | |
| label="Model 1 Layer Range", placeholder="[start, end]" | |
| ) | |
| model2_input = gr.Textbox(label="Model 2") | |
| model2_layers_input = gr.Textbox( | |
| label="Model 2 Layer Range", placeholder="[start, end]" | |
| ) | |
| merge_method_input = gr.Dropdown( | |
| label="Merge Method", choices=["slerp", "linear"] | |
| ) | |
| base_model_input = gr.Textbox(label="Base Model") | |
| parameters_input = gr.Textbox( | |
| label="Parameters", placeholder="Formatted as a list of dicts" | |
| ) | |
| dtype_input = gr.Textbox(label="Data Type", value="bfloat16") | |
| create_button = gr.Button("Create Config YAML") | |
| create_button.click( | |
| fn=create_config_yaml, | |
| inputs=[ | |
| model_name_input, | |
| model1_input, | |
| model1_layers_input, | |
| model2_input, | |
| model2_layers_input, | |
| merge_method_input, | |
| base_model_input, | |
| parameters_input, | |
| dtype_input, | |
| ], | |
| outputs=[], | |
| ) | |
| with gr.Tab("Merge"): | |
| # Placeholder for Merge tab contents | |
| # Not yet tested | |
| merge_output = gr.Textbox(label="Merge Output", interactive=False) | |
| merge_button = gr.Button("Execute Merge Command") | |
| merge_button.click(fn=execute_merge_command, inputs=[], outputs=merge_output) | |
| with gr.Tab("Push to HF Hub"): | |
| push_model_name_input = gr.Textbox(label="Model Name", interactive=False) | |
| push_yaml_config_input = gr.Textbox(label="YAML Config", interactive=False) | |
| push_output = gr.Textbox(label="Push Output", interactive=False) | |
| push_button = gr.Button("Push to HF Hub") | |
| push_button.click( | |
| fn=push_to_hf_hub, | |
| inputs=[push_model_name_input, push_yaml_config_input], | |
| outputs=push_output, | |
| ) | |
| app.launch() | |