File size: 5,682 Bytes
30bc257
 
 
 
40ba062
30bc257
 
 
 
 
3fafd04
2bf8fe3
30bc257
 
 
 
 
 
40ba062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30bc257
 
 
 
 
 
 
40ba062
 
 
 
30bc257
 
 
 
 
700396f
30bc257
 
 
700396f
30bc257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40ba062
 
 
30bc257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40ba062
 
 
 
 
30bc257
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from verbatim_llm import TokenSwapProcessor
import gc

# Predefined model pairs
MODEL_PAIRS = {
    "Pythia 6.9B + 70M": ("EleutherAI/pythia-6.9b", "EleutherAI/pythia-70m"),
    "OLMo-2 13B Instruct + SmolLM 135M Instruct": ("allenai/OLMo-2-1124-13B-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct"),
    "DeepSeek 7B Chat + SmolLM 135M Instruct": ("deepseek-ai/deepseek-llm-7b-chat", "HuggingFaceTB/SmolLM-135M-Instruct"),
    "DeepSeek 70B Chat + SmolLM 135M Instruct": ("deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "HuggingFaceTB/SmolLM-135M-Instruct"),
}

# Global variables to store loaded models
loaded_models = {}
current_pair = None

def clear_models():
    global loaded_models, current_pair
    
    try:
        # Clear models from memory
        if loaded_models:
            # Move models to CPU if they were on GPU
            for key, value in loaded_models.items():
                if hasattr(value, 'to'):
                    value.to('cpu')
                del value
        
        loaded_models = {}
        current_pair = None
        
        # Force garbage collection
        gc.collect()
        
        # Clear GPU cache if available
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return "βœ… Models cleared from memory"
    except Exception as e:
        return f"❌ Error clearing models: {str(e)}"

def load_models(model_pair):
    global loaded_models, current_pair
    
    if current_pair == model_pair:
        return "Models already loaded!"
    
    try:
        # Clear existing models first if switching
        if loaded_models:
            clear_models()
        
        main_model_name, aux_model_name = MODEL_PAIRS[model_pair]
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Load auxiliary model
        aux_tokenizer = AutoTokenizer.from_pretrained(aux_model_name)
        aux_model = AutoModelForCausalLM.from_pretrained(aux_model_name, resume_download = True).to(device)
        
        # Load main model
        main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
        main_model = AutoModelForCausalLM.from_pretrained(main_model_name, resume_download=True).to(device)
        
        # Create processor
        processor = TokenSwapProcessor(aux_model, main_tokenizer, aux_tokenizer=aux_tokenizer)
        
        loaded_models = {
            'main_model': main_model,
            'main_tokenizer': main_tokenizer,
            'processor': processor
        }
        current_pair = model_pair
        
        return f"βœ… Loaded {model_pair}"
    except Exception as e:
        return f"❌ Error: {str(e)}"

def generate_text(prompt, max_tokens, use_tokenswap):
    if not loaded_models:
        return "Please load models first!"
    
    try:
        inputs = loaded_models['main_tokenizer'](prompt, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = inputs.to("cuda")
        
        logits_processor = [loaded_models['processor']] if use_tokenswap else []
        
        outputs = loaded_models['main_model'].generate(
            inputs.input_ids,
            logits_processor=logits_processor,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=loaded_models['main_tokenizer'].eos_token_id
        )
        
        result = loaded_models['main_tokenizer'].decode(outputs[0], skip_special_tokens=True)
        return result[len(prompt):]  # Return only generated part
        
    except Exception as e:
        return f"Error generating: {str(e)}"

def compare_outputs(prompt, max_tokens):
    standard = generate_text(prompt, max_tokens, False)
    tokenswap = generate_text(prompt, max_tokens, True)
    return standard, tokenswap

# Gradio interface
with gr.Blocks(title="Verbatim-LLM Demo") as app:
    gr.Markdown("# Verbatim-LLM: Mitigate Memorization in LLMs")
    gr.Markdown("Compare standard generation vs TokenSwap method")
    
    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=list(MODEL_PAIRS.keys()),
            value=list(MODEL_PAIRS.keys())[0],
            label="Model Pair"
        )
        with gr.Column():
            load_btn = gr.Button("Load Models", variant="primary")
            clear_btn = gr.Button("Clear Models", variant="secondary")
    
    status = gr.Textbox(label="Status", interactive=False)
    
    with gr.Row():
        prompt_box = gr.Textbox(
            label="Prompt",
            placeholder="Enter your prompt here...",
            lines=3
        )
        max_tokens = gr.Slider(10, 200, value=100, label="Max Tokens")
    
    with gr.Row():
        generate_btn = gr.Button("Generate", variant="primary")
        compare_btn = gr.Button("Compare Both", variant="secondary")
    
    with gr.Row():
        standard_output = gr.Textbox(label="Standard Generation", lines=5)
        tokenswap_output = gr.Textbox(label="TokenSwap Generation", lines=5)
    
    # Event handlers
    load_btn.click(
        fn=load_models,
        inputs=[model_dropdown],
        outputs=[status]
    )
    
    clear_btn.click(
        fn=clear_models,
        outputs=[status]
    )
    
    generate_btn.click(
        fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)),
        inputs=[prompt_box, max_tokens],
        outputs=[standard_output, tokenswap_output]
    )
    
    compare_btn.click(
        fn=compare_outputs,
        inputs=[prompt_box, max_tokens],
        outputs=[standard_output, tokenswap_output]
    )

if __name__ == "__main__":
    app.launch()