File size: 14,516 Bytes
91aa250
 
 
 
 
 
 
 
 
 
 
72ad230
dd51a85
60026f3
 
 
 
 
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adbe710
 
 
 
 
60026f3
 
 
 
4b77ff5
60026f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ad230
60026f3
 
dd51a85
 
 
 
 
 
adbe710
dd51a85
 
 
60026f3
 
72ad230
60026f3
 
 
 
2e2d23d
72ad230
4b77ff5
72ad230
 
 
 
 
 
 
 
 
 
 
 
 
dd51a85
72ad230
 
 
 
 
 
 
 
 
 
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e2d23d
72ad230
 
048e09e
 
 
 
 
72ad230
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
 
72ad230
2e2d23d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
048e09e
72ad230
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fc13ab
72ad230
 
 
7fc13ab
 
 
 
 
 
 
 
 
 
 
 
048e09e
 
9e0d513
72ad230
048e09e
 
 
 
 
 
 
 
72ad230
048e09e
2e2d23d
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
 
360a4ff
048e09e
 
 
360a4ff
 
048e09e
 
 
 
 
 
 
 
360a4ff
 
 
 
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
adbe710
 
 
 
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360a4ff
 
048e09e
 
 
 
 
360a4ff
048e09e
 
 
360a4ff
048e09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ad230
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
# Runtime upgrade to fix huggingface_hub compatibility
import subprocess
import sys

def upgrade_package(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package, "--quiet"])

# Upgrade packages before importing gradio
upgrade_package("gradio>=5.0.0")
upgrade_package("huggingface-hub")

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

model_cache = {}

def get_model_info(model_id):
    """Get model's current context length from config."""
    try:
        config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        ctx = getattr(config, "max_position_embeddings", None)
        if ctx is None:
            return "Unknown"
        return str(ctx)
    except:
        return "Unknown"


def calculate_context_length(base_context, multiplier):
    """Calculate new context length based on multiplier."""
    multipliers = {
        "2x": 2,
        "5x": 5,
        "10x": 10,
        "20x": 20,
        "50x": 50,
        "100x": 100
    }
    return base_context * multipliers.get(multiplier, 2)


def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor):
    """Load model - CPU by default, ZeroGPU will handle GPU allocation."""
    device = "cpu"  # Use CPU, ZeroGPU will move to GPU when needed
    
    cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}"
    
    if cache_key in model_cache:
        return model_cache[cache_key]
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
    original_context = getattr(config, "max_position_embeddings", 4096)
    
    if extension_method == "raw":
        config.max_position_embeddings = new_context_length
    elif extension_method == "rope":
        config.max_position_embeddings = new_context_length
        if hasattr(config, "rope_theta"):
            original_theta = getattr(config, "rope_theta", 10000.0)
            if rope_type == "linear":
                config.rope_theta = original_theta * rope_factor
            elif rope_type == "dynamic":
                config.rope_theta = original_theta * (rope_factor - 1) + original_theta * rope_factor
            elif rope_type == "yarn":
                config.rope_scaling = {"type": "yarn", "factor": rope_factor, "original_max_position_embeddings": original_context}
                config.rope_theta = original_theta
    
    torch_dtype = torch.float16 if device == "cuda" else torch.float32
    
    model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        config=config, 
        torch_dtype=torch_dtype, 
        device_map="cpu",  # Load on CPU, ZeroGPU handles GPU
        low_cpu_mem_usage=True, 
        trust_remote_code=True
    )
    model.eval()
    
    result = {"model": model, "tokenizer": tokenizer, "original_context": original_context, "applied_context": new_context_length}
    model_cache[cache_key] = result
    return result


@spaces.GPU(duration=300)
def generate(model_id, extension_method, new_context_length, rope_type, rope_factor, prompt, max_new_tokens, temperature, top_p):
    if not model_id.strip():
        return "Error: Please enter a model ID"
    if not prompt.strip():
        return "Error: Please enter a prompt"
    
    try:
        model_data = load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor)
    except Exception as e:
        return f"Error loading model: {str(e)}"
    
    model = model_data["model"]
    tokenizer = model_data["tokenizer"]
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=temperature > 0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if generated_text.strip() == prompt.strip():
            return "Model generated same text as input. Try adjusting parameters."
        return generated_text
    except Exception as e:
        return f"Error during generation: {str(e)}"


# Default model - recent Qwen3 series
DEFAULT_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507"

with gr.Blocks(title="Context Window Extender - Chat") as demo:
    gr.Markdown("""
    # 🧠 Context Window Extender - Chat Mode
    
    Load any model from Hugging Face Hub and extend its context window dynamically.
    Select a multiplier to expand context by 2x to 100x!
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            # Model selection
            model_id = gr.Textbox(
                value=DEFAULT_MODEL, 
                label="πŸ€— Model ID",
                placeholder="Enter Hugging Face model ID..."
            )
            gr.Examples([
                ["Qwen/Qwen3-30B-A3B-Thinking-2507"],
                ["Qwen/Qwen2.5-1.5B-Instruct"],
                ["Qwen/Qwen2.5-3B-Instruct"],
                ["microsoft/phi-4-mini-instruct"],
                ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"],
            ], inputs=model_id)
            
    # Define these first so they can be used in buttons
    with gr.Row():
        with gr.Column():
            extension_method = gr.Radio(
                ["none", "raw", "rope"], 
                value="rope", 
                label="Extension Method"
            )
        with gr.Column():
            rope_type = gr.Dropdown(
                ["linear", "dynamic", "yarn"], 
                value="linear", 
                label="RoPE Type",
                visible=True
            )
            rope_factor = gr.Slider(
                minimum=1.0, 
                maximum=8.0, 
                value=2.0, 
                step=0.5, 
                label="RoPE Factor",
                visible=True
            )
    
    # Define context_multiplier BEFORE it's used in buttons
    context_multiplier = gr.Dropdown(
        choices=["2x", "5x", "10x", "20x", "50x", "100x"],
        value="2x",
        label="πŸ“ˆ Context Multiplier",
        info="Expand context window by this factor"
    )
    
    with gr.Row():
        with gr.Column(scale=2):
            # Model selection
            model_id = gr.Textbox(
                value=DEFAULT_MODEL, 
                label="πŸ€— Model ID",
                placeholder="Enter Hugging Face model ID..."
            )
            gr.Examples([
                ["Qwen/Qwen3-30B-A3B-Thinking-2507"],
                ["Qwen/Qwen2.5-1.5B-Instruct"],
                ["Qwen/Qwen2.5-3B-Instruct"],
                ["microsoft/phi-4-mini-instruct"],
                ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"],
            ], inputs=model_id)
            
            with gr.Row():
                download_btn = gr.Button("πŸ“₯ Download Model", variant="secondary")
                load_btn = gr.Button("πŸš€ Load Model", variant="primary")
            
            model_status = gr.Textbox(label="Model Status", interactive=False)
            
            # Download model function (runs outside ZeroGPU)
            def download_model(mid):
                if not mid.strip():
                    return "Error: Please enter a model ID"
                try:
                    # Download tokenizer and config first
                    from transformers import AutoTokenizer, AutoConfig
                    tokenizer = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
                    config = AutoConfig.from_pretrained(mid, trust_remote_code=True)
                    return f"βœ… Model downloaded: {mid}"
                except Exception as e:
                    return f"❌ Download failed: {str(e)}"
            
            download_btn.click(download_model, inputs=[model_id], outputs=[model_status])
            
            # Load model function (runs inside ZeroGPU)
            @spaces.GPU(duration=300)
            def load_model(mid, ext_method, ctx_mult, rt, rf):
                if not mid.strip():
                    return "Error: Please enter a model ID"
                try:
                    base_ctx = 32768
                    new_ctx = calculate_context_length(base_ctx, ctx_mult)
                    model_data = load_model_with_extension(mid, ext_method, new_ctx, rt, rf)
                    return f"βœ… Model loaded: {mid} (context: {new_ctx})"
                except Exception as e:
                    return f"❌ Load failed: {str(e)}"
            
            load_btn.click(load_model, inputs=[model_id, extension_method, context_multiplier, rope_type, rope_factor], outputs=[model_status])
    
    # Show context info
    with gr.Row():
        base_ctx = gr.Number(value=32768, label="Base Context", interactive=False)
        extended_ctx = gr.Number(value=65536, label="Extended Context", interactive=False)
    
    # Update extended context when multiplier changes
    def update_extended_context(multiplier, base=32768):
        return calculate_context_length(base, multiplier)
    
    context_multiplier.change(
        fn=update_extended_context,
        inputs=[context_multiplier],
        outputs=extended_ctx
    )
    
    model_id.change(
        fn=get_model_info,
        inputs=model_id,
        outputs=base_ctx
    )
    
    with gr.Row():
        max_new_tokens = gr.Slider(minimum=10, maximum=32768, value=256, step=10, label="Max New Tokens")
        temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p")
    
    # Update max_new_tokens slider max based on context multiplier
    def update_max_tokens(multiplier):
        base = 32768
        max_tokens = calculate_context_length(base, multiplier)
        return gr.update(maximum=max_tokens)
    
    context_multiplier.change(
        fn=update_max_tokens,
        inputs=[context_multiplier],
        outputs=[max_new_tokens]
    )
    
    # Hide/show RoPE options based on extension method
    def update_rope_visibility(method):
        return gr.update(visible=method == "rope"), gr.update(visible=method == "rope")
    
    extension_method.change(
        update_rope_visibility, 
        extension_method, 
        [rope_type, rope_factor]
    )
    
    gr.Markdown("---")
    gr.Markdown("### πŸ’¬ Chat with the Model")
    
    # Conversational chat interface
    @spaces.GPU(duration=300)
    def respond(
        message: str,
        history: list,
        model_id: str,
        extension_method: str,
        context_multiplier: str,
        rope_type: str,
        rope_factor: float,
        max_new_tokens: int,
        temperature: float,
        top_p: float,
    ):
        """Handle chat response with streaming."""
        if not message.strip():
            yield [{"role": "user", "content": msg} for msg, _ in history] + [{"role": "user", "content": message, "content": "Please enter a message."}]
            return
        
        # Add user message to history
        history.append({"role": "user", "content": message})
        yield history + [{"role": "assistant", "content": "..."}]
        
        # Generate response
        try:
            base_context = 32768
            new_context_length = calculate_context_length(base_context, context_multiplier)
            
            # Build prompt from history
            prompt = message
            for item in history[:-1]:
                role = item.get("role", "user")
                content = item.get("content", "")
                prompt = f"User: {content}\nAssistant: " + prompt
            
            prompt = prompt + "\nAssistant:"
            
            model_data = load_model_with_extension(
                model_id, 
                extension_method, 
                new_context_length, 
                rope_type, 
                rope_factor
            )
            model = model_data["model"]
            tokenizer = model_data["tokenizer"]
            
            # Move model to GPU for generation
            model = model.to("cuda")
            
            inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
            
            # Stream generation
            from transformers import TextIteratorStreamer
            from threading import Thread
            
            streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            
            generation_kwargs = {
                "inputs": inputs,
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "do_sample": temperature > 0,
                "pad_token_id": tokenizer.pad_token_id,
                "eos_token_id": tokenizer.eos_token_id,
                "streamer": streamer
            }
            
            thread = Thread(target=model.generate, kwargs=generation_kwargs)
            thread.start()
            
            full_response = ""
            for text in streamer:
                full_response += text
                # Update the last message (assistant response)
                current_history = history + [{"role": "assistant", "content": full_response}]
                yield current_history
            
            thread.join()
            
            if not full_response.strip():
                full_response = "Model generated same text as input. Try adjusting parameters."
                yield history + [{"role": "assistant", "content": full_response}]
            
        except Exception as e:
            full_response = f"Error: {str(e)}"
            yield history + [{"role": "assistant", "content": full_response}]
    
    # ChatInterface
    chat_interface = gr.ChatInterface(
        fn=respond,
        additional_inputs=[
            model_id,
            extension_method,
            context_multiplier,
            rope_type,
            rope_factor,
            max_new_tokens,
            temperature,
            top_p
        ],
        title="",
        description=None,
        autofocus=True
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)