import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gc def get_model_size_mb(model): """Rough estimate of model size in MB (parameters only)""" param_size = 0 for param in model.parameters(): param_size += param.nelement() * param.element_size() return round(param_size / (1024 ** 2), 1) def parse_layer_string(layer_str, max_layers): """Parses strings like '0, 5, 20-21' into a sorted list of valid integers.""" if not layer_str or not str(layer_str).strip(): return list(range(max_layers)) layers = [] try: for part in str(layer_str).split(","): part = part.strip() if "-" in part: start, end = map(int, part.split("-")) layers.extend(range(start, end + 1)) else: layers.append(int(part)) # Deduplicate, sort, and filter out-of-bounds layers valid_layers = sorted(list(set([l for l in layers if 0 <= l < max_layers]))) return valid_layers except Exception: return None def apply_pruning(model, layers_to_keep_indices): """Helper function to perform the actual lobotomy""" new_layers_list = [] for new_idx, old_idx in enumerate(layers_to_keep_indices): layer = model.model.layers[old_idx] # Reset internal layer indices so KV caching doesn't crash if hasattr(layer, "layer_idx"): layer.layer_idx = new_idx if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "layer_idx"): layer.self_attn.layer_idx = new_idx new_layers_list.append(layer) # Overwrite the model's layers model.model.layers = torch.nn.ModuleList(new_layers_list) model.config.num_hidden_layers = len(new_layers_list) return model def prune_and_test(model_id: str, layer_input: str, test_prompt: str): status_lines = [] status_lines.append(f"Loading base model: {model_id}") model = None tokenizer = None try: model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) orig_layers = len(model.model.layers) if hasattr(model.model, "layers") else 0 if orig_layers == 0: return "\n".join(status_lines) + "\n\n❌ Model architecture not supported." orig_size_mb = get_model_size_mb(model) status_lines.append(f"→ Original layers: {orig_layers} (Indices 0 to {orig_layers - 1})") layers_to_keep_indices = parse_layer_string(layer_input, orig_layers) if not layers_to_keep_indices: return "\n".join(status_lines) + "\n\n❌ Invalid layer selection." status_lines.append(f"\nTargeting layers: {layers_to_keep_indices}") # PRUNE model = apply_pruning(model, layers_to_keep_indices) gc.collect() new_size_mb = get_model_size_mb(model) status_lines.append(f"→ After pruning: {len(layers_to_keep_indices)} layers") status_lines.append(f"→ Size reduced from {orig_size_mb} MB to {new_size_mb} MB") # TEST try: prompt_to_use = test_prompt if test_prompt.strip() else "Hello, the future of AI is" inputs = tokenizer(prompt_to_use, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs.to(model.device), max_new_tokens=40, do_sample=False, pad_token_id=tokenizer.eos_token_id, use_cache=False ) text = tokenizer.decode(outputs[0], skip_special_tokens=True) status_lines.append(f"\nQuick generation test:\n→ {text.strip()}") except Exception as gen_e: status_lines.append(f"\nGeneration test failed: {str(gen_e)}") status_lines.append("\nPruning test successful ✓") return "\n".join(status_lines) except Exception as e: return "\n".join(status_lines) + f"\n\n❌ Failed: {str(e)}" finally: if model is not None: del model if tokenizer is not None: del tokenizer gc.collect() def push_pruned_model(model_id: str, layer_input: str, hf_token: str, repo_id: str): if not hf_token or not repo_id: return "❌ Please provide both a Hugging Face Write Token and a Repo Name." status_lines = [f"Preparing to push pruned {model_id} to {repo_id}..."] model = None tokenizer = None try: model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) orig_layers = len(model.model.layers) layers_to_keep_indices = parse_layer_string(layer_input, orig_layers) status_lines.append(f"Pruning to layers: {layers_to_keep_indices}...") model = apply_pruning(model, layers_to_keep_indices) status_lines.append("Pushing model and tokenizer to Hub (this may take a minute)...") # PUSH COMMANDS model.push_to_hub(repo_id, token=hf_token) tokenizer.push_to_hub(repo_id, token=hf_token) status_lines.append(f"\n✅ SUCCESS! Model pushed to: https://huggingface.co/{repo_id}") status_lines.append("You can now load this in any script using AutoModelForCausalLM.from_pretrained()") return "\n".join(status_lines) except Exception as e: return "\n".join(status_lines) + f"\n\n❌ Push Failed: {str(e)}" finally: if model is not None: del model if tokenizer is not None: del tokenizer gc.collect() # ──────────────────────────────────────────────── # Gradio Interface # ──────────────────────────────────────────────── CSS = """.gradio-container { max-width: 950px !important; }""" with gr.Blocks(title="The Frankenstein Pruner", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🧠 The Frankenstein Pruner Lobotomize language models by ripping out their middle layers. Test the output, and if it produces beautiful, semi-coherent gibberish, save your abomination to the Hugging Face Hub! """) with gr.Row(): with gr.Column(scale=1): gr.Markdown(""" ### ✂️ Mad Scientist Cheat Sheet *Models have between 16 and 28 layers. If a model has 22 layers, valid indices are `0` to `21`.* **The Bookend (The Best Starter)** Keep the first 2 and last 2 layers. It usually remembers grammar but forgets facts. *Try:* `0, 1, 20, 21` **The Swiss Cheese** Keep every other layer. The model's logic gets incredibly confused. *Try:* `0, 2, 4, 6, 8, 10, 12, 14` **The Brainless** Keep only the first layer and the last layer. Absolute chaos. *Try:* `0, 21` """) with gr.Column(scale=1): model_choice = gr.Dropdown( choices=[ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", "google/gemma-2-2b-it", ], label="Base Model", value="TinyLlama/TinyLlama-1.1B-Chat-v1.0" ) layer_choice = gr.Textbox( label="Layers to Keep (e.g. '0, 1, 20, 21')", value="0, 1, 20, 21" ) prompt_choice = gr.Textbox( label="Test Prompt", value="The capital of France is Paris. The capital of Japan is", ) btn_test = gr.Button("🧪 Prune & Test Generate", variant="primary") with gr.Accordion("🚀 Push to Hugging Face Hub", open=False): gr.Markdown("Love your broken model? Save it! You'll need a **Write Token** from your HF account settings.") hf_token_input = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...") repo_name_input = gr.Textbox(label="Target Repo Name", placeholder="your-username/my-weird-llama") btn_push = gr.Button("Push to Hub", variant="secondary") with gr.Column(scale=1): status = gr.Textbox( label="Surgery Log & Output", lines=25, interactive=False ) btn_test.click( prune_and_test, inputs=[model_choice, layer_choice, prompt_choice], outputs=status ) btn_push.click( push_pruned_model, inputs=[model_choice, layer_choice, hf_token_input, repo_name_input], outputs=status ) if __name__ == "__main__": demo.launch(css=CSS)