File size: 9,646 Bytes
17f949e 71db69e 17f949e 71db69e 1f5c9f6 b5574df 1f5c9f6 71db69e 17f949e 80b56c0 17f949e 71db69e 80b56c0 17f949e 71db69e b5574df 80b56c0 1f5c9f6 b5574df 1f5c9f6 71db69e 1f5c9f6 2a6178a b5574df 1f5c9f6 b5574df 71db69e 80b56c0 71db69e b5574df 80b56c0 b5574df 71db69e 1f5c9f6 71db69e 8449b1c 2a6178a b5574df 71db69e b5574df 71db69e b5574df 80b56c0 b5574df 71db69e 80b56c0 71db69e b5574df 80b56c0 b5574df 71db69e b5574df 80b56c0 71db69e b5574df 17f949e b5574df 71db69e 1f5c9f6 b5574df 71db69e 1f5c9f6 b5574df 1f5c9f6 b5574df 1f5c9f6 b5574df 1f5c9f6 b5574df 1f5c9f6 b5574df 1f5c9f6 80b56c0 b5574df 1f5c9f6 71db69e b5574df 17f949e 80b56c0 b5574df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
def get_model_size_mb(model):
"""Rough estimate of model size in MB (parameters only)"""
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
return round(param_size / (1024 ** 2), 1)
def parse_layer_string(layer_str, max_layers):
"""Parses strings like '0, 5, 20-21' into a sorted list of valid integers."""
if not layer_str or not str(layer_str).strip():
return list(range(max_layers))
layers = []
try:
for part in str(layer_str).split(","):
part = part.strip()
if "-" in part:
start, end = map(int, part.split("-"))
layers.extend(range(start, end + 1))
else:
layers.append(int(part))
# Deduplicate, sort, and filter out-of-bounds layers
valid_layers = sorted(list(set([l for l in layers if 0 <= l < max_layers])))
return valid_layers
except Exception:
return None
def apply_pruning(model, layers_to_keep_indices):
"""Helper function to perform the actual lobotomy"""
new_layers_list = []
for new_idx, old_idx in enumerate(layers_to_keep_indices):
layer = model.model.layers[old_idx]
# Reset internal layer indices so KV caching doesn't crash
if hasattr(layer, "layer_idx"):
layer.layer_idx = new_idx
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "layer_idx"):
layer.self_attn.layer_idx = new_idx
new_layers_list.append(layer)
# Overwrite the model's layers
model.model.layers = torch.nn.ModuleList(new_layers_list)
model.config.num_hidden_layers = len(new_layers_list)
return model
def prune_and_test(model_id: str, layer_input: str, test_prompt: str):
status_lines = []
status_lines.append(f"Loading base model: {model_id}")
model = None
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
orig_layers = len(model.model.layers) if hasattr(model.model, "layers") else 0
if orig_layers == 0:
return "\n".join(status_lines) + "\n\nβ Model architecture not supported."
orig_size_mb = get_model_size_mb(model)
status_lines.append(f"β Original layers: {orig_layers} (Indices 0 to {orig_layers - 1})")
layers_to_keep_indices = parse_layer_string(layer_input, orig_layers)
if not layers_to_keep_indices:
return "\n".join(status_lines) + "\n\nβ Invalid layer selection."
status_lines.append(f"\nTargeting layers: {layers_to_keep_indices}")
# PRUNE
model = apply_pruning(model, layers_to_keep_indices)
gc.collect()
new_size_mb = get_model_size_mb(model)
status_lines.append(f"β After pruning: {len(layers_to_keep_indices)} layers")
status_lines.append(f"β Size reduced from {orig_size_mb} MB to {new_size_mb} MB")
# TEST
try:
prompt_to_use = test_prompt if test_prompt.strip() else "Hello, the future of AI is"
inputs = tokenizer(prompt_to_use, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs.to(model.device),
max_new_tokens=40,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
use_cache=False
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
status_lines.append(f"\nQuick generation test:\nβ {text.strip()}")
except Exception as gen_e:
status_lines.append(f"\nGeneration test failed: {str(gen_e)}")
status_lines.append("\nPruning test successful β")
return "\n".join(status_lines)
except Exception as e:
return "\n".join(status_lines) + f"\n\nβ Failed: {str(e)}"
finally:
if model is not None: del model
if tokenizer is not None: del tokenizer
gc.collect()
def push_pruned_model(model_id: str, layer_input: str, hf_token: str, repo_id: str):
if not hf_token or not repo_id:
return "β Please provide both a Hugging Face Write Token and a Repo Name."
status_lines = [f"Preparing to push pruned {model_id} to {repo_id}..."]
model = None
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
orig_layers = len(model.model.layers)
layers_to_keep_indices = parse_layer_string(layer_input, orig_layers)
status_lines.append(f"Pruning to layers: {layers_to_keep_indices}...")
model = apply_pruning(model, layers_to_keep_indices)
status_lines.append("Pushing model and tokenizer to Hub (this may take a minute)...")
# PUSH COMMANDS
model.push_to_hub(repo_id, token=hf_token)
tokenizer.push_to_hub(repo_id, token=hf_token)
status_lines.append(f"\nβ
SUCCESS! Model pushed to: https://huggingface.co/{repo_id}")
status_lines.append("You can now load this in any script using AutoModelForCausalLM.from_pretrained()")
return "\n".join(status_lines)
except Exception as e:
return "\n".join(status_lines) + f"\n\nβ Push Failed: {str(e)}"
finally:
if model is not None: del model
if tokenizer is not None: del tokenizer
gc.collect()
# ββββββββββββββββββββββββββββββββββββββββββββββββ
# Gradio Interface
# ββββββββββββββββββββββββββββββββββββββββββββββββ
CSS = """.gradio-container { max-width: 950px !important; }"""
with gr.Blocks(title="The Frankenstein Pruner", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# π§ The Frankenstein Pruner
Lobotomize language models by ripping out their middle layers. Test the output, and if it produces beautiful, semi-coherent gibberish, save your abomination to the Hugging Face Hub!
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""
### βοΈ Mad Scientist Cheat Sheet
*Models have between 16 and 28 layers. If a model has 22 layers, valid indices are `0` to `21`.*
**The Bookend (The Best Starter)**
Keep the first 2 and last 2 layers. It usually remembers grammar but forgets facts.
*Try:* `0, 1, 20, 21`
**The Swiss Cheese**
Keep every other layer. The model's logic gets incredibly confused.
*Try:* `0, 2, 4, 6, 8, 10, 12, 14`
**The Brainless**
Keep only the first layer and the last layer. Absolute chaos.
*Try:* `0, 21`
""")
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=[
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
"google/gemma-2-2b-it",
],
label="Base Model",
value="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
)
layer_choice = gr.Textbox(
label="Layers to Keep (e.g. '0, 1, 20, 21')",
value="0, 1, 20, 21"
)
prompt_choice = gr.Textbox(
label="Test Prompt",
value="The capital of France is Paris. The capital of Japan is",
)
btn_test = gr.Button("π§ͺ Prune & Test Generate", variant="primary")
with gr.Accordion("π Push to Hugging Face Hub", open=False):
gr.Markdown("Love your broken model? Save it! You'll need a **Write Token** from your HF account settings.")
hf_token_input = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
repo_name_input = gr.Textbox(label="Target Repo Name", placeholder="your-username/my-weird-llama")
btn_push = gr.Button("Push to Hub", variant="secondary")
with gr.Column(scale=1):
status = gr.Textbox(
label="Surgery Log & Output",
lines=25,
interactive=False
)
btn_test.click(
prune_and_test,
inputs=[model_choice, layer_choice, prompt_choice],
outputs=status
)
btn_push.click(
push_pruned_model,
inputs=[model_choice, layer_choice, hf_token_input, repo_name_input],
outputs=status
)
if __name__ == "__main__":
demo.launch(css=CSS) |