Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer, pipeline, AutoConfig, AutoModelForCausalLM | |
| from huggingface_hub import create_repo, HfApi, list_models | |
| from transformers.modeling_utils import PreTrainedModel | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| import base64 | |
| import torch | |
| from torch.nn.utils import prune | |
| import subprocess | |
| import logging | |
| import sys | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Ensure sentencepiece is installed | |
| try: | |
| import sentencepiece | |
| except ImportError: | |
| subprocess.check_call(['pip', 'install', 'sentencepiece']) | |
| # Function to fetch open-weight LLM models | |
| def fetch_open_weight_models(): | |
| try: | |
| models = list_models() | |
| return models | |
| except Exception as e: | |
| logging.error(f"Error fetching models: {e}") | |
| return [] | |
| # Custom function to retrieve just names from models list | |
| def get_model_names(): | |
| models = fetch_open_weight_models() | |
| model_names = [model.modelId for model in models if model.modelId is not None] | |
| return model_names | |
| # Full merge-kit Pruning Function | |
| def merge_kit_prune(model: PreTrainedModel, target_num_parameters: int, progress: gr.Progress) -> PreTrainedModel: | |
| """Prunes a model using a merge-kit approach. | |
| Args: | |
| model (PreTrainedModel): The model to be pruned. | |
| target_num_parameters (int): The target number of parameters after pruning. | |
| progress (gr.Progress): The progress object for visual feedback. | |
| Returns: | |
| PreTrainedModel: The pruned model. | |
| """ | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| amount = 1 - (target_num_parameters / total_params) | |
| try: | |
| # Prune the model | |
| for i, (name, module) in enumerate(tqdm(model.named_modules(), desc="Pruning", file=sys.stdout)): | |
| if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): | |
| prune.random_unstructured(module, name="weight", amount=amount) | |
| progress(percent_complete=50 * (i + 1) / len(list(model.named_modules()))) # Progress update | |
| # Remove the pruned weights | |
| for i, (name, module) in enumerate(model.named_modules()): | |
| if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): | |
| prune.remove(module, name="weight") | |
| progress(percent_complete=50 + 50 * (i + 1) / len(list(model.named_modules()))) # Progress update | |
| return model | |
| except Exception as e: | |
| logging.error(f"Error during pruning: {e}") | |
| raise e | |
| # Function to prune a model | |
| def prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name=None, progress=gr.Progress(track_tqdm=True)): | |
| log_messages = [] | |
| try: | |
| # Load the LLM model and tokenizer | |
| llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| llm_model_name, | |
| torch_dtype=torch.float16, | |
| ) | |
| log_messages.append('Model and tokenizer loaded successfully.') | |
| logging.info('Model and tokenizer loaded successfully.') | |
| total_params = sum(p.numel() for p in llm_model.parameters()) | |
| target_num_parameters = int(total_params * (target_size / 100)) | |
| # Prune the model | |
| pruned_model = merge_kit_prune(llm_model, target_num_parameters, progress) | |
| log_messages.append('Model pruned successfully.') | |
| logging.info('Model pruned successfully.') | |
| # Save the pruned model | |
| api = HfApi() | |
| create_repo(repo_name, token=hf_write_token, private=False, exist_ok=True) | |
| pruned_model.push_to_hub(repo_name, use_auth_token=hf_write_token) | |
| llm_tokenizer.push_to_hub(repo_name, use_auth_token=hf_write_token) | |
| log_messages.append(f"Pruned model saved to Hugging Face Hub in repository {repo_name}") | |
| logging.info(f"Pruned model saved to Hugging Face Hub in repository {repo_name}") | |
| # Create a visualization | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.bar(['Original', 'Pruned'], [total_params, sum(p.numel() for p in pruned_model.parameters())]) | |
| ax.set_ylabel('Number of Parameters') | |
| ax.set_title('Model Size Comparison') | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| image_base64 = base64.b64encode(buf.read()).decode('utf-8') | |
| return f"Pruned model saved to Hugging Face Hub in repository {repo_name}", f"data:image/png;base64,{image_base64}", '\n'.join(log_messages) | |
| except Exception as e: | |
| error_message = f"Detailed error: {repr(e)}" | |
| log_messages.append(error_message) | |
| logging.error(error_message) | |
| return error_message, None, '\n'.join(log_messages) | |
| # Define function to generate text | |
| def generate_text(text, repo_name, hf_write_token): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(repo_name, use_auth_token=hf_write_token) | |
| model = AutoModelForCausalLM.from_pretrained(repo_name, use_auth_token=hf_write_token) | |
| generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
| generated_text = generator(text, max_length=50, num_beams=5, num_return_sequences=1)[0]['generated_text'] | |
| return generated_text | |
| except Exception as e: | |
| logging.error(f"Error during text generation: {e}") | |
| return f"Error: {repr(e)}" | |
| # Function to create a Gradio interface | |
| def create_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Create a Smaller LLM") | |
| # Fetch available model names | |
| model_names = get_model_names() | |
| # Input components | |
| llm_model_name = gr.Dropdown(label="Choose a Large Language Model", choices=model_names, interactive=True) | |
| base_model_name = gr.Dropdown(label="Base Model Name (if required)", choices=model_names, interactive=True, visible=False) | |
| target_size = gr.Slider(label="Target Model Size (%)", minimum=1, maximum=100, step=1, value=50, interactive=True) | |
| hf_write_token = gr.Textbox(label="Hugging Face Write Token", placeholder="Enter your HF write token", interactive=True, type="password") | |
| repo_name = gr.Textbox(label="Repository Name", placeholder="Enter the name of the repository", interactive=True) | |
| pruned_func_choice = gr.Radio(label="Pruning Function", choices=["merge-kit"], value="merge-kit", interactive=True) | |
| pruning_status = gr.Textbox(label="Pruning Status", interactive=False) | |
| prune_button = gr.Button("Prune Model") | |
| visualization = gr.Image(label="Model Size Comparison", interactive=False) | |
| progress_bar = gr.Progress() | |
| # Define function for pruning model with progress | |
| def prune_model_with_progress(llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice): | |
| if pruned_func_choice == "merge-kit": | |
| return prune_model(llm_model_name, target_size, hf_write_token, repo_name, base_model_name, progress_bar) | |
| else: | |
| return f"Pruning function '{pruned_func_choice}' not implemented.", None, None | |
| prune_button.click(fn=prune_model_with_progress, inputs=[llm_model_name, base_model_name, target_size, hf_write_token, repo_name, pruned_func_choice], outputs=[pruning_status, visualization]) | |
| text_input = gr.Textbox(label="Input Text") | |
| text_output = gr.Textbox(label="Generated Text") | |
| generate_button = gr.Button("Generate Text") | |
| generate_button.click(fn=generate_text, inputs=[text_input, repo_name, hf_write_token], outputs=text_output) | |
| return demo | |
| # Create and launch the Gradio interface | |
| demo = create_interface() | |
| demo.launch() |