Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -47,7 +47,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
|
|
| 47 |
target_num_parameters = int(config.num_parameters * (target_size / 100))
|
| 48 |
|
| 49 |
# Prune the model
|
| 50 |
-
pruned_model = merge_kit_prune(llm_model, target_num_parameters
|
| 51 |
|
| 52 |
log_messages.append("Model pruned successfully.")
|
| 53 |
logging.info("Model pruned successfully.")
|
|
@@ -81,7 +81,7 @@ def prune_model(llm_model_name, target_size, hf_write_token, repo_name, progress
|
|
| 81 |
return error_message, None, "\n".join(log_messages)
|
| 82 |
|
| 83 |
# Merge-kit Pruning Function (adjust as needed)
|
| 84 |
-
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int
|
| 85 |
"""Prunes a model using a merge-kit approach.
|
| 86 |
Args:
|
| 87 |
model (PreTrainedModel): The model to be pruned.
|
|
@@ -120,9 +120,9 @@ def create_interface():
|
|
| 120 |
pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
|
| 121 |
prune_button = gr.Button("Prune Model")
|
| 122 |
visualization = gr.Image(label="Model Size Comparison", interactive=False)
|
| 123 |
-
progress_bar = gr.Progress()
|
| 124 |
logs_button = gr.Button("Show Logs")
|
| 125 |
logs_output = gr.Textbox(label="Logs", interactive=False)
|
|
|
|
| 126 |
|
| 127 |
def show_logs():
|
| 128 |
with open("pruning.log", "r") as log_file:
|
|
@@ -131,7 +131,11 @@ def create_interface():
|
|
| 131 |
|
| 132 |
logs_button.click(fn=show_logs, outputs=logs_output)
|
| 133 |
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
text_input = gr.Textbox(label="Input Text")
|
| 137 |
text_output = gr.Textbox(label="Generated Text")
|
|
|
|
| 47 |
target_num_parameters = int(config.num_parameters * (target_size / 100))
|
| 48 |
|
| 49 |
# Prune the model
|
| 50 |
+
pruned_model = merge_kit_prune(llm_model, target_num_parameters)
|
| 51 |
|
| 52 |
log_messages.append("Model pruned successfully.")
|
| 53 |
logging.info("Model pruned successfully.")
|
|
|
|
| 81 |
return error_message, None, "\n".join(log_messages)
|
| 82 |
|
| 83 |
# Merge-kit Pruning Function (adjust as needed)
|
| 84 |
+
def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel:
|
| 85 |
"""Prunes a model using a merge-kit approach.
|
| 86 |
Args:
|
| 87 |
model (PreTrainedModel): The model to be pruned.
|
|
|
|
| 120 |
pruning_status = gr.Textbox(label="Pruning Status", interactive=False)
|
| 121 |
prune_button = gr.Button("Prune Model")
|
| 122 |
visualization = gr.Image(label="Model Size Comparison", interactive=False)
|
|
|
|
| 123 |
logs_button = gr.Button("Show Logs")
|
| 124 |
logs_output = gr.Textbox(label="Logs", interactive=False)
|
| 125 |
+
progress_bar = gr.Progress()
|
| 126 |
|
| 127 |
def show_logs():
|
| 128 |
with open("pruning.log", "r") as log_file:
|
|
|
|
| 131 |
|
| 132 |
logs_button.click(fn=show_logs, outputs=logs_output)
|
| 133 |
|
| 134 |
+
def prune_model_with_progress(llm_model_name, target_size, hf_write_token, repo_name):
|
| 135 |
+
with progress_bar:
|
| 136 |
+
return prune_model(llm_model_name, target_size, hf_write_token, repo_name)
|
| 137 |
+
|
| 138 |
+
prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, target_size, hf_write_token, repo_name], outputs=[pruning_status, visualization, logs_output])
|
| 139 |
|
| 140 |
text_input = gr.Textbox(label="Input Text")
|
| 141 |
text_output = gr.Textbox(label="Generated Text")
|