llaa33219 commited on
Commit
adbe710
·
verified ·
1 Parent(s): 360a4ff

Upload 3 files

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -41,8 +41,11 @@ def calculate_context_length(base_context, multiplier):
41
  return base_context * multipliers.get(multiplier, 2)
42
 
43
 
44
- def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor, device="cuda"):
45
- cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}_{device}"
 
 
 
46
 
47
  if cache_key in model_cache:
48
  return model_cache[cache_key]
@@ -74,7 +77,7 @@ def load_model_with_extension(model_id, extension_method, new_context_length, ro
74
  model_id,
75
  config=config,
76
  torch_dtype=torch_dtype,
77
- device_map=device,
78
  low_cpu_mem_usage=True,
79
  trust_remote_code=True
80
  )
@@ -223,6 +226,7 @@ with gr.Blocks(title="Context Window Extender - Chat") as demo:
223
  gr.Markdown("### 💬 Chat with the Model")
224
 
225
  # Conversational chat interface
 
226
  def respond(
227
  message: str,
228
  history: list,
@@ -268,7 +272,10 @@ with gr.Blocks(title="Context Window Extender - Chat") as demo:
268
  model = model_data["model"]
269
  tokenizer = model_data["tokenizer"]
270
 
271
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
272
 
273
  # Stream generation
274
  from transformers import TextIteratorStreamer
 
41
  return base_context * multipliers.get(multiplier, 2)
42
 
43
 
44
+ def load_model_with_extension(model_id, extension_method, new_context_length, rope_type, rope_factor):
45
+ """Load model - CPU by default, ZeroGPU will handle GPU allocation."""
46
+ device = "cpu" # Use CPU, ZeroGPU will move to GPU when needed
47
+
48
+ cache_key = f"{model_id}_{extension_method}_{new_context_length}_{rope_type}_{rope_factor}"
49
 
50
  if cache_key in model_cache:
51
  return model_cache[cache_key]
 
77
  model_id,
78
  config=config,
79
  torch_dtype=torch_dtype,
80
+ device_map="cpu", # Load on CPU, ZeroGPU handles GPU
81
  low_cpu_mem_usage=True,
82
  trust_remote_code=True
83
  )
 
226
  gr.Markdown("### 💬 Chat with the Model")
227
 
228
  # Conversational chat interface
229
+ @spaces.GPU(duration=120)
230
  def respond(
231
  message: str,
232
  history: list,
 
272
  model = model_data["model"]
273
  tokenizer = model_data["tokenizer"]
274
 
275
+ # Move model to GPU for generation
276
+ model = model.to("cuda")
277
+
278
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
279
 
280
  # Stream generation
281
  from transformers import TextIteratorStreamer