TokenSwap / app.py
parjanya20's picture
Update app.py
2bf8fe3 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from verbatim_llm import TokenSwapProcessor
import gc
# Predefined model pairs
MODEL_PAIRS = {
"Pythia 6.9B + 70M": ("EleutherAI/pythia-6.9b", "EleutherAI/pythia-70m"),
"OLMo-2 13B Instruct + SmolLM 135M Instruct": ("allenai/OLMo-2-1124-13B-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct"),
"DeepSeek 7B Chat + SmolLM 135M Instruct": ("deepseek-ai/deepseek-llm-7b-chat", "HuggingFaceTB/SmolLM-135M-Instruct"),
"DeepSeek 70B Chat + SmolLM 135M Instruct": ("deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "HuggingFaceTB/SmolLM-135M-Instruct"),
}
# Global variables to store loaded models
loaded_models = {}
current_pair = None
def clear_models():
global loaded_models, current_pair
try:
# Clear models from memory
if loaded_models:
# Move models to CPU if they were on GPU
for key, value in loaded_models.items():
if hasattr(value, 'to'):
value.to('cpu')
del value
loaded_models = {}
current_pair = None
# Force garbage collection
gc.collect()
# Clear GPU cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "βœ… Models cleared from memory"
except Exception as e:
return f"❌ Error clearing models: {str(e)}"
def load_models(model_pair):
global loaded_models, current_pair
if current_pair == model_pair:
return "Models already loaded!"
try:
# Clear existing models first if switching
if loaded_models:
clear_models()
main_model_name, aux_model_name = MODEL_PAIRS[model_pair]
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load auxiliary model
aux_tokenizer = AutoTokenizer.from_pretrained(aux_model_name)
aux_model = AutoModelForCausalLM.from_pretrained(aux_model_name, resume_download = True).to(device)
# Load main model
main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
main_model = AutoModelForCausalLM.from_pretrained(main_model_name, resume_download=True).to(device)
# Create processor
processor = TokenSwapProcessor(aux_model, main_tokenizer, aux_tokenizer=aux_tokenizer)
loaded_models = {
'main_model': main_model,
'main_tokenizer': main_tokenizer,
'processor': processor
}
current_pair = model_pair
return f"βœ… Loaded {model_pair}"
except Exception as e:
return f"❌ Error: {str(e)}"
def generate_text(prompt, max_tokens, use_tokenswap):
if not loaded_models:
return "Please load models first!"
try:
inputs = loaded_models['main_tokenizer'](prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = inputs.to("cuda")
logits_processor = [loaded_models['processor']] if use_tokenswap else []
outputs = loaded_models['main_model'].generate(
inputs.input_ids,
logits_processor=logits_processor,
max_new_tokens=max_tokens,
do_sample=False,
pad_token_id=loaded_models['main_tokenizer'].eos_token_id
)
result = loaded_models['main_tokenizer'].decode(outputs[0], skip_special_tokens=True)
return result[len(prompt):] # Return only generated part
except Exception as e:
return f"Error generating: {str(e)}"
def compare_outputs(prompt, max_tokens):
standard = generate_text(prompt, max_tokens, False)
tokenswap = generate_text(prompt, max_tokens, True)
return standard, tokenswap
# Gradio interface
with gr.Blocks(title="Verbatim-LLM Demo") as app:
gr.Markdown("# Verbatim-LLM: Mitigate Memorization in LLMs")
gr.Markdown("Compare standard generation vs TokenSwap method")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(MODEL_PAIRS.keys()),
value=list(MODEL_PAIRS.keys())[0],
label="Model Pair"
)
with gr.Column():
load_btn = gr.Button("Load Models", variant="primary")
clear_btn = gr.Button("Clear Models", variant="secondary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
prompt_box = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3
)
max_tokens = gr.Slider(10, 200, value=100, label="Max Tokens")
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
compare_btn = gr.Button("Compare Both", variant="secondary")
with gr.Row():
standard_output = gr.Textbox(label="Standard Generation", lines=5)
tokenswap_output = gr.Textbox(label="TokenSwap Generation", lines=5)
# Event handlers
load_btn.click(
fn=load_models,
inputs=[model_dropdown],
outputs=[status]
)
clear_btn.click(
fn=clear_models,
outputs=[status]
)
generate_btn.click(
fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)),
inputs=[prompt_box, max_tokens],
outputs=[standard_output, tokenswap_output]
)
compare_btn.click(
fn=compare_outputs,
inputs=[prompt_box, max_tokens],
outputs=[standard_output, tokenswap_output]
)
if __name__ == "__main__":
app.launch()