Spaces:
Running on Zero
Running on Zero
| # Runtime upgrade to fix huggingface_hub compatibility | |
| import subprocess | |
| import sys | |
| def upgrade_package(package): | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package, "--quiet"]) | |
| # Upgrade packages before importing gradio | |
| upgrade_package("gradio>=5.0.0") | |
| upgrade_package("huggingface-hub") | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
| model_cache = {} | |
| def get_model_info(model_id): | |
| """Get model's current context length from config.""" | |
| try: | |
| config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) | |
| ctx = getattr(config, "max_position_embeddings", None) | |
| if ctx is None: | |
| return "Unknown" | |
| return str(ctx) | |
| except: | |
| return "Unknown" | |
| def calculate_context_length(base_context, multiplier): | |
| """Calculate new context length based on multiplier.""" | |
| multipliers = { | |
| "2x": 2, | |
| "5x": 5, | |
| "10x": 10, | |
| "20x": 20, | |
| "50x": 50, | |
| "100x": 100 | |
| } | |
| return base_context * multipliers.get(multiplier, 2) | |
| def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor): | |
| """Load model - CPU by default, ZeroGPU will handle GPU allocation.""" | |
| device = "cpu" # Use CPU, ZeroGPU will move to GPU when needed | |
| cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}" | |
| if cache_key in model_cache: | |
| return model_cache[cache_key] | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) | |
| original_context = getattr(config, "max_position_embeddings", 4096) | |
| if extension_method == "raw": | |
| config.max_position_embeddings = new_context_length | |
| elif extension_method == "rope": | |
| config.max_position_embeddings = new_context_length | |
| if hasattr(config, "rope_theta"): | |
| original_theta = getattr(config, "rope_theta", 10000.0) | |
| if rope_type == "linear": | |
| config.rope_theta = original_theta * rope_factor | |
| elif rope_type == "dynamic": | |
| config.rope_theta = original_theta * (rope_factor - 1) + original_theta * rope_factor | |
| elif rope_type == "yarn": | |
| config.rope_scaling = {"type": "yarn", "factor": rope_factor, "original_max_position_embeddings": original_context} | |
| config.rope_theta = original_theta | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| config=config, | |
| torch_dtype=torch_dtype, | |
| device_map="cpu", # Load on CPU, ZeroGPU handles GPU | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| result = {"model": model, "tokenizer": tokenizer, "original_context": original_context, "applied_context": new_context_length} | |
| model_cache[cache_key] = result | |
| return result | |
| def generate(model_id, extension_method, new_context_length, rope_type, rope_factor, prompt, max_new_tokens, temperature, top_p): | |
| if not model_id.strip(): | |
| return "Error: Please enter a model ID" | |
| if not prompt.strip(): | |
| return "Error: Please enter a prompt" | |
| try: | |
| model_data = load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor) | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| model = model_data["model"] | |
| tokenizer = model_data["tokenizer"] | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=temperature > 0, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if generated_text.strip() == prompt.strip(): | |
| return "Model generated same text as input. Try adjusting parameters." | |
| return generated_text | |
| except Exception as e: | |
| return f"Error during generation: {str(e)}" | |
| # Default model - recent Qwen3 series | |
| DEFAULT_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" | |
| with gr.Blocks(title="Context Window Extender - Chat") as demo: | |
| gr.Markdown(""" | |
| # 🧠 Context Window Extender - Chat Mode | |
| Load any model from Hugging Face Hub and extend its context window dynamically. | |
| Select a multiplier to expand context by 2x to 100x! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Model selection | |
| model_id = gr.Textbox( | |
| value=DEFAULT_MODEL, | |
| label="🤗 Model ID", | |
| placeholder="Enter Hugging Face model ID..." | |
| ) | |
| gr.Examples([ | |
| ["Qwen/Qwen3-30B-A3B-Thinking-2507"], | |
| ["Qwen/Qwen2.5-1.5B-Instruct"], | |
| ["Qwen/Qwen2.5-3B-Instruct"], | |
| ["microsoft/phi-4-mini-instruct"], | |
| ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"], | |
| ], inputs=model_id) | |
| # Define these first so they can be used in buttons | |
| with gr.Row(): | |
| with gr.Column(): | |
| extension_method = gr.Radio( | |
| ["none", "raw", "rope"], | |
| value="rope", | |
| label="Extension Method" | |
| ) | |
| with gr.Column(): | |
| rope_type = gr.Dropdown( | |
| ["linear", "dynamic", "yarn"], | |
| value="linear", | |
| label="RoPE Type", | |
| visible=True | |
| ) | |
| rope_factor = gr.Slider( | |
| minimum=1.0, | |
| maximum=8.0, | |
| value=2.0, | |
| step=0.5, | |
| label="RoPE Factor", | |
| visible=True | |
| ) | |
| # Define context_multiplier BEFORE it's used in buttons | |
| context_multiplier = gr.Dropdown( | |
| choices=["2x", "5x", "10x", "20x", "50x", "100x"], | |
| value="2x", | |
| label="📈 Context Multiplier", | |
| info="Expand context window by this factor" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Model selection | |
| model_id = gr.Textbox( | |
| value=DEFAULT_MODEL, | |
| label="🤗 Model ID", | |
| placeholder="Enter Hugging Face model ID..." | |
| ) | |
| gr.Examples([ | |
| ["Qwen/Qwen3-30B-A3B-Thinking-2507"], | |
| ["Qwen/Qwen2.5-1.5B-Instruct"], | |
| ["Qwen/Qwen2.5-3B-Instruct"], | |
| ["microsoft/phi-4-mini-instruct"], | |
| ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"], | |
| ], inputs=model_id) | |
| with gr.Row(): | |
| download_btn = gr.Button("📥 Download Model", variant="secondary") | |
| load_btn = gr.Button("🚀 Load Model", variant="primary") | |
| model_status = gr.Textbox(label="Model Status", interactive=False) | |
| # Download model function (runs outside ZeroGPU) | |
| def download_model(mid): | |
| if not mid.strip(): | |
| return "Error: Please enter a model ID" | |
| try: | |
| # Download tokenizer and config first | |
| from transformers import AutoTokenizer, AutoConfig | |
| tokenizer = AutoTokenizer.from_pretrained(mid, trust_remote_code=True) | |
| config = AutoConfig.from_pretrained(mid, trust_remote_code=True) | |
| return f"✅ Model downloaded: {mid}" | |
| except Exception as e: | |
| return f"❌ Download failed: {str(e)}" | |
| download_btn.click(download_model, inputs=[model_id], outputs=[model_status]) | |
| # Load model function (runs inside ZeroGPU) | |
| def load_model(mid, ext_method, ctx_mult, rt, rf): | |
| if not mid.strip(): | |
| return "Error: Please enter a model ID" | |
| try: | |
| base_ctx = 32768 | |
| new_ctx = calculate_context_length(base_ctx, ctx_mult) | |
| model_data = load_model_with_extension(mid, ext_method, new_ctx, rt, rf) | |
| return f"✅ Model loaded: {mid} (context: {new_ctx})" | |
| except Exception as e: | |
| return f"❌ Load failed: {str(e)}" | |
| load_btn.click(load_model, inputs=[model_id, extension_method, context_multiplier, rope_type, rope_factor], outputs=[model_status]) | |
| # Show context info | |
| with gr.Row(): | |
| base_ctx = gr.Number(value=32768, label="Base Context", interactive=False) | |
| extended_ctx = gr.Number(value=65536, label="Extended Context", interactive=False) | |
| # Update extended context when multiplier changes | |
| def update_extended_context(multiplier, base=32768): | |
| return calculate_context_length(base, multiplier) | |
| context_multiplier.change( | |
| fn=update_extended_context, | |
| inputs=[context_multiplier], | |
| outputs=extended_ctx | |
| ) | |
| model_id.change( | |
| fn=get_model_info, | |
| inputs=model_id, | |
| outputs=base_ctx | |
| ) | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(minimum=10, maximum=32768, value=256, step=10, label="Max New Tokens") | |
| temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p") | |
| # Update max_new_tokens slider max based on context multiplier | |
| def update_max_tokens(multiplier): | |
| base = 32768 | |
| max_tokens = calculate_context_length(base, multiplier) | |
| return gr.update(maximum=max_tokens) | |
| context_multiplier.change( | |
| fn=update_max_tokens, | |
| inputs=[context_multiplier], | |
| outputs=[max_new_tokens] | |
| ) | |
| # Hide/show RoPE options based on extension method | |
| def update_rope_visibility(method): | |
| return gr.update(visible=method == "rope"), gr.update(visible=method == "rope") | |
| extension_method.change( | |
| update_rope_visibility, | |
| extension_method, | |
| [rope_type, rope_factor] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### 💬 Chat with the Model") | |
| # Conversational chat interface | |
| def respond( | |
| message: str, | |
| history: list, | |
| model_id: str, | |
| extension_method: str, | |
| context_multiplier: str, | |
| rope_type: str, | |
| rope_factor: float, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ): | |
| """Handle chat response with streaming.""" | |
| if not message.strip(): | |
| yield [{"role": "user", "content": msg} for msg, _ in history] + [{"role": "user", "content": message, "content": "Please enter a message."}] | |
| return | |
| # Add user message to history | |
| history.append({"role": "user", "content": message}) | |
| yield history + [{"role": "assistant", "content": "..."}] | |
| # Generate response | |
| try: | |
| base_context = 32768 | |
| new_context_length = calculate_context_length(base_context, context_multiplier) | |
| # Build prompt from history | |
| prompt = message | |
| for item in history[:-1]: | |
| role = item.get("role", "user") | |
| content = item.get("content", "") | |
| prompt = f"User: {content}\nAssistant: " + prompt | |
| prompt = prompt + "\nAssistant:" | |
| model_data = load_model_with_extension( | |
| model_id, | |
| extension_method, | |
| new_context_length, | |
| rope_type, | |
| rope_factor | |
| ) | |
| model = model_data["model"] | |
| tokenizer = model_data["tokenizer"] | |
| # Move model to GPU for generation | |
| model = model.to("cuda") | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
| # Stream generation | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = { | |
| "inputs": inputs, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "do_sample": temperature > 0, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| "eos_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| full_response = "" | |
| for text in streamer: | |
| full_response += text | |
| # Update the last message (assistant response) | |
| current_history = history + [{"role": "assistant", "content": full_response}] | |
| yield current_history | |
| thread.join() | |
| if not full_response.strip(): | |
| full_response = "Model generated same text as input. Try adjusting parameters." | |
| yield history + [{"role": "assistant", "content": full_response}] | |
| except Exception as e: | |
| full_response = f"Error: {str(e)}" | |
| yield history + [{"role": "assistant", "content": full_response}] | |
| # ChatInterface | |
| chat_interface = gr.ChatInterface( | |
| fn=respond, | |
| additional_inputs=[ | |
| model_id, | |
| extension_method, | |
| context_multiplier, | |
| rope_type, | |
| rope_factor, | |
| max_new_tokens, | |
| temperature, | |
| top_p | |
| ], | |
| title="", | |
| description=None, | |
| autofocus=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |