Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from verbatim_llm import TokenSwapProcessor | |
| import gc | |
| # Predefined model pairs | |
| MODEL_PAIRS = { | |
| "Pythia 6.9B + 70M": ("EleutherAI/pythia-6.9b", "EleutherAI/pythia-70m"), | |
| "OLMo-2 13B Instruct + SmolLM 135M Instruct": ("allenai/OLMo-2-1124-13B-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct"), | |
| "DeepSeek 7B Chat + SmolLM 135M Instruct": ("deepseek-ai/deepseek-llm-7b-chat", "HuggingFaceTB/SmolLM-135M-Instruct"), | |
| "DeepSeek 70B Chat + SmolLM 135M Instruct": ("deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "HuggingFaceTB/SmolLM-135M-Instruct"), | |
| } | |
| # Global variables to store loaded models | |
| loaded_models = {} | |
| current_pair = None | |
| def clear_models(): | |
| global loaded_models, current_pair | |
| try: | |
| # Clear models from memory | |
| if loaded_models: | |
| # Move models to CPU if they were on GPU | |
| for key, value in loaded_models.items(): | |
| if hasattr(value, 'to'): | |
| value.to('cpu') | |
| del value | |
| loaded_models = {} | |
| current_pair = None | |
| # Force garbage collection | |
| gc.collect() | |
| # Clear GPU cache if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return "β Models cleared from memory" | |
| except Exception as e: | |
| return f"β Error clearing models: {str(e)}" | |
| def load_models(model_pair): | |
| global loaded_models, current_pair | |
| if current_pair == model_pair: | |
| return "Models already loaded!" | |
| try: | |
| # Clear existing models first if switching | |
| if loaded_models: | |
| clear_models() | |
| main_model_name, aux_model_name = MODEL_PAIRS[model_pair] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load auxiliary model | |
| aux_tokenizer = AutoTokenizer.from_pretrained(aux_model_name) | |
| aux_model = AutoModelForCausalLM.from_pretrained(aux_model_name, resume_download = True).to(device) | |
| # Load main model | |
| main_tokenizer = AutoTokenizer.from_pretrained(main_model_name) | |
| main_model = AutoModelForCausalLM.from_pretrained(main_model_name, resume_download=True).to(device) | |
| # Create processor | |
| processor = TokenSwapProcessor(aux_model, main_tokenizer, aux_tokenizer=aux_tokenizer) | |
| loaded_models = { | |
| 'main_model': main_model, | |
| 'main_tokenizer': main_tokenizer, | |
| 'processor': processor | |
| } | |
| current_pair = model_pair | |
| return f"β Loaded {model_pair}" | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| def generate_text(prompt, max_tokens, use_tokenswap): | |
| if not loaded_models: | |
| return "Please load models first!" | |
| try: | |
| inputs = loaded_models['main_tokenizer'](prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to("cuda") | |
| logits_processor = [loaded_models['processor']] if use_tokenswap else [] | |
| outputs = loaded_models['main_model'].generate( | |
| inputs.input_ids, | |
| logits_processor=logits_processor, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| pad_token_id=loaded_models['main_tokenizer'].eos_token_id | |
| ) | |
| result = loaded_models['main_tokenizer'].decode(outputs[0], skip_special_tokens=True) | |
| return result[len(prompt):] # Return only generated part | |
| except Exception as e: | |
| return f"Error generating: {str(e)}" | |
| def compare_outputs(prompt, max_tokens): | |
| standard = generate_text(prompt, max_tokens, False) | |
| tokenswap = generate_text(prompt, max_tokens, True) | |
| return standard, tokenswap | |
| # Gradio interface | |
| with gr.Blocks(title="Verbatim-LLM Demo") as app: | |
| gr.Markdown("# Verbatim-LLM: Mitigate Memorization in LLMs") | |
| gr.Markdown("Compare standard generation vs TokenSwap method") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_PAIRS.keys()), | |
| value=list(MODEL_PAIRS.keys())[0], | |
| label="Model Pair" | |
| ) | |
| with gr.Column(): | |
| load_btn = gr.Button("Load Models", variant="primary") | |
| clear_btn = gr.Button("Clear Models", variant="secondary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3 | |
| ) | |
| max_tokens = gr.Slider(10, 200, value=100, label="Max Tokens") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| compare_btn = gr.Button("Compare Both", variant="secondary") | |
| with gr.Row(): | |
| standard_output = gr.Textbox(label="Standard Generation", lines=5) | |
| tokenswap_output = gr.Textbox(label="TokenSwap Generation", lines=5) | |
| # Event handlers | |
| load_btn.click( | |
| fn=load_models, | |
| inputs=[model_dropdown], | |
| outputs=[status] | |
| ) | |
| clear_btn.click( | |
| fn=clear_models, | |
| outputs=[status] | |
| ) | |
| generate_btn.click( | |
| fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)), | |
| inputs=[prompt_box, max_tokens], | |
| outputs=[standard_output, tokenswap_output] | |
| ) | |
| compare_btn.click( | |
| fn=compare_outputs, | |
| inputs=[prompt_box, max_tokens], | |
| outputs=[standard_output, tokenswap_output] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |