llaa33219 commited on
Commit
048e09e
·
verified ·
1 Parent(s): dd51a85

Upload 3 files

Browse files
Files changed (2) hide show
  1. README.md +19 -3
  2. app.py +293 -16
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Context Window Extender
3
  emoji: 🧠
4
  colorFrom: purple
5
  colorTo: indigo
@@ -9,6 +9,22 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # Context Window Extender
13
 
14
- Load any causal language model from Hugging Face Hub and extend its context window.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Context Window Extender - Chat Mode
3
  emoji: 🧠
4
  colorFrom: purple
5
  colorTo: indigo
 
9
  pinned: false
10
  ---
11
 
12
+ # Context Window Extender - Chat Mode
13
 
14
+ Load any causal language model from Hugging Face Hub and extend its context window dynamically.
15
+
16
+ ## Features
17
+
18
+ - **Recent Models**: Default model is `Qwen/Qwen3-30B-A3B-Thinking-2507` (256K context, extendable to 1M)
19
+ - **Conversational UI**: Chat-style interface instead of form-based
20
+ - **Dynamic Context Multiplier**: Expand context by 2x, 5x, 10x, 20x, 50x, or 100x
21
+ - **Streaming Responses**: Real-time streaming of model outputs
22
+ - **RoPE Extension**: Support for linear, dynamic, and YaRN RoPE scaling
23
+
24
+ ## Available Models
25
+
26
+ - Qwen/Qwen3-30B-A3B-Thinking-2507 (default)
27
+ - Qwen/Qwen2.5-1.5B-Instruct
28
+ - Qwen/Qwen2.5-3B-Instruct
29
+ - microsoft/phi-4-mini-instruct
30
+ - deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
app.py CHANGED
@@ -5,6 +5,31 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
5
 
6
  model_cache = {}
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor, device="cuda"):
9
  cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}_{device}"
10
 
@@ -76,33 +101,285 @@ def generate(model_id, extension_method, new_context_length, rope_type, rope_fac
76
  return f"Error during generation: {str(e)}"
77
 
78
 
79
- with gr.Blocks(title="Context Window Extender") as demo:
80
- gr.Markdown("# Context Window Extender\n\nLoad any model from Hugging Face Hub and extend its context window.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  with gr.Row():
83
  with gr.Column():
84
- model_id = gr.Textbox(value="gpt2", label="Model ID")
85
- extension_method = gr.Radio(["none", "raw", "rope"], value="none", label="Extension Method")
86
- new_context_length = gr.Slider(minimum=512, maximum=32768, value=2048, step=512, label="Target Context Length")
 
 
87
  with gr.Column():
88
- rope_type = gr.Dropdown(["linear", "dynamic", "yarn"], value="linear", label="RoPE Type", visible=False)
89
- rope_factor = gr.Slider(minimum=1.0, maximum=8.0, value=2.0, step=0.5, label="RoPE Factor", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- prompt = gr.Textbox(label="Prompt", lines=6)
92
  with gr.Row():
93
- max_new_tokens = gr.Slider(minimum=10, maximum=1024, value=100, step=10, label="Max New Tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
95
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p")
96
 
97
- generate_btn = gr.Button("Generate", variant="primary")
98
- output = gr.Textbox(label="Output", lines=10)
 
99
 
100
- def update_rope(v):
101
- return gr.update(visible=v == "rope")
 
 
 
 
 
 
102
 
103
- extension_method.change(update_rope, extension_method, [rope_type, rope_factor])
104
- generate_btn.click(generate, [model_id, extension_method, new_context_length, rope_type, rope_factor, prompt, max_new_tokens, temperature, top_p], output)
105
- prompt.submit(generate, [model_id, extension_method, new_context_length, rope_type, rope_factor, prompt, max_new_tokens, temperature, top_p], output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  if __name__ == "__main__":
108
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
5
 
6
  model_cache = {}
7
 
8
+ def get_model_info(model_id):
9
+ """Get model's current context length from config."""
10
+ try:
11
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
12
+ ctx = getattr(config, "max_position_embeddings", None)
13
+ if ctx is None:
14
+ return "Unknown"
15
+ return str(ctx)
16
+ except:
17
+ return "Unknown"
18
+
19
+
20
+ def calculate_context_length(base_context, multiplier):
21
+ """Calculate new context length based on multiplier."""
22
+ multipliers = {
23
+ "2x": 2,
24
+ "5x": 5,
25
+ "10x": 10,
26
+ "20x": 20,
27
+ "50x": 50,
28
+ "100x": 100
29
+ }
30
+ return base_context * multipliers.get(multiplier, 2)
31
+
32
+
33
  def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor, device="cuda"):
34
  cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}_{device}"
35
 
 
101
  return f"Error during generation: {str(e)}"
102
 
103
 
104
+ # Chat-based generation function for conversational UI
105
+ @spaces.GPU(duration=120)
106
+ def chat_generate(message, history, model_id, extension_method, context_multiplier, rope_type, rope_factor, max_new_tokens, temperature, top_p):
107
+ """Generate response in conversational format with streaming."""
108
+ if not model_id.strip():
109
+ yield "Error: Please select a model ID"
110
+ return
111
+
112
+ # Get base context length and calculate new context
113
+ base_context = 32768 # Default base for Qwen3
114
+ new_context_length = calculate_context_length(base_context, context_multiplier)
115
+
116
+ # Build full prompt from history
117
+ prompt = message
118
+ for user_msg, assistant_msg in history:
119
+ prompt = f"User: {user_msg}\nAssistant: {assistant_msg}\nUser: {message}\nAssistant:"
120
+
121
+ if not prompt.strip():
122
+ yield "Error: Please enter a message"
123
+ return
124
+
125
+ try:
126
+ model_data = load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor)
127
+ except Exception as e:
128
+ yield f"Error loading model: {str(e)}"
129
+ return
130
+
131
+ model = model_data["model"]
132
+ tokenizer = model_data["tokenizer"]
133
+
134
+ try:
135
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
136
+
137
+ # Stream generation
138
+ full_response = ""
139
+ from transformers import TextIteratorStreamer
140
+ from threading import Thread
141
+
142
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
143
+
144
+ generation_kwargs = {
145
+ "inputs": inputs,
146
+ "max_new_tokens": max_new_tokens,
147
+ "temperature": temperature,
148
+ "top_p": top_p,
149
+ "do_sample": temperature > 0,
150
+ "pad_token_id": tokenizer.pad_token_id,
151
+ "eos_token_id": tokenizer.eos_token_id,
152
+ "streamer": streamer
153
+ }
154
+
155
+ # Run generation in thread
156
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
157
+ thread.start()
158
+
159
+ # Yield streamed response
160
+ for text in streamer:
161
+ full_response += text
162
+ yield full_response
163
+
164
+ thread.join()
165
+
166
+ if not full_response.strip():
167
+ yield "Model generated same text as input. Try adjusting parameters."
168
+ return
169
+
170
+ except Exception as e:
171
+ yield f"Error during generation: {str(e)}"
172
+
173
+
174
+ # Default model - recent Qwen3 series
175
+ DEFAULT_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507"
176
+
177
+ with gr.Blocks(title="Context Window Extender - Chat") as demo:
178
+ gr.Markdown("""
179
+ # 🧠 Context Window Extender - Chat Mode
180
+
181
+ Load any model from Hugging Face Hub and extend its context window dynamically.
182
+ Select a multiplier to expand context by 2x to 100x!
183
+ """)
184
+
185
+ with gr.Row():
186
+ with gr.Column(scale=2):
187
+ # Model selection
188
+ model_id = gr.Textbox(
189
+ value=DEFAULT_MODEL,
190
+ label="🤗 Model ID",
191
+ placeholder="Enter Hugging Face model ID..."
192
+ )
193
+ gr.Examples([
194
+ ["Qwen/Qwen3-30B-A3B-Thinking-2507"],
195
+ ["Qwen/Qwen2.5-1.5B-Instruct"],
196
+ ["Qwen/Qwen2.5-3B-Instruct"],
197
+ ["microsoft/phi-4-mini-instruct"],
198
+ ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"],
199
+ ], inputs=model_id)
200
+
201
+ with gr.Column(scale=1):
202
+ # Context multiplier selector
203
+ context_multiplier = gr.Dropdown(
204
+ choices=["2x", "5x", "10x", "20x", "50x", "100x"],
205
+ value="2x",
206
+ label="📈 Context Multiplier",
207
+ info="Expand context window by this factor"
208
+ )
209
 
210
  with gr.Row():
211
  with gr.Column():
212
+ extension_method = gr.Radio(
213
+ ["none", "raw", "rope"],
214
+ value="rope",
215
+ label="Extension Method"
216
+ )
217
  with gr.Column():
218
+ rope_type = gr.Dropdown(
219
+ ["linear", "dynamic", "yarn"],
220
+ value="linear",
221
+ label="RoPE Type",
222
+ visible=True
223
+ )
224
+ rope_factor = gr.Slider(
225
+ minimum=1.0,
226
+ maximum=8.0,
227
+ value=2.0,
228
+ step=0.5,
229
+ label="RoPE Factor",
230
+ visible=True
231
+ )
232
 
233
+ # Show context info
234
  with gr.Row():
235
+ base_ctx = gr.Number(value=32768, label="Base Context", interactive=False)
236
+ extended_ctx = gr.Number(value=65536, label="Extended Context", interactive=False)
237
+
238
+ # Update extended context when multiplier changes
239
+ def update_extended_context(multiplier, base=32768):
240
+ return calculate_context_length(base, multiplier)
241
+
242
+ context_multiplier.change(
243
+ fn=update_extended_context,
244
+ inputs=[context_multiplier],
245
+ outputs=extended_ctx
246
+ )
247
+
248
+ model_id.change(
249
+ fn=get_model_info,
250
+ inputs=model_id,
251
+ outputs=base_ctx
252
+ )
253
+
254
+ with gr.Row():
255
+ max_new_tokens = gr.Slider(minimum=10, maximum=4096, value=256, step=10, label="Max New Tokens")
256
  temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
257
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p")
258
 
259
+ # Hide/show RoPE options based on extension method
260
+ def update_rope_visibility(method):
261
+ return gr.update(visible=method == "rope")
262
 
263
+ extension_method.change(
264
+ update_rope_visibility,
265
+ extension_method,
266
+ [rope_type, rope_factor]
267
+ )
268
+
269
+ gr.Markdown("---")
270
+ gr.Markdown("### 💬 Chat with the Model")
271
 
272
+ # Conversational chat interface
273
+ def respond(
274
+ message: str,
275
+ history: list,
276
+ model_id: str,
277
+ extension_method: str,
278
+ context_multiplier: str,
279
+ rope_type: str,
280
+ rope_factor: float,
281
+ max_new_tokens: int,
282
+ temperature: float,
283
+ top_p: float,
284
+ ):
285
+ """Handle chat response with streaming."""
286
+ if not message.strip():
287
+ yield history + [{"role": "assistant", "content": "Please enter a message."}]
288
+ return
289
+
290
+ # Add user message to history
291
+ history.append({"role": "user", "content": message})
292
+ yield history + [{"role": "assistant", "content": "..."}]
293
+
294
+ # Generate response
295
+ try:
296
+ base_context = 32768
297
+ new_context_length = calculate_context_length(base_context, context_multiplier)
298
+
299
+ # Build prompt from history
300
+ prompt = message
301
+ for user_msg, assistant_msg in history[:-1]:
302
+ if isinstance(user_msg, dict):
303
+ user_content = user_msg.get("content", str(user_msg))
304
+ assistant_content = assistant_msg.get("content", str(assistant_msg)) if isinstance(assistant_msg, dict) else str(assistant_msg)
305
+ else:
306
+ user_content = str(user_msg)
307
+ assistant_content = str(assistant_msg)
308
+ prompt = f"User: {user_content}\nAssistant: {assistant_content}\n" + prompt
309
+
310
+ prompt = prompt + "\nAssistant:"
311
+
312
+ model_data = load_model_with_extension(
313
+ model_id,
314
+ extension_method,
315
+ new_context_length,
316
+ rope_type,
317
+ rope_factor
318
+ )
319
+ model = model_data["model"]
320
+ tokenizer = model_data["tokenizer"]
321
+
322
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
323
+
324
+ # Stream generation
325
+ from transformers import TextIteratorStreamer
326
+ from threading import Thread
327
+
328
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
329
+
330
+ generation_kwargs = {
331
+ "inputs": inputs,
332
+ "max_new_tokens": max_new_tokens,
333
+ "temperature": temperature,
334
+ "top_p": top_p,
335
+ "do_sample": temperature > 0,
336
+ "pad_token_id": tokenizer.pad_token_id,
337
+ "eos_token_id": tokenizer.eos_token_id,
338
+ "streamer": streamer
339
+ }
340
+
341
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
342
+ thread.start()
343
+
344
+ full_response = ""
345
+ for text in streamer:
346
+ full_response += text
347
+ # Update the last message (assistant response)
348
+ current_history = history + [{"role": "assistant", "content": full_response}]
349
+ yield current_history
350
+
351
+ thread.join()
352
+
353
+ if not full_response.strip():
354
+ full_response = "Model generated same text as input. Try adjusting parameters."
355
+
356
+ except Exception as e:
357
+ full_response = f"Error: {str(e)}"
358
+ yield history + [{"role": "assistant", "content": full_response}]
359
+ return
360
+
361
+ # ChatInterface
362
+ chat_interface = gr.ChatInterface(
363
+ fn=respond,
364
+ additional_inputs=[
365
+ model_id,
366
+ extension_method,
367
+ context_multiplier,
368
+ rope_type,
369
+ rope_factor,
370
+ max_new_tokens,
371
+ temperature,
372
+ top_p
373
+ ],
374
+ title="",
375
+ description=None,
376
+ examples=[
377
+ {"text": "Hello, how are you?"},
378
+ {"text": "Explain quantum computing in simple terms."},
379
+ {"text": "Write a short poem about artificial intelligence."}
380
+ ],
381
+ autofocus=True
382
+ )
383
 
384
  if __name__ == "__main__":
385
  demo.launch(server_name="0.0.0.0", server_port=7860)