Vishwas1's picture
Update app.py
b5574df verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
def get_model_size_mb(model):
"""Rough estimate of model size in MB (parameters only)"""
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
return round(param_size / (1024 ** 2), 1)
def parse_layer_string(layer_str, max_layers):
"""Parses strings like '0, 5, 20-21' into a sorted list of valid integers."""
if not layer_str or not str(layer_str).strip():
return list(range(max_layers))
layers = []
try:
for part in str(layer_str).split(","):
part = part.strip()
if "-" in part:
start, end = map(int, part.split("-"))
layers.extend(range(start, end + 1))
else:
layers.append(int(part))
# Deduplicate, sort, and filter out-of-bounds layers
valid_layers = sorted(list(set([l for l in layers if 0 <= l < max_layers])))
return valid_layers
except Exception:
return None
def apply_pruning(model, layers_to_keep_indices):
"""Helper function to perform the actual lobotomy"""
new_layers_list = []
for new_idx, old_idx in enumerate(layers_to_keep_indices):
layer = model.model.layers[old_idx]
# Reset internal layer indices so KV caching doesn't crash
if hasattr(layer, "layer_idx"):
layer.layer_idx = new_idx
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "layer_idx"):
layer.self_attn.layer_idx = new_idx
new_layers_list.append(layer)
# Overwrite the model's layers
model.model.layers = torch.nn.ModuleList(new_layers_list)
model.config.num_hidden_layers = len(new_layers_list)
return model
def prune_and_test(model_id: str, layer_input: str, test_prompt: str):
status_lines = []
status_lines.append(f"Loading base model: {model_id}")
model = None
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
orig_layers = len(model.model.layers) if hasattr(model.model, "layers") else 0
if orig_layers == 0:
return "\n".join(status_lines) + "\n\n❌ Model architecture not supported."
orig_size_mb = get_model_size_mb(model)
status_lines.append(f"β†’ Original layers: {orig_layers} (Indices 0 to {orig_layers - 1})")
layers_to_keep_indices = parse_layer_string(layer_input, orig_layers)
if not layers_to_keep_indices:
return "\n".join(status_lines) + "\n\n❌ Invalid layer selection."
status_lines.append(f"\nTargeting layers: {layers_to_keep_indices}")
# PRUNE
model = apply_pruning(model, layers_to_keep_indices)
gc.collect()
new_size_mb = get_model_size_mb(model)
status_lines.append(f"β†’ After pruning: {len(layers_to_keep_indices)} layers")
status_lines.append(f"β†’ Size reduced from {orig_size_mb} MB to {new_size_mb} MB")
# TEST
try:
prompt_to_use = test_prompt if test_prompt.strip() else "Hello, the future of AI is"
inputs = tokenizer(prompt_to_use, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs.to(model.device),
max_new_tokens=40,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
use_cache=False
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
status_lines.append(f"\nQuick generation test:\n→ {text.strip()}")
except Exception as gen_e:
status_lines.append(f"\nGeneration test failed: {str(gen_e)}")
status_lines.append("\nPruning test successful βœ“")
return "\n".join(status_lines)
except Exception as e:
return "\n".join(status_lines) + f"\n\n❌ Failed: {str(e)}"
finally:
if model is not None: del model
if tokenizer is not None: del tokenizer
gc.collect()
def push_pruned_model(model_id: str, layer_input: str, hf_token: str, repo_id: str):
if not hf_token or not repo_id:
return "❌ Please provide both a Hugging Face Write Token and a Repo Name."
status_lines = [f"Preparing to push pruned {model_id} to {repo_id}..."]
model = None
tokenizer = None
try:
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
orig_layers = len(model.model.layers)
layers_to_keep_indices = parse_layer_string(layer_input, orig_layers)
status_lines.append(f"Pruning to layers: {layers_to_keep_indices}...")
model = apply_pruning(model, layers_to_keep_indices)
status_lines.append("Pushing model and tokenizer to Hub (this may take a minute)...")
# PUSH COMMANDS
model.push_to_hub(repo_id, token=hf_token)
tokenizer.push_to_hub(repo_id, token=hf_token)
status_lines.append(f"\nβœ… SUCCESS! Model pushed to: https://huggingface.co/{repo_id}")
status_lines.append("You can now load this in any script using AutoModelForCausalLM.from_pretrained()")
return "\n".join(status_lines)
except Exception as e:
return "\n".join(status_lines) + f"\n\n❌ Push Failed: {str(e)}"
finally:
if model is not None: del model
if tokenizer is not None: del tokenizer
gc.collect()
# ────────────────────────────────────────────────
# Gradio Interface
# ────────────────────────────────────────────────
CSS = """.gradio-container { max-width: 950px !important; }"""
with gr.Blocks(title="The Frankenstein Pruner", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🧠 The Frankenstein Pruner
Lobotomize language models by ripping out their middle layers. Test the output, and if it produces beautiful, semi-coherent gibberish, save your abomination to the Hugging Face Hub!
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""
### βœ‚οΈ Mad Scientist Cheat Sheet
*Models have between 16 and 28 layers. If a model has 22 layers, valid indices are `0` to `21`.*
**The Bookend (The Best Starter)**
Keep the first 2 and last 2 layers. It usually remembers grammar but forgets facts.
*Try:* `0, 1, 20, 21`
**The Swiss Cheese**
Keep every other layer. The model's logic gets incredibly confused.
*Try:* `0, 2, 4, 6, 8, 10, 12, 14`
**The Brainless**
Keep only the first layer and the last layer. Absolute chaos.
*Try:* `0, 21`
""")
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=[
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
"google/gemma-2-2b-it",
],
label="Base Model",
value="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
)
layer_choice = gr.Textbox(
label="Layers to Keep (e.g. '0, 1, 20, 21')",
value="0, 1, 20, 21"
)
prompt_choice = gr.Textbox(
label="Test Prompt",
value="The capital of France is Paris. The capital of Japan is",
)
btn_test = gr.Button("πŸ§ͺ Prune & Test Generate", variant="primary")
with gr.Accordion("πŸš€ Push to Hugging Face Hub", open=False):
gr.Markdown("Love your broken model? Save it! You'll need a **Write Token** from your HF account settings.")
hf_token_input = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
repo_name_input = gr.Textbox(label="Target Repo Name", placeholder="your-username/my-weird-llama")
btn_push = gr.Button("Push to Hub", variant="secondary")
with gr.Column(scale=1):
status = gr.Textbox(
label="Surgery Log & Output",
lines=25,
interactive=False
)
btn_test.click(
prune_and_test,
inputs=[model_choice, layer_choice, prompt_choice],
outputs=status
)
btn_push.click(
push_pruned_model,
inputs=[model_choice, layer_choice, hf_token_input, repo_name_input],
outputs=status
)
if __name__ == "__main__":
demo.launch(css=CSS)