llaa33219 commited on
Commit
60026f3
·
verified ·
1 Parent(s): fe6b518

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +27 -7
  2. app.py +351 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,14 +1,34 @@
1
  ---
2
  title: Context Window Extender
3
- emoji: 🐨
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
- python_version: '3.12'
9
  app_file: app.py
 
10
  pinned: false
11
- short_description: llm context-window-extender
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Context Window Extender
3
+ emoji: 🧠
4
+ colorFrom: purple
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.44.0
 
8
  app_file: app.py
9
+ suggested_hardware: cpu-basic
10
  pinned: false
 
11
  ---
12
 
13
+ # Context Window Extender
14
+
15
+ Load any causal language model from Hugging Face Hub and extend its context window.
16
+
17
+ ## Features
18
+
19
+ - **Model Loading**: Enter any Hugging Face model ID
20
+ - **Context Extension**:
21
+ - Raw: Simply increase max_position_embeddings
22
+ - RoPE: Apply RoPE scaling (linear, dynamic, yarn)
23
+ - **CPU Only**: Runs on free CPU hardware
24
+
25
+ ## Usage
26
+
27
+ 1. Enter a Hugging Face model ID (e.g., `gpt2`, `meta-llama/Llama-2-7b-hf`)
28
+ 2. Choose extension method:
29
+ - **None**: Use original context
30
+ - **Raw**: Increase max_position_embeddings
31
+ - **RoPE**: Apply RoPE scaling
32
+ 3. If RoPE selected, choose type and factor
33
+ 4. Set target context length
34
+ 5. Enter prompt and click Generate
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
+ import warnings
5
+ import os
6
+
7
+ warnings.filterwarnings("ignore")
8
+
9
+ # Global model cache to avoid reloading
10
+ model_cache = {}
11
+
12
+ def load_model_with_extension(
13
+ model_id: str,
14
+ extension_method: str,
15
+ new_context_length: int,
16
+ rope_type: str,
17
+ rope_factor: float
18
+ ):
19
+ """
20
+ Load model with optional context window extension.
21
+
22
+ Args:
23
+ model_id: Hugging Face model ID
24
+ extension_method: "none", "raw", or "rope"
25
+ new_context_length: Target context length
26
+ rope_type: "linear", "dynamic", or "yarn"
27
+ rope_factor: RoPE scaling factor
28
+ """
29
+
30
+ # Create cache key based on all parameters
31
+ cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}"
32
+
33
+ if cache_key in model_cache:
34
+ return model_cache[cache_key]
35
+
36
+ # Load tokenizer
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ model_id,
39
+ trust_remote_code=True
40
+ )
41
+
42
+ if tokenizer.pad_token is None:
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+
45
+ # Load config and modify
46
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
47
+
48
+ original_context = getattr(config, "max_position_embeddings", 4096)
49
+
50
+ # Apply extension based on method
51
+ if extension_method == "raw":
52
+ # Raw extension: just increase max_position_embeddings
53
+ config.max_position_embeddings = new_context_length
54
+
55
+ elif extension_method == "rope":
56
+ # RoPE scaling extension
57
+ config.max_position_embeddings = new_context_length
58
+
59
+ # Set RoPE scaling if model supports it
60
+ if hasattr(config, "rope_theta"):
61
+ # Get original rope_theta
62
+ original_theta = getattr(config, "rope_theta", 10000.0)
63
+
64
+ # Apply scaling based on type
65
+ if rope_type == "linear":
66
+ # Linear scaling - adjust theta by factor
67
+ config.rope_theta = original_theta * rope_factor
68
+ elif rope_type == "dynamic":
69
+ # Dynamic scaling - use higher base frequency
70
+ config.rope_theta = original_theta * (rope_factor - 1) + original_theta * rope_factor
71
+ elif rope_type == "yarn":
72
+ # YaRN - more sophisticated scaling
73
+ config.rope_scaling = {
74
+ "type": "yarn",
75
+ "factor": rope_factor,
76
+ "original_max_position_embeddings": original_context,
77
+ "attn_factor": 1.0,
78
+ "beta_fast": 32.0,
79
+ "beta_slow": 1.0,
80
+ }
81
+ config.rope_theta = original_theta
82
+
83
+ # Load model on CPU
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ model_id,
86
+ config=config,
87
+ torch_dtype=torch.float32,
88
+ device_map="cpu",
89
+ low_cpu_mem_usage=True,
90
+ trust_remote_code=True,
91
+ )
92
+ model.eval()
93
+
94
+ result = {
95
+ "model": model,
96
+ "tokenizer": tokenizer,
97
+ "original_context": original_context,
98
+ "applied_context": new_context_length,
99
+ "extension_method": extension_method
100
+ }
101
+
102
+ model_cache[cache_key] = result
103
+ return result
104
+
105
+
106
+ def generate(
107
+ model_id: str,
108
+ extension_method: str,
109
+ new_context_length: int,
110
+ rope_type: str,
111
+ rope_factor: float,
112
+ prompt: str,
113
+ max_new_tokens: int,
114
+ temperature: float,
115
+ top_p: float,
116
+ ):
117
+ """
118
+ Generate text with the loaded model.
119
+ """
120
+
121
+ # Validate inputs
122
+ if not model_id.strip():
123
+ return "Error: Please enter a model ID"
124
+
125
+ if not prompt.strip():
126
+ return "Error: Please enter a prompt"
127
+
128
+ # Set default context length if not provided
129
+ if new_context_length <= 0:
130
+ new_context_length = 4096
131
+
132
+ # Load or get model from cache
133
+ try:
134
+ model_data = load_model_with_extension(
135
+ model_id,
136
+ extension_method,
137
+ new_context_length,
138
+ rope_type,
139
+ rope_factor
140
+ )
141
+ except Exception as e:
142
+ return f"Error loading model: {str(e)}"
143
+
144
+ model = model_data["model"]
145
+ tokenizer = model_data["tokenizer"]
146
+
147
+ # Tokenize input
148
+ try:
149
+ inputs = tokenizer(
150
+ prompt,
151
+ return_tensors="pt",
152
+ truncation=False,
153
+ padding=False
154
+ )
155
+ except Exception as e:
156
+ return f"Error tokenizing input: {str(e)}"
157
+
158
+ # Generate
159
+ try:
160
+ with torch.no_grad():
161
+ outputs = model.generate(
162
+ **inputs,
163
+ max_new_tokens=max_new_tokens,
164
+ temperature=temperature,
165
+ top_p=top_p,
166
+ do_sample=temperature > 0,
167
+ pad_token_id=tokenizer.pad_token_id,
168
+ eos_token_id=tokenizer.eos_token_id,
169
+ )
170
+
171
+ # Decode output
172
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
173
+
174
+ # If generation is same as input, return a message
175
+ if generated_text.strip() == prompt.strip():
176
+ return "Model generated same text as input. Try adjusting parameters."
177
+
178
+ return generated_text
179
+
180
+ except Exception as e:
181
+ return f"Error during generation: {str(e)}"
182
+
183
+
184
+ def update_rope_options(extension_method: str):
185
+ """
186
+ Update visibility of RoPE options based on extension method.
187
+ """
188
+ if extension_method == "rope":
189
+ return gr.update(visible=True)
190
+ else:
191
+ return gr.update(visible=False)
192
+
193
+
194
+ # Build Gradio UI
195
+ with gr.Blocks(title="Context Window Extender") as demo:
196
+ gr.Markdown("""
197
+ # 🧠 Model Context Window Extender
198
+
199
+ Load any causal language model from Hugging Face Hub and extend its context window.
200
+ Supports both **Raw Extension** and **RoPE Scaling** methods.
201
+
202
+ **Extension Methods:**
203
+ - **None**: Use model's original context length
204
+ - **Raw**: Simply increase max_position_embeddings (simple but may degrade quality)
205
+ - **RoPE**: Apply RoPE scaling for better quality (supports linear, dynamic, yarn)
206
+ """)
207
+
208
+ with gr.Row():
209
+ with gr.Column(scale=2):
210
+ model_id = gr.Textbox(
211
+ label="🤗 Model ID",
212
+ placeholder="meta-llama/Llama-2-7b-hf, gpt2, EleutherAI/gpt-neo-1.3B",
213
+ value="gpt2",
214
+ info="Enter Hugging Face model ID"
215
+ )
216
+ gr.Examples(
217
+ examples=[
218
+ ["gpt2"],
219
+ ["EleutherAI/gpt-neo-1.3B"],
220
+ ["microsoft/phi-2"],
221
+ ["facebook/opt-1.3b"],
222
+ ],
223
+ inputs=model_id
224
+ )
225
+
226
+ with gr.Column(scale=1):
227
+ extension_method = gr.Radio(
228
+ choices=["none", "raw", "rope"],
229
+ value="none",
230
+ label="Extension Method",
231
+ info="Choose how to extend context window"
232
+ )
233
+
234
+ # RoPE options (shown when rope is selected)
235
+ with gr.Row():
236
+ with gr.Column(scale=1):
237
+ rope_type = gr.Dropdown(
238
+ choices=["linear", "dynamic", "yarn"],
239
+ value="linear",
240
+ label="RoPE Type",
241
+ visible=False,
242
+ info="linear: simple scaling, dynamic: better quality, yarn: best quality"
243
+ )
244
+ with gr.Column(scale=1):
245
+ rope_factor = gr.Slider(
246
+ minimum=1.0,
247
+ maximum=8.0,
248
+ step=0.5,
249
+ value=2.0,
250
+ label="RoPE Factor",
251
+ visible=False,
252
+ info="Multiply context by this factor"
253
+ )
254
+
255
+ with gr.Row():
256
+ new_context_length = gr.Slider(
257
+ minimum=512,
258
+ maximum=32768,
259
+ step=512,
260
+ value=2048,
261
+ label="Target Context Length",
262
+ info="Desired context window size (tokens)"
263
+ )
264
+
265
+ with gr.Row():
266
+ with gr.Column():
267
+ prompt = gr.Textbox(
268
+ label="📝 Prompt",
269
+ lines=6,
270
+ placeholder="Enter your prompt here...",
271
+ info="Input text for generation"
272
+ )
273
+ with gr.Column():
274
+ with gr.Row():
275
+ max_new_tokens = gr.Slider(
276
+ minimum=10,
277
+ maximum=1024,
278
+ step=10,
279
+ value=100,
280
+ label="Max New Tokens"
281
+ )
282
+ with gr.Row():
283
+ temperature = gr.Slider(
284
+ minimum=0.0,
285
+ maximum=2.0,
286
+ step=0.1,
287
+ value=0.7,
288
+ label="Temperature"
289
+ )
290
+ with gr.Row():
291
+ top_p = gr.Slider(
292
+ minimum=0.0,
293
+ maximum=1.0,
294
+ step=0.05,
295
+ value=0.9,
296
+ label="Top-p"
297
+ )
298
+
299
+ generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
300
+
301
+ output = gr.Textbox(
302
+ label="📄 Generated Output",
303
+ lines=10
304
+ )
305
+
306
+ # Event handlers
307
+ extension_method.change(
308
+ fn=update_rope_options,
309
+ inputs=[extension_method],
310
+ outputs=[rope_type, rope_factor]
311
+ )
312
+
313
+ generate_btn.click(
314
+ fn=generate,
315
+ inputs=[
316
+ model_id,
317
+ extension_method,
318
+ new_context_length,
319
+ rope_type,
320
+ rope_factor,
321
+ prompt,
322
+ max_new_tokens,
323
+ temperature,
324
+ top_p
325
+ ],
326
+ outputs=[output]
327
+ )
328
+
329
+ # Also allow Enter key to generate
330
+ prompt.submit(
331
+ fn=generate,
332
+ inputs=[
333
+ model_id,
334
+ extension_method,
335
+ new_context_length,
336
+ rope_type,
337
+ rope_factor,
338
+ prompt,
339
+ max_new_tokens,
340
+ temperature,
341
+ top_p
342
+ ],
343
+ outputs=[output]
344
+ )
345
+
346
+ if __name__ == "__main__":
347
+ demo.launch(
348
+ server_name="0.0.0.0",
349
+ server_port=7860,
350
+ share=False
351
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.35.0
3
+ torch>=2.0.0
4
+ accelerate>=0.25.0