llaa33219 commited on
Commit
4b77ff5
·
verified ·
1 Parent(s): e9cb424

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +96 -274
  3. requirements.txt +1 -1
README.md CHANGED
@@ -3,8 +3,8 @@ title: Context Window Extender
3
  emoji: 🧠
4
  colorFrom: purple
5
  colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  suggested_hardware: cpu-basic
10
  pinned: false
 
3
  emoji: 🧠
4
  colorFrom: purple
5
  colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.40.0
8
  app_file: app.py
9
  suggested_hardware: cpu-basic
10
  pinned: false
app.py CHANGED
@@ -1,12 +1,20 @@
1
- import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
- import warnings
5
- import os
6
 
7
- warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Global model cache to avoid reloading
10
  model_cache = {}
11
 
12
  def load_model_with_extension(
@@ -16,60 +24,34 @@ def load_model_with_extension(
16
  rope_type: str,
17
  rope_factor: float
18
  ):
19
- """
20
- Load model with optional context window extension.
21
-
22
- Args:
23
- model_id: Hugging Face model ID
24
- extension_method: "none", "raw", or "rope"
25
- new_context_length: Target context length
26
- rope_type: "linear", "dynamic", or "yarn"
27
- rope_factor: RoPE scaling factor
28
- """
29
-
30
- # Create cache key based on all parameters
31
  cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}"
32
 
33
  if cache_key in model_cache:
34
  return model_cache[cache_key]
35
 
36
- # Load tokenizer
37
- tokenizer = AutoTokenizer.from_pretrained(
38
- model_id,
39
- trust_remote_code=True
40
- )
41
 
42
  if tokenizer.pad_token is None:
43
  tokenizer.pad_token = tokenizer.eos_token
44
 
45
- # Load config and modify
46
  config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
47
 
48
  original_context = getattr(config, "max_position_embeddings", 4096)
49
 
50
- # Apply extension based on method
51
  if extension_method == "raw":
52
- # Raw extension: just increase max_position_embeddings
53
  config.max_position_embeddings = new_context_length
54
 
55
  elif extension_method == "rope":
56
- # RoPE scaling extension
57
  config.max_position_embeddings = new_context_length
58
 
59
- # Set RoPE scaling if model supports it
60
  if hasattr(config, "rope_theta"):
61
- # Get original rope_theta
62
  original_theta = getattr(config, "rope_theta", 10000.0)
63
 
64
- # Apply scaling based on type
65
  if rope_type == "linear":
66
- # Linear scaling - adjust theta by factor
67
  config.rope_theta = original_theta * rope_factor
68
  elif rope_type == "dynamic":
69
- # Dynamic scaling - use higher base frequency
70
  config.rope_theta = original_theta * (rope_factor - 1) + original_theta * rope_factor
71
  elif rope_type == "yarn":
72
- # YaRN - more sophisticated scaling
73
  config.rope_scaling = {
74
  "type": "yarn",
75
  "factor": rope_factor,
@@ -80,7 +62,6 @@ def load_model_with_extension(
80
  }
81
  config.rope_theta = original_theta
82
 
83
- # Load model on CPU
84
  model = AutoModelForCausalLM.from_pretrained(
85
  model_id,
86
  config=config,
@@ -96,256 +77,97 @@ def load_model_with_extension(
96
  "tokenizer": tokenizer,
97
  "original_context": original_context,
98
  "applied_context": new_context_length,
99
- "extension_method": extension_method
100
  }
101
 
102
  model_cache[cache_key] = result
103
  return result
104
 
105
 
106
- def generate(
107
- model_id: str,
108
- extension_method: str,
109
- new_context_length: int,
110
- rope_type: str,
111
- rope_factor: float,
112
- prompt: str,
113
- max_new_tokens: int,
114
- temperature: float,
115
- top_p: float,
116
- ):
117
- """
118
- Generate text with the loaded model.
119
- """
120
-
121
- # Validate inputs
122
- if not model_id.strip():
123
- return "Error: Please enter a model ID"
124
-
125
- if not prompt.strip():
126
- return "Error: Please enter a prompt"
127
-
128
- # Set default context length if not provided
129
- if new_context_length <= 0:
130
- new_context_length = 4096
131
-
132
- # Load or get model from cache
133
- try:
134
- model_data = load_model_with_extension(
135
- model_id,
136
- extension_method,
137
- new_context_length,
138
- rope_type,
139
- rope_factor
140
- )
141
- except Exception as e:
142
- return f"Error loading model: {str(e)}"
143
-
144
- model = model_data["model"]
145
- tokenizer = model_data["tokenizer"]
146
-
147
- # Tokenize input
148
- try:
149
- inputs = tokenizer(
150
- prompt,
151
- return_tensors="pt",
152
- truncation=False,
153
- padding=False
154
  )
155
- except Exception as e:
156
- return f"Error tokenizing input: {str(e)}"
157
-
158
- # Generate
159
- try:
160
- with torch.no_grad():
161
- outputs = model.generate(
162
- **inputs,
163
- max_new_tokens=max_new_tokens,
164
- temperature=temperature,
165
- top_p=top_p,
166
- do_sample=temperature > 0,
167
- pad_token_id=tokenizer.pad_token_id,
168
- eos_token_id=tokenizer.eos_token_id,
169
- )
170
-
171
- # Decode output
172
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
173
-
174
- # If generation is same as input, return a message
175
- if generated_text.strip() == prompt.strip():
176
- return "Model generated same text as input. Try adjusting parameters."
177
-
178
- return generated_text
179
-
180
- except Exception as e:
181
- return f"Error during generation: {str(e)}"
182
 
 
183
 
184
- def update_rope_options(extension_method: str):
185
- """
186
- Update visibility of RoPE options based on extension method.
187
- """
188
- if extension_method == "rope":
189
- return gr.update(visible=True)
190
- else:
191
- return gr.update(visible=False)
192
 
 
 
193
 
194
- # Build Gradio UI
195
- with gr.Blocks(title="Context Window Extender") as demo:
196
- gr.Markdown("""
197
- # 🧠 Model Context Window Extender
198
-
199
- Load any causal language model from Hugging Face Hub and extend its context window.
200
- Supports both **Raw Extension** and **RoPE Scaling** methods.
201
-
202
- **Extension Methods:**
203
- - **None**: Use model's original context length
204
- - **Raw**: Simply increase max_position_embeddings (simple but may degrade quality)
205
- - **RoPE**: Apply RoPE scaling for better quality (supports linear, dynamic, yarn)
206
- """)
207
-
208
- with gr.Row():
209
- with gr.Column(scale=2):
210
- model_id = gr.Textbox(
211
- label="🤗 Model ID",
212
- placeholder="meta-llama/Llama-2-7b-hf, gpt2, EleutherAI/gpt-neo-1.3B",
213
- value="gpt2",
214
- info="Enter Hugging Face model ID"
215
- )
216
- gr.Examples(
217
- examples=[
218
- ["gpt2"],
219
- ["EleutherAI/gpt-neo-1.3B"],
220
- ["microsoft/phi-2"],
221
- ["facebook/opt-1.3b"],
222
- ],
223
- inputs=model_id
224
- )
225
-
226
- with gr.Column(scale=1):
227
- extension_method = gr.Radio(
228
- choices=["none", "raw", "rope"],
229
- value="none",
230
- label="Extension Method",
231
- info="Choose how to extend context window"
232
- )
233
-
234
- # RoPE options (shown when rope is selected)
235
- with gr.Row():
236
- with gr.Column(scale=1):
237
- rope_type = gr.Dropdown(
238
- choices=["linear", "dynamic", "yarn"],
239
- value="linear",
240
- label="RoPE Type",
241
- visible=False,
242
- info="linear: simple scaling, dynamic: better quality, yarn: best quality"
243
- )
244
- with gr.Column(scale=1):
245
- rope_factor = gr.Slider(
246
- minimum=1.0,
247
- maximum=8.0,
248
- step=0.5,
249
- value=2.0,
250
- label="RoPE Factor",
251
- visible=False,
252
- info="Multiply context by this factor"
253
- )
254
-
255
- with gr.Row():
256
- new_context_length = gr.Slider(
257
- minimum=512,
258
- maximum=32768,
259
- step=512,
260
- value=2048,
261
- label="Target Context Length",
262
- info="Desired context window size (tokens)"
263
- )
264
-
265
- with gr.Row():
266
- with gr.Column():
267
- prompt = gr.Textbox(
268
- label="📝 Prompt",
269
- lines=6,
270
- placeholder="Enter your prompt here...",
271
- info="Input text for generation"
272
- )
273
- with gr.Column():
274
- with gr.Row():
275
- max_new_tokens = gr.Slider(
276
- minimum=10,
277
- maximum=1024,
278
- step=10,
279
- value=100,
280
- label="Max New Tokens"
281
- )
282
- with gr.Row():
283
- temperature = gr.Slider(
284
- minimum=0.0,
285
- maximum=2.0,
286
- step=0.1,
287
- value=0.7,
288
- label="Temperature"
289
- )
290
- with gr.Row():
291
- top_p = gr.Slider(
292
- minimum=0.0,
293
- maximum=1.0,
294
- step=0.05,
295
- value=0.9,
296
- label="Top-p"
297
- )
298
-
299
- generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
300
-
301
- output = gr.Textbox(
302
- label="📄 Generated Output",
303
- lines=10
304
- )
305
-
306
- # Event handlers
307
- extension_method.change(
308
- fn=update_rope_options,
309
- inputs=[extension_method],
310
- outputs=[rope_type, rope_factor]
311
- )
312
-
313
- generate_btn.click(
314
- fn=generate,
315
- inputs=[
316
- model_id,
317
- extension_method,
318
- new_context_length,
319
- rope_type,
320
- rope_factor,
321
- prompt,
322
- max_new_tokens,
323
- temperature,
324
- top_p
325
- ],
326
- outputs=[output]
327
- )
328
-
329
- # Also allow Enter key to generate
330
- prompt.submit(
331
- fn=generate,
332
- inputs=[
333
- model_id,
334
- extension_method,
335
- new_context_length,
336
- rope_type,
337
- rope_factor,
338
- prompt,
339
- max_new_tokens,
340
- temperature,
341
- top_p
342
- ],
343
- outputs=[output]
344
- )
345
 
346
- if __name__ == "__main__":
347
- demo.launch(
348
- server_name="0.0.0.0",
349
- server_port=7860,
350
- share=False
351
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
 
 
4
 
5
+ st.set_page_config(page_title="Context Window Extender", page_icon="🧠")
6
+
7
+ st.title("🧠 Model Context Window Extender")
8
+
9
+ st.markdown("""
10
+ Load any causal language model from Hugging Face Hub and extend its context window.
11
+
12
+ **Extension Methods:**
13
+ - **None**: Use model's original context length
14
+ - **Raw**: Simply increase max_position_embeddings (simple but may degrade quality)
15
+ - **RoPE**: Apply RoPE scaling for better quality (supports linear, dynamic, yarn)
16
+ """)
17
 
 
18
  model_cache = {}
19
 
20
  def load_model_with_extension(
 
24
  rope_type: str,
25
  rope_factor: float
26
  ):
 
 
 
 
 
 
 
 
 
 
 
 
27
  cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}"
28
 
29
  if cache_key in model_cache:
30
  return model_cache[cache_key]
31
 
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
33
 
34
  if tokenizer.pad_token is None:
35
  tokenizer.pad_token = tokenizer.eos_token
36
 
 
37
  config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
38
 
39
  original_context = getattr(config, "max_position_embeddings", 4096)
40
 
 
41
  if extension_method == "raw":
 
42
  config.max_position_embeddings = new_context_length
43
 
44
  elif extension_method == "rope":
 
45
  config.max_position_embeddings = new_context_length
46
 
 
47
  if hasattr(config, "rope_theta"):
 
48
  original_theta = getattr(config, "rope_theta", 10000.0)
49
 
 
50
  if rope_type == "linear":
 
51
  config.rope_theta = original_theta * rope_factor
52
  elif rope_type == "dynamic":
 
53
  config.rope_theta = original_theta * (rope_factor - 1) + original_theta * rope_factor
54
  elif rope_type == "yarn":
 
55
  config.rope_scaling = {
56
  "type": "yarn",
57
  "factor": rope_factor,
 
62
  }
63
  config.rope_theta = original_theta
64
 
 
65
  model = AutoModelForCausalLM.from_pretrained(
66
  model_id,
67
  config=config,
 
77
  "tokenizer": tokenizer,
78
  "original_context": original_context,
79
  "applied_context": new_context_length,
 
80
  }
81
 
82
  model_cache[cache_key] = result
83
  return result
84
 
85
 
86
+ col1, col2 = st.columns([2, 1])
87
+
88
+ with col1:
89
+ model_id = st.text_input(
90
+ "🤗 Model ID",
91
+ value="gpt2",
92
+ help="Enter Hugging Face model ID"
93
+ )
94
+ st.caption("Examples: gpt2, EleutherAI/gpt-neo-1.3B, microsoft/phi-2")
95
+
96
+ with col2:
97
+ extension_method = st.radio(
98
+ "Extension Method",
99
+ ["none", "raw", "rope"],
100
+ index=0,
101
+ help="Choose how to extend context window"
102
+ )
103
+
104
+ if extension_method == "rope":
105
+ col_rope1, col_rope2 = st.columns(2)
106
+ with col_rope1:
107
+ rope_type = st.selectbox(
108
+ "RoPE Type",
109
+ ["linear", "dynamic", "yarn"],
110
+ help="linear: simple scaling, dynamic: better quality, yarn: best quality"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
+ with col_rope2:
113
+ rope_factor = st.slider("RoPE Factor", 1.0, 8.0, 2.0, 0.5, help="Multiply context by this factor")
114
+ else:
115
+ rope_type = "linear"
116
+ rope_factor = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ new_context_length = st.slider("Target Context Length", 512, 32768, 2048, 512, help="Desired context window size (tokens)")
119
 
120
+ col_p1, col_p2 = st.columns(2)
 
 
 
 
 
 
 
121
 
122
+ with col_p1:
123
+ prompt = st.text_area("📝 Prompt", height=150, placeholder="Enter your prompt here...")
124
 
125
+ with col_p2:
126
+ max_new_tokens = st.slider("Max New Tokens", 10, 1024, 100, 10)
127
+ temperature = st.slider("Temperature", 0.0, 2.0, 0.7, 0.1)
128
+ top_p = st.slider("Top-p", 0.0, 1.0, 0.9, 0.05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ if st.button("🚀 Generate", type="primary"):
131
+ if not model_id.strip():
132
+ st.error("Please enter a model ID")
133
+ elif not prompt.strip():
134
+ st.error("Please enter a prompt")
135
+ else:
136
+ with st.spinner("Loading model..."):
137
+ try:
138
+ model_data = load_model_with_extension(
139
+ model_id,
140
+ extension_method,
141
+ new_context_length,
142
+ rope_type,
143
+ rope_factor
144
+ )
145
+
146
+ model = model_data["model"]
147
+ tokenizer = model_data["tokenizer"]
148
+
149
+ st.success(f"Model loaded! Original context: {model_data['original_context']}, Applied: {model_data['applied_context']}")
150
+
151
+ with st.spinner("Generating..."):
152
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=False, padding=False)
153
+
154
+ with torch.no_grad():
155
+ outputs = model.generate(
156
+ **inputs,
157
+ max_new_tokens=max_new_tokens,
158
+ temperature=temperature,
159
+ top_p=top_p,
160
+ do_sample=temperature > 0,
161
+ pad_token_id=tokenizer.pad_token_id,
162
+ eos_token_id=tokenizer.eos_token_id,
163
+ )
164
+
165
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
166
+
167
+ if generated_text.strip() == prompt.strip():
168
+ st.warning("Model generated same text as input. Try adjusting parameters.")
169
+ else:
170
+ st.text_area("📄 Generated Output", value=generated_text, height=250)
171
+
172
+ except Exception as e:
173
+ st.error(f"Error: {str(e)}")
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- gradio>=5.0.0
2
  transformers>=4.35.0
3
  accelerate>=0.25.0
 
1
+ streamlit>=1.40.0
2
  transformers>=4.35.0
3
  accelerate>=0.25.0