Spaces:
Paused
Paused
File size: 5,682 Bytes
30bc257 40ba062 30bc257 3fafd04 2bf8fe3 30bc257 40ba062 30bc257 40ba062 30bc257 700396f 30bc257 700396f 30bc257 40ba062 30bc257 40ba062 30bc257 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from verbatim_llm import TokenSwapProcessor
import gc
# Predefined model pairs
MODEL_PAIRS = {
"Pythia 6.9B + 70M": ("EleutherAI/pythia-6.9b", "EleutherAI/pythia-70m"),
"OLMo-2 13B Instruct + SmolLM 135M Instruct": ("allenai/OLMo-2-1124-13B-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct"),
"DeepSeek 7B Chat + SmolLM 135M Instruct": ("deepseek-ai/deepseek-llm-7b-chat", "HuggingFaceTB/SmolLM-135M-Instruct"),
"DeepSeek 70B Chat + SmolLM 135M Instruct": ("deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "HuggingFaceTB/SmolLM-135M-Instruct"),
}
# Global variables to store loaded models
loaded_models = {}
current_pair = None
def clear_models():
global loaded_models, current_pair
try:
# Clear models from memory
if loaded_models:
# Move models to CPU if they were on GPU
for key, value in loaded_models.items():
if hasattr(value, 'to'):
value.to('cpu')
del value
loaded_models = {}
current_pair = None
# Force garbage collection
gc.collect()
# Clear GPU cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "β
Models cleared from memory"
except Exception as e:
return f"β Error clearing models: {str(e)}"
def load_models(model_pair):
global loaded_models, current_pair
if current_pair == model_pair:
return "Models already loaded!"
try:
# Clear existing models first if switching
if loaded_models:
clear_models()
main_model_name, aux_model_name = MODEL_PAIRS[model_pair]
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load auxiliary model
aux_tokenizer = AutoTokenizer.from_pretrained(aux_model_name)
aux_model = AutoModelForCausalLM.from_pretrained(aux_model_name, resume_download = True).to(device)
# Load main model
main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
main_model = AutoModelForCausalLM.from_pretrained(main_model_name, resume_download=True).to(device)
# Create processor
processor = TokenSwapProcessor(aux_model, main_tokenizer, aux_tokenizer=aux_tokenizer)
loaded_models = {
'main_model': main_model,
'main_tokenizer': main_tokenizer,
'processor': processor
}
current_pair = model_pair
return f"β
Loaded {model_pair}"
except Exception as e:
return f"β Error: {str(e)}"
def generate_text(prompt, max_tokens, use_tokenswap):
if not loaded_models:
return "Please load models first!"
try:
inputs = loaded_models['main_tokenizer'](prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = inputs.to("cuda")
logits_processor = [loaded_models['processor']] if use_tokenswap else []
outputs = loaded_models['main_model'].generate(
inputs.input_ids,
logits_processor=logits_processor,
max_new_tokens=max_tokens,
do_sample=False,
pad_token_id=loaded_models['main_tokenizer'].eos_token_id
)
result = loaded_models['main_tokenizer'].decode(outputs[0], skip_special_tokens=True)
return result[len(prompt):] # Return only generated part
except Exception as e:
return f"Error generating: {str(e)}"
def compare_outputs(prompt, max_tokens):
standard = generate_text(prompt, max_tokens, False)
tokenswap = generate_text(prompt, max_tokens, True)
return standard, tokenswap
# Gradio interface
with gr.Blocks(title="Verbatim-LLM Demo") as app:
gr.Markdown("# Verbatim-LLM: Mitigate Memorization in LLMs")
gr.Markdown("Compare standard generation vs TokenSwap method")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODEL_PAIRS.keys()),
value=list(MODEL_PAIRS.keys())[0],
label="Model Pair"
)
with gr.Column():
load_btn = gr.Button("Load Models", variant="primary")
clear_btn = gr.Button("Clear Models", variant="secondary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
prompt_box = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3
)
max_tokens = gr.Slider(10, 200, value=100, label="Max Tokens")
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
compare_btn = gr.Button("Compare Both", variant="secondary")
with gr.Row():
standard_output = gr.Textbox(label="Standard Generation", lines=5)
tokenswap_output = gr.Textbox(label="TokenSwap Generation", lines=5)
# Event handlers
load_btn.click(
fn=load_models,
inputs=[model_dropdown],
outputs=[status]
)
clear_btn.click(
fn=clear_models,
outputs=[status]
)
generate_btn.click(
fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)),
inputs=[prompt_box, max_tokens],
outputs=[standard_output, tokenswap_output]
)
compare_btn.click(
fn=compare_outputs,
inputs=[prompt_box, max_tokens],
outputs=[standard_output, tokenswap_output]
)
if __name__ == "__main__":
app.launch() |