Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoConfig | |
| from huggingface_hub import cached_download, hf_hub_url, list_models | |
| import requests | |
| import json | |
| import os | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| import base64 | |
| from transformers.models.auto import AutoModel | |
| from transformers.modeling_utils import PreTrainedModel | |
| import torch | |
| from torch.nn.utils import prune | |
| # Function to fetch open-weight LLM models | |
| def fetch_open_weight_models(): | |
| models = list_models(filter="open-weight", sort="downloads", limit=12) | |
| return [model["id"] for model in models] | |
| # Function to prune a model using the "merge-kit" approach | |
| def prune_model(llm_model_name, target_size, output_dir): | |
| # Load the LLM model and tokenizer | |
| llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
| llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name) | |
| # Get the model config | |
| config = AutoConfig.from_pretrained(llm_model_name) | |
| # Calculate the target number of parameters | |
| target_num_parameters = int(config.num_parameters * (target_size / 100)) | |
| # Use merge-kit to prune the model | |
| pruned_model = merge_kit_prune(llm_model, target_num_parameters) | |
| # Save the pruned model | |
| pruned_model.save_pretrained(output_dir) | |
| # Create a visualization | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.bar(["Original", "Pruned"], [config.num_parameters, pruned_model.num_parameters]) | |
| ax.set_ylabel("Number of Parameters") | |
| ax.set_title("Model Size Comparison") | |
| buf = BytesIO() | |
| fig.savefig(buf, format="png") | |
| buf.seek(0) | |
| image_base64 = base64.b64encode(buf.read()).decode("utf-8") | |
| return f"Pruned model saved to {output_dir}", f"data:image/png;base64,{image_base64}" | |
| # Merge-kit Pruning Function | |
| def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int) -> PreTrainedModel: | |
| """Prunes a model using a merge-kit approach. | |
| Args: | |
| model (PreTrainedModel): The model to be pruned. | |
| target_num_parameters (int): The target number of parameters after pruning. | |
| Returns: | |
| PreTrainedModel: The pruned model. | |
| """ | |
| # Define the pruning method | |
| pruning_method = "unstructured" | |
| # Calculate the pruning amount | |
| amount = 1 - (target_num_parameters / model.num_parameters) | |
| # Prune the model using the selected method | |
| for name, module in model.named_modules(): | |
| if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): | |
| prune.random_unstructured(module, name="weight", amount=amount) | |
| # Remove the pruned weights | |
| for name, module in model.named_modules(): | |
| if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): | |
| prune.remove(module, name="weight") | |
| return model | |
| # Function to create a Gradio interface | |
| def create_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Create a Smaller LLM") | |
| # Fetch open-weight models from Hugging Face | |
| available_models = gr.Dropdown( | |
| label="Choose a Large Language Model", | |
| choices=fetch_open_weight_models(), | |
| interactive=True, | |
| ) | |
| # Input for target model size | |
| target_size = gr.Slider( | |
| label="Target Model Size (%)", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| interactive=True, | |
| ) | |
| # Output for pruning status | |
| pruning_status = gr.Textbox(label="Pruning Status") | |
| # Output for saving the model | |
| save_model_path = gr.Textbox(label="Save Model Path", placeholder="Path to save the pruned model", interactive=True) | |
| # Button to start pruning | |
| prune_button = gr.Button("Prune Model") | |
| # Output for visualization | |
| visualization = gr.Image(label="Model Size Comparison") | |
| # Connect components | |
| prune_button.click( | |
| fn=prune_model, | |
| inputs=[available_models, target_size, save_model_path], | |
| outputs=[pruning_status, visualization], | |
| ) | |
| # Example usage of the pruned model (optional) | |
| text_input = gr.Textbox(label="Input Text") | |
| text_output = gr.Textbox(label="Generated Text") | |
| # Generate text button | |
| generate_button = gr.Button("Generate Text") | |
| def generate_text(text, model_path): | |
| # Load the pruned model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | |
| # Use the pipeline for text generation | |
| generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]["generated_text"] | |
| return generated_text | |
| generate_button.click(fn=generate_text, inputs=[text_input, save_model_path], outputs=text_output) | |
| return demo | |
| # Create and launch the Gradio interface | |
| demo = create_interface() | |
| demo.launch(share=True) |