parjanya20 commited on
Commit
30bc257
·
verified ·
1 Parent(s): 614654f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from verbatim_llm import TokenSwapProcessor
5
+
6
+ # Predefined model pairs
7
+ MODEL_PAIRS = {
8
+ "Pythia 6.9B + 70M": ("EleutherAI/pythia-6.9b", "EleutherAI/pythia-70m"),
9
+ "OLMo-2 13B Instruct + SmolLM 135M Instruct": ("allenai/OLMo-2-1124-13B-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct"),
10
+ "DeepSeek 7B Chat + SmolLM 135M Instruct": ("deepseek-ai/deepseek-llm-7b-chat", "HuggingFaceTB/SmolLM-135M-Instruct")
11
+ }
12
+
13
+ # Global variables to store loaded models
14
+ loaded_models = {}
15
+ current_pair = None
16
+
17
+ def load_models(model_pair):
18
+ global loaded_models, current_pair
19
+
20
+ if current_pair == model_pair:
21
+ return "Models already loaded!"
22
+
23
+ try:
24
+ main_model_name, aux_model_name = MODEL_PAIRS[model_pair]
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # Load auxiliary model
28
+ aux_tokenizer = AutoTokenizer.from_pretrained(aux_model_name)
29
+ aux_model = AutoModelForCausalLM.from_pretrained(aux_model_name).to(device)
30
+
31
+ # Load main model
32
+ main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
33
+ main_model = AutoModelForCausalLM.from_pretrained(main_model_name).to(device)
34
+
35
+ # Create processor
36
+ processor = TokenSwapProcessor(aux_model, main_tokenizer, aux_tokenizer=aux_tokenizer)
37
+
38
+ loaded_models = {
39
+ 'main_model': main_model,
40
+ 'main_tokenizer': main_tokenizer,
41
+ 'processor': processor
42
+ }
43
+ current_pair = model_pair
44
+
45
+ return f"✅ Loaded {model_pair}"
46
+ except Exception as e:
47
+ return f"❌ Error: {str(e)}"
48
+
49
+ def generate_text(prompt, max_tokens, use_tokenswap):
50
+ if not loaded_models:
51
+ return "Please load models first!"
52
+
53
+ try:
54
+ inputs = loaded_models['main_tokenizer'](prompt, return_tensors="pt")
55
+ if torch.cuda.is_available():
56
+ inputs = inputs.to("cuda")
57
+
58
+ logits_processor = [loaded_models['processor']] if use_tokenswap else []
59
+
60
+ outputs = loaded_models['main_model'].generate(
61
+ inputs.input_ids,
62
+ logits_processor=logits_processor,
63
+ max_new_tokens=max_tokens,
64
+ do_sample=False,
65
+ pad_token_id=loaded_models['main_tokenizer'].eos_token_id
66
+ )
67
+
68
+ result = loaded_models['main_tokenizer'].decode(outputs[0], skip_special_tokens=True)
69
+ return result[len(prompt):] # Return only generated part
70
+
71
+ except Exception as e:
72
+ return f"Error generating: {str(e)}"
73
+
74
+ def compare_outputs(prompt, max_tokens):
75
+ standard = generate_text(prompt, max_tokens, False)
76
+ tokenswap = generate_text(prompt, max_tokens, True)
77
+ return standard, tokenswap
78
+
79
+ # Gradio interface
80
+ with gr.Blocks(title="Verbatim-LLM Demo") as app:
81
+ gr.Markdown("# Verbatim-LLM: Mitigate Memorization in LLMs")
82
+ gr.Markdown("Compare standard generation vs TokenSwap method")
83
+
84
+ with gr.Row():
85
+ model_dropdown = gr.Dropdown(
86
+ choices=list(MODEL_PAIRS.keys()),
87
+ value=list(MODEL_PAIRS.keys())[0],
88
+ label="Model Pair"
89
+ )
90
+ load_btn = gr.Button("Load Models", variant="primary")
91
+
92
+ status = gr.Textbox(label="Status", interactive=False)
93
+
94
+ with gr.Row():
95
+ prompt_box = gr.Textbox(
96
+ label="Prompt",
97
+ placeholder="Enter your prompt here...",
98
+ lines=3
99
+ )
100
+ max_tokens = gr.Slider(10, 200, value=100, label="Max Tokens")
101
+
102
+ with gr.Row():
103
+ generate_btn = gr.Button("Generate", variant="primary")
104
+ compare_btn = gr.Button("Compare Both", variant="secondary")
105
+
106
+ with gr.Row():
107
+ standard_output = gr.Textbox(label="Standard Generation", lines=5)
108
+ tokenswap_output = gr.Textbox(label="TokenSwap Generation", lines=5)
109
+
110
+ # Event handlers
111
+ load_btn.click(
112
+ fn=load_models,
113
+ inputs=[model_dropdown],
114
+ outputs=[status]
115
+ )
116
+
117
+ generate_btn.click(
118
+ fn=lambda p, t: (generate_text(p, t, False), generate_text(p, t, True)),
119
+ inputs=[prompt_box, max_tokens],
120
+ outputs=[standard_output, tokenswap_output]
121
+ )
122
+
123
+ compare_btn.click(
124
+ fn=compare_outputs,
125
+ inputs=[prompt_box, max_tokens],
126
+ outputs=[standard_output, tokenswap_output]
127
+ )
128
+
129
+ if __name__ == "__main__":
130
+ app.launch()