import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from verbatim_llm import TokenSwapProcessor import gc # Predefined model pairs MODEL_PAIRS = { "Pythia 6.9B + 70M": ("EleutherAI/pythia-6.9b", "EleutherAI/pythia-70m"), "OLMo-2 13B Instruct + SmolLM 135M Instruct": ("allenai/OLMo-2-1124-13B-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct"), "DeepSeek 7B Chat + SmolLM 135M Instruct": ("deepseek-ai/deepseek-llm-7b-chat", "HuggingFaceTB/SmolLM-135M-Instruct"), "DeepSeek 70B Chat + SmolLM 135M Instruct": ("deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "HuggingFaceTB/SmolLM-135M-Instruct"), } # Global variables to store loaded models loaded_models = {} current_pair = None def clear_models(): global loaded_models, current_pair try: # Clear models from memory if loaded_models: # Move models to CPU if they were on GPU for key, value in loaded_models.items(): if hasattr(value, 'to'): value.to('cpu') del value loaded_models = {} current_pair = None # Force garbage collection gc.collect() # Clear GPU cache if available if torch.cuda.is_available(): torch.cuda.empty_cache() return "✅ Models cleared from memory" except Exception as e: return f"❌ Error clearing models: {str(e)}" def load_models(model_pair): global loaded_models, current_pair if current_pair == model_pair: return "Models already loaded!" try: # Clear existing models first if switching if loaded_models: clear_models() main_model_name, aux_model_name = MODEL_PAIRS[model_pair] device = "cuda" if torch.cuda.is_available() else "cpu" # Load auxiliary model aux_tokenizer = AutoTokenizer.from_pretrained(aux_model_name) aux_model = AutoModelForCausalLM.from_pretrained(aux_model_name, resume_download = True).to(device) # Load main model main_tokenizer = AutoTokenizer.from_pretrained(main_model_name) main_model = AutoModelForCausalLM.from_pretrained(main_model_name, resume_download=True).to(device) # Create processor processor = TokenSwapProcessor(aux_model, main_tokenizer, aux_tokenizer=aux_tokenizer) loaded_models = { 'main_model': main_model, 'main_tokenizer': main_tokenizer, 'processor': processor } current_pair = model_pair return f"✅ Loaded {model_pair}" except Exception as e: return f"❌ Error: {str(e)}" def generate_text(prompt, max_tokens, use_tokenswap): if not loaded_models: return "Please load models first!" try: inputs = loaded_models['main_tokenizer'](prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = inputs.to("cuda") logits_processor = [loaded_models['processor']] if use_tokenswap else [] outputs = loaded_models['main_model'].generate( inputs.input_ids, logits_processor=logits_processor, max_new_tokens=max_tokens, do_sample=False, pad_token_id=loaded_models['main_tokenizer'].eos_token_id ) result = loaded_models['main_tokenizer'].decode(outputs[0], skip_special_tokens=True) return result[len(prompt):] # Return only generated part except Exception as e: return f"Error generating: {str(e)}" def compare_outputs(prompt, max_tokens): standard = generate_text(prompt, max_tokens, False) tokenswap = generate_text(prompt, max_tokens, True) return standard, tokenswap # Gradio interface with gr.Blocks(title="Verbatim-LLM Demo") as app: gr.Markdown("# Verbatim-LLM: Mitigate Memorization in LLMs") gr.Markdown("Compare standard generation vs TokenSwap method") with gr.Row(): model_dropdown = gr.Dropdown( choices=list(MODEL_PAIRS.keys()), value=list(MODEL_PAIRS.keys())[0], label="Model Pair" ) with gr.Column(): load_btn = gr.Button("Load Models", variant="primary") clear_btn = gr.Button("Clear Models", variant="secondary") status = gr.Textbox(label="Status", interactive=False) with gr.Row(): prompt_box = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=3 ) max_tokens = gr.Slider(10, 200, value=100, label="Max Tokens") with gr.Row(): generate_btn = gr.Button("Generate", variant="primary") compare_btn = gr.Button("Compare Both", variant="secondary") with gr.Row(): standard_output = gr.Textbox(label="Standard Generation", lines=5) tokenswap_output = gr.Textbox(label="TokenSwap Generation", lines=5) # Event handlers load_btn.click( fn=load_models, inputs=[model_dropdown], outputs=[status] ) clear_btn.click( fn=clear_models, outputs=[status] ) generate_btn.click( fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)), inputs=[prompt_box, max_tokens], outputs=[standard_output, tokenswap_output] ) compare_btn.click( fn=compare_outputs, inputs=[prompt_box, max_tokens], outputs=[standard_output, tokenswap_output] ) if __name__ == "__main__": app.launch()