Spaces:
Running on Zero
Running on Zero
Upload 3 files
Browse files
app.py
CHANGED
|
@@ -41,8 +41,11 @@ def calculate_context_length(base_context, multiplier):
|
|
| 41 |
return base_context * multipliers.get(multiplier, 2)
|
| 42 |
|
| 43 |
|
| 44 |
-
def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
if cache_key in model_cache:
|
| 48 |
return model_cache[cache_key]
|
|
@@ -74,7 +77,7 @@ def load_model_with_extension(model_id, extension_method, new_context_length, ro
|
|
| 74 |
model_id,
|
| 75 |
config=config,
|
| 76 |
torch_dtype=torch_dtype,
|
| 77 |
-
device_map=
|
| 78 |
low_cpu_mem_usage=True,
|
| 79 |
trust_remote_code=True
|
| 80 |
)
|
|
@@ -223,6 +226,7 @@ with gr.Blocks(title="Context Window Extender - Chat") as demo:
|
|
| 223 |
gr.Markdown("### 💬 Chat with the Model")
|
| 224 |
|
| 225 |
# Conversational chat interface
|
|
|
|
| 226 |
def respond(
|
| 227 |
message: str,
|
| 228 |
history: list,
|
|
@@ -268,7 +272,10 @@ with gr.Blocks(title="Context Window Extender - Chat") as demo:
|
|
| 268 |
model = model_data["model"]
|
| 269 |
tokenizer = model_data["tokenizer"]
|
| 270 |
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
# Stream generation
|
| 274 |
from transformers import TextIteratorStreamer
|
|
|
|
| 41 |
return base_context * multipliers.get(multiplier, 2)
|
| 42 |
|
| 43 |
|
| 44 |
+
def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor):
|
| 45 |
+
"""Load model - CPU by default, ZeroGPU will handle GPU allocation."""
|
| 46 |
+
device = "cpu" # Use CPU, ZeroGPU will move to GPU when needed
|
| 47 |
+
|
| 48 |
+
cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}"
|
| 49 |
|
| 50 |
if cache_key in model_cache:
|
| 51 |
return model_cache[cache_key]
|
|
|
|
| 77 |
model_id,
|
| 78 |
config=config,
|
| 79 |
torch_dtype=torch_dtype,
|
| 80 |
+
device_map="cpu", # Load on CPU, ZeroGPU handles GPU
|
| 81 |
low_cpu_mem_usage=True,
|
| 82 |
trust_remote_code=True
|
| 83 |
)
|
|
|
|
| 226 |
gr.Markdown("### 💬 Chat with the Model")
|
| 227 |
|
| 228 |
# Conversational chat interface
|
| 229 |
+
@spaces.GPU(duration=120)
|
| 230 |
def respond(
|
| 231 |
message: str,
|
| 232 |
history: list,
|
|
|
|
| 272 |
model = model_data["model"]
|
| 273 |
tokenizer = model_data["tokenizer"]
|
| 274 |
|
| 275 |
+
# Move model to GPU for generation
|
| 276 |
+
model = model.to("cuda")
|
| 277 |
+
|
| 278 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 279 |
|
| 280 |
# Stream generation
|
| 281 |
from transformers import TextIteratorStreamer
|