llaa33219's picture
Upload 3 files
2e2d23d verified
# Runtime upgrade to fix huggingface_hub compatibility
import subprocess
import sys
def upgrade_package(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package, "--quiet"])
# Upgrade packages before importing gradio
upgrade_package("gradio>=5.0.0")
upgrade_package("huggingface-hub")
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
model_cache = {}
def get_model_info(model_id):
"""Get model's current context length from config."""
try:
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
ctx = getattr(config, "max_position_embeddings", None)
if ctx is None:
return "Unknown"
return str(ctx)
except:
return "Unknown"
def calculate_context_length(base_context, multiplier):
"""Calculate new context length based on multiplier."""
multipliers = {
"2x": 2,
"5x": 5,
"10x": 10,
"20x": 20,
"50x": 50,
"100x": 100
}
return base_context * multipliers.get(multiplier, 2)
def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor):
"""Load model - CPU by default, ZeroGPU will handle GPU allocation."""
device = "cpu" # Use CPU, ZeroGPU will move to GPU when needed
cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}"
if cache_key in model_cache:
return model_cache[cache_key]
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
original_context = getattr(config, "max_position_embeddings", 4096)
if extension_method == "raw":
config.max_position_embeddings = new_context_length
elif extension_method == "rope":
config.max_position_embeddings = new_context_length
if hasattr(config, "rope_theta"):
original_theta = getattr(config, "rope_theta", 10000.0)
if rope_type == "linear":
config.rope_theta = original_theta * rope_factor
elif rope_type == "dynamic":
config.rope_theta = original_theta * (rope_factor - 1) + original_theta * rope_factor
elif rope_type == "yarn":
config.rope_scaling = {"type": "yarn", "factor": rope_factor, "original_max_position_embeddings": original_context}
config.rope_theta = original_theta
torch_dtype = torch.float16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
torch_dtype=torch_dtype,
device_map="cpu", # Load on CPU, ZeroGPU handles GPU
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.eval()
result = {"model": model, "tokenizer": tokenizer, "original_context": original_context, "applied_context": new_context_length}
model_cache[cache_key] = result
return result
@spaces.GPU(duration=300)
def generate(model_id, extension_method, new_context_length, rope_type, rope_factor, prompt, max_new_tokens, temperature, top_p):
if not model_id.strip():
return "Error: Please enter a model ID"
if not prompt.strip():
return "Error: Please enter a prompt"
try:
model_data = load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor)
except Exception as e:
return f"Error loading model: {str(e)}"
model = model_data["model"]
tokenizer = model_data["tokenizer"]
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=temperature > 0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if generated_text.strip() == prompt.strip():
return "Model generated same text as input. Try adjusting parameters."
return generated_text
except Exception as e:
return f"Error during generation: {str(e)}"
# Default model - recent Qwen3 series
DEFAULT_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507"
with gr.Blocks(title="Context Window Extender - Chat") as demo:
gr.Markdown("""
# 🧠 Context Window Extender - Chat Mode
Load any model from Hugging Face Hub and extend its context window dynamically.
Select a multiplier to expand context by 2x to 100x!
""")
with gr.Row():
with gr.Column(scale=2):
# Model selection
model_id = gr.Textbox(
value=DEFAULT_MODEL,
label="🤗 Model ID",
placeholder="Enter Hugging Face model ID..."
)
gr.Examples([
["Qwen/Qwen3-30B-A3B-Thinking-2507"],
["Qwen/Qwen2.5-1.5B-Instruct"],
["Qwen/Qwen2.5-3B-Instruct"],
["microsoft/phi-4-mini-instruct"],
["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"],
], inputs=model_id)
# Define these first so they can be used in buttons
with gr.Row():
with gr.Column():
extension_method = gr.Radio(
["none", "raw", "rope"],
value="rope",
label="Extension Method"
)
with gr.Column():
rope_type = gr.Dropdown(
["linear", "dynamic", "yarn"],
value="linear",
label="RoPE Type",
visible=True
)
rope_factor = gr.Slider(
minimum=1.0,
maximum=8.0,
value=2.0,
step=0.5,
label="RoPE Factor",
visible=True
)
# Define context_multiplier BEFORE it's used in buttons
context_multiplier = gr.Dropdown(
choices=["2x", "5x", "10x", "20x", "50x", "100x"],
value="2x",
label="📈 Context Multiplier",
info="Expand context window by this factor"
)
with gr.Row():
with gr.Column(scale=2):
# Model selection
model_id = gr.Textbox(
value=DEFAULT_MODEL,
label="🤗 Model ID",
placeholder="Enter Hugging Face model ID..."
)
gr.Examples([
["Qwen/Qwen3-30B-A3B-Thinking-2507"],
["Qwen/Qwen2.5-1.5B-Instruct"],
["Qwen/Qwen2.5-3B-Instruct"],
["microsoft/phi-4-mini-instruct"],
["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"],
], inputs=model_id)
with gr.Row():
download_btn = gr.Button("📥 Download Model", variant="secondary")
load_btn = gr.Button("🚀 Load Model", variant="primary")
model_status = gr.Textbox(label="Model Status", interactive=False)
# Download model function (runs outside ZeroGPU)
def download_model(mid):
if not mid.strip():
return "Error: Please enter a model ID"
try:
# Download tokenizer and config first
from transformers import AutoTokenizer, AutoConfig
tokenizer = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
config = AutoConfig.from_pretrained(mid, trust_remote_code=True)
return f"✅ Model downloaded: {mid}"
except Exception as e:
return f"❌ Download failed: {str(e)}"
download_btn.click(download_model, inputs=[model_id], outputs=[model_status])
# Load model function (runs inside ZeroGPU)
@spaces.GPU(duration=300)
def load_model(mid, ext_method, ctx_mult, rt, rf):
if not mid.strip():
return "Error: Please enter a model ID"
try:
base_ctx = 32768
new_ctx = calculate_context_length(base_ctx, ctx_mult)
model_data = load_model_with_extension(mid, ext_method, new_ctx, rt, rf)
return f"✅ Model loaded: {mid} (context: {new_ctx})"
except Exception as e:
return f"❌ Load failed: {str(e)}"
load_btn.click(load_model, inputs=[model_id, extension_method, context_multiplier, rope_type, rope_factor], outputs=[model_status])
# Show context info
with gr.Row():
base_ctx = gr.Number(value=32768, label="Base Context", interactive=False)
extended_ctx = gr.Number(value=65536, label="Extended Context", interactive=False)
# Update extended context when multiplier changes
def update_extended_context(multiplier, base=32768):
return calculate_context_length(base, multiplier)
context_multiplier.change(
fn=update_extended_context,
inputs=[context_multiplier],
outputs=extended_ctx
)
model_id.change(
fn=get_model_info,
inputs=model_id,
outputs=base_ctx
)
with gr.Row():
max_new_tokens = gr.Slider(minimum=10, maximum=32768, value=256, step=10, label="Max New Tokens")
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p")
# Update max_new_tokens slider max based on context multiplier
def update_max_tokens(multiplier):
base = 32768
max_tokens = calculate_context_length(base, multiplier)
return gr.update(maximum=max_tokens)
context_multiplier.change(
fn=update_max_tokens,
inputs=[context_multiplier],
outputs=[max_new_tokens]
)
# Hide/show RoPE options based on extension method
def update_rope_visibility(method):
return gr.update(visible=method == "rope"), gr.update(visible=method == "rope")
extension_method.change(
update_rope_visibility,
extension_method,
[rope_type, rope_factor]
)
gr.Markdown("---")
gr.Markdown("### 💬 Chat with the Model")
# Conversational chat interface
@spaces.GPU(duration=300)
def respond(
message: str,
history: list,
model_id: str,
extension_method: str,
context_multiplier: str,
rope_type: str,
rope_factor: float,
max_new_tokens: int,
temperature: float,
top_p: float,
):
"""Handle chat response with streaming."""
if not message.strip():
yield [{"role": "user", "content": msg} for msg, _ in history] + [{"role": "user", "content": message, "content": "Please enter a message."}]
return
# Add user message to history
history.append({"role": "user", "content": message})
yield history + [{"role": "assistant", "content": "..."}]
# Generate response
try:
base_context = 32768
new_context_length = calculate_context_length(base_context, context_multiplier)
# Build prompt from history
prompt = message
for item in history[:-1]:
role = item.get("role", "user")
content = item.get("content", "")
prompt = f"User: {content}\nAssistant: " + prompt
prompt = prompt + "\nAssistant:"
model_data = load_model_with_extension(
model_id,
extension_method,
new_context_length,
rope_type,
rope_factor
)
model = model_data["model"]
tokenizer = model_data["tokenizer"]
# Move model to GPU for generation
model = model.to("cuda")
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Stream generation
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"inputs": inputs,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": temperature > 0,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_response = ""
for text in streamer:
full_response += text
# Update the last message (assistant response)
current_history = history + [{"role": "assistant", "content": full_response}]
yield current_history
thread.join()
if not full_response.strip():
full_response = "Model generated same text as input. Try adjusting parameters."
yield history + [{"role": "assistant", "content": full_response}]
except Exception as e:
full_response = f"Error: {str(e)}"
yield history + [{"role": "assistant", "content": full_response}]
# ChatInterface
chat_interface = gr.ChatInterface(
fn=respond,
additional_inputs=[
model_id,
extension_method,
context_multiplier,
rope_type,
rope_factor,
max_new_tokens,
temperature,
top_p
],
title="",
description=None,
autofocus=True
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)