File size: 9,646 Bytes
17f949e
 
 
71db69e
17f949e
71db69e
 
 
 
 
 
 
1f5c9f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5574df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f5c9f6
71db69e
 
17f949e
80b56c0
 
 
17f949e
 
71db69e
80b56c0
17f949e
71db69e
 
 
b5574df
80b56c0
1f5c9f6
 
b5574df
1f5c9f6
71db69e
1f5c9f6
2a6178a
b5574df
 
 
1f5c9f6
 
 
b5574df
 
71db69e
80b56c0
71db69e
b5574df
 
80b56c0
b5574df
71db69e
1f5c9f6
 
 
71db69e
 
 
 
8449b1c
2a6178a
b5574df
71db69e
 
b5574df
71db69e
b5574df
80b56c0
b5574df
71db69e
80b56c0
71db69e
b5574df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80b56c0
b5574df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71db69e
b5574df
 
80b56c0
71db69e
 
 
 
b5574df
17f949e
b5574df
71db69e
1f5c9f6
b5574df
71db69e
1f5c9f6
 
b5574df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f5c9f6
 
 
 
 
 
 
 
 
 
 
 
b5574df
 
1f5c9f6
 
 
 
 
 
 
b5574df
 
 
 
 
 
 
1f5c9f6
b5574df
1f5c9f6
 
b5574df
1f5c9f6
 
80b56c0
b5574df
1f5c9f6
 
71db69e
 
b5574df
 
 
 
 
 
17f949e
80b56c0
b5574df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc

def get_model_size_mb(model):
    """Rough estimate of model size in MB (parameters only)"""
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    return round(param_size / (1024 ** 2), 1)

def parse_layer_string(layer_str, max_layers):
    """Parses strings like '0, 5, 20-21' into a sorted list of valid integers."""
    if not layer_str or not str(layer_str).strip():
        return list(range(max_layers))
        
    layers = []
    try:
        for part in str(layer_str).split(","):
            part = part.strip()
            if "-" in part:
                start, end = map(int, part.split("-"))
                layers.extend(range(start, end + 1))
            else:
                layers.append(int(part))
                
        # Deduplicate, sort, and filter out-of-bounds layers
        valid_layers = sorted(list(set([l for l in layers if 0 <= l < max_layers])))
        return valid_layers
    except Exception:
        return None

def apply_pruning(model, layers_to_keep_indices):
    """Helper function to perform the actual lobotomy"""
    new_layers_list = []
    for new_idx, old_idx in enumerate(layers_to_keep_indices):
        layer = model.model.layers[old_idx]
        
        # Reset internal layer indices so KV caching doesn't crash
        if hasattr(layer, "layer_idx"):
            layer.layer_idx = new_idx
        if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "layer_idx"):
            layer.self_attn.layer_idx = new_idx
            
        new_layers_list.append(layer)
        
    # Overwrite the model's layers
    model.model.layers = torch.nn.ModuleList(new_layers_list)
    model.config.num_hidden_layers = len(new_layers_list)
    return model

def prune_and_test(model_id: str, layer_input: str, test_prompt: str):
    status_lines = []
    status_lines.append(f"Loading base model: {model_id}")
    
    model = None
    tokenizer = None
        
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16, 
            device_map="cpu",
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
                
        orig_layers = len(model.model.layers) if hasattr(model.model, "layers") else 0
        if orig_layers == 0:
            return "\n".join(status_lines) + "\n\n❌ Model architecture not supported."
            
        orig_size_mb = get_model_size_mb(model)
        status_lines.append(f"β†’ Original layers: {orig_layers} (Indices 0 to {orig_layers - 1})")
        
        layers_to_keep_indices = parse_layer_string(layer_input, orig_layers)
        if not layers_to_keep_indices:
            return "\n".join(status_lines) + "\n\n❌ Invalid layer selection."
            
        status_lines.append(f"\nTargeting layers: {layers_to_keep_indices}")
                
        # PRUNE
        model = apply_pruning(model, layers_to_keep_indices)
        gc.collect()
                
        new_size_mb = get_model_size_mb(model)
        status_lines.append(f"β†’ After pruning: {len(layers_to_keep_indices)} layers")
        status_lines.append(f"β†’ Size reduced from {orig_size_mb} MB to {new_size_mb} MB")
                
        # TEST
        try:
            prompt_to_use = test_prompt if test_prompt.strip() else "Hello, the future of AI is"
            inputs = tokenizer(prompt_to_use, return_tensors="pt")
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs.to(model.device),
                    max_new_tokens=40,
                    do_sample=False, 
                    pad_token_id=tokenizer.eos_token_id,
                    use_cache=False
                )
            text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            status_lines.append(f"\nQuick generation test:\n→ {text.strip()}")
        except Exception as gen_e:
            status_lines.append(f"\nGeneration test failed: {str(gen_e)}")
                
        status_lines.append("\nPruning test successful βœ“")
        return "\n".join(status_lines)
        
    except Exception as e:
        return "\n".join(status_lines) + f"\n\n❌ Failed: {str(e)}"
    finally:
        if model is not None: del model
        if tokenizer is not None: del tokenizer
        gc.collect()

def push_pruned_model(model_id: str, layer_input: str, hf_token: str, repo_id: str):
    if not hf_token or not repo_id:
        return "❌ Please provide both a Hugging Face Write Token and a Repo Name."
        
    status_lines = [f"Preparing to push pruned {model_id} to {repo_id}..."]
    model = None
    tokenizer = None
    
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        
        orig_layers = len(model.model.layers)
        layers_to_keep_indices = parse_layer_string(layer_input, orig_layers)
        
        status_lines.append(f"Pruning to layers: {layers_to_keep_indices}...")
        model = apply_pruning(model, layers_to_keep_indices)
        
        status_lines.append("Pushing model and tokenizer to Hub (this may take a minute)...")
        
        # PUSH COMMANDS
        model.push_to_hub(repo_id, token=hf_token)
        tokenizer.push_to_hub(repo_id, token=hf_token)
        
        status_lines.append(f"\nβœ… SUCCESS! Model pushed to: https://huggingface.co/{repo_id}")
        status_lines.append("You can now load this in any script using AutoModelForCausalLM.from_pretrained()")
        return "\n".join(status_lines)
        
    except Exception as e:
        return "\n".join(status_lines) + f"\n\n❌ Push Failed: {str(e)}"
    finally:
        if model is not None: del model
        if tokenizer is not None: del tokenizer
        gc.collect()

# ────────────────────────────────────────────────
#                 Gradio Interface
# ────────────────────────────────────────────────
CSS = """.gradio-container { max-width: 950px !important; }"""

with gr.Blocks(title="The Frankenstein Pruner", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🧠 The Frankenstein Pruner
    Lobotomize language models by ripping out their middle layers. Test the output, and if it produces beautiful, semi-coherent gibberish, save your abomination to the Hugging Face Hub!
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("""
            ### βœ‚οΈ Mad Scientist Cheat Sheet
            *Models have between 16 and 28 layers. If a model has 22 layers, valid indices are `0` to `21`.*
            
            **The Bookend (The Best Starter)**
            Keep the first 2 and last 2 layers. It usually remembers grammar but forgets facts.
            *Try:* `0, 1, 20, 21`
            
            **The Swiss Cheese**
            Keep every other layer. The model's logic gets incredibly confused.
            *Try:* `0, 2, 4, 6, 8, 10, 12, 14`
            
            **The Brainless**
            Keep only the first layer and the last layer. Absolute chaos.
            *Try:* `0, 21`
            """)
            
        with gr.Column(scale=1):
            model_choice = gr.Dropdown(
                choices=[
                    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                    "Qwen/Qwen2.5-0.5B-Instruct",
                    "Qwen/Qwen2.5-1.5B-Instruct",
                    "google/gemma-2-2b-it",
                ],
                label="Base Model",
                value="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
            )
            
            layer_choice = gr.Textbox(
                label="Layers to Keep (e.g. '0, 1, 20, 21')",
                value="0, 1, 20, 21"
            )
            
            prompt_choice = gr.Textbox(
                label="Test Prompt",
                value="The capital of France is Paris. The capital of Japan is",
            )
            
            btn_test = gr.Button("πŸ§ͺ Prune & Test Generate", variant="primary")
            
            with gr.Accordion("πŸš€ Push to Hugging Face Hub", open=False):
                gr.Markdown("Love your broken model? Save it! You'll need a **Write Token** from your HF account settings.")
                hf_token_input = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
                repo_name_input = gr.Textbox(label="Target Repo Name", placeholder="your-username/my-weird-llama")
                btn_push = gr.Button("Push to Hub", variant="secondary")
            
        with gr.Column(scale=1):
            status = gr.Textbox(
                label="Surgery Log & Output",
                lines=25,
                interactive=False
            )
        
    btn_test.click(
        prune_and_test,
        inputs=[model_choice, layer_choice, prompt_choice],
        outputs=status
    )
    
    btn_push.click(
        push_pruned_model,
        inputs=[model_choice, layer_choice, hf_token_input, repo_name_input],
        outputs=status
    )

if __name__ == "__main__":
    demo.launch(css=CSS)