Benjamin-KY commited on
Commit
d00ca70
Β·
1 Parent(s): 9159bd2

Fix for CPU-only mode (no GPU quota needed)

Browse files

Critical fixes for HuggingFace CPU hardware:

1. Force CPU device_map
- device_map={'': 'cpu'} instead of 'auto'
- Prevents disk offloading errors

2. Use float32 instead of float16
- CPU doesn't support float16 efficiently
- Avoids performance issues

3. Add low_cpu_mem_usage=True
- Reduces memory footprint

4. Remove @spaces.GPU decorator
- No longer needs GPU allocation
- Works on free CPU tier

5. Reduce max_new_tokens (512->256)
- Faster CPU inference
- Still enough for vulnerable-then-educate pattern

6. Explicit CPU device for inputs
- .to('cpu') for all tensors

This allows unlimited free usage without login or quotas!

Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -15,7 +15,6 @@ import gradio as gr
15
  import torch
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
17
  from peft import PeftModel
18
- import spaces
19
  import re
20
  from typing import Dict, Tuple
21
 
@@ -29,13 +28,18 @@ LORA_ADAPTER = "Zen0/Vulnerable-Edu-Qwen3B"
29
  print("πŸ”„ Loading base model (Qwen2.5-3B-Instruct)...")
30
  model = AutoModelForCausalLM.from_pretrained(
31
  BASE_MODEL,
32
- torch_dtype=torch.float16,
33
- device_map="auto",
 
34
  trust_remote_code=True
35
  )
36
 
37
  print("πŸ”„ Loading LoRA adapter (vulnerable education)...")
38
- model = PeftModel.from_pretrained(model, LORA_ADAPTER)
 
 
 
 
39
 
40
  tokenizer = AutoTokenizer.from_pretrained(
41
  BASE_MODEL,
@@ -107,8 +111,7 @@ validator = InputValidator()
107
  # Inference Functions
108
  # ============================================================================
109
 
110
- @spaces.GPU
111
- def query_vulnerable_model(prompt: str, max_new_tokens: int = 512) -> str:
112
  """Query the VULNERABLE model (no defences)"""
113
  # Format prompt using Qwen2.5 chat template
114
  messages = [
@@ -120,7 +123,7 @@ def query_vulnerable_model(prompt: str, max_new_tokens: int = 512) -> str:
120
  add_generation_prompt=True
121
  )
122
 
123
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
124
  input_length = inputs.input_ids.shape[1]
125
 
126
  with torch.no_grad():
@@ -138,7 +141,7 @@ def query_vulnerable_model(prompt: str, max_new_tokens: int = 512) -> str:
138
  response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
139
  return response
140
 
141
- def query_defended_model(prompt: str, max_new_tokens: int = 512) -> Tuple[str, Dict]:
142
  """Query the model WITH defences"""
143
  # Layer 1: Input Validation
144
  validation = validator.detect(prompt)
 
15
  import torch
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
17
  from peft import PeftModel
 
18
  import re
19
  from typing import Dict, Tuple
20
 
 
28
  print("πŸ”„ Loading base model (Qwen2.5-3B-Instruct)...")
29
  model = AutoModelForCausalLM.from_pretrained(
30
  BASE_MODEL,
31
+ torch_dtype=torch.float32, # CPU doesn't support float16 well
32
+ device_map={"": "cpu"}, # Force CPU
33
+ low_cpu_mem_usage=True,
34
  trust_remote_code=True
35
  )
36
 
37
  print("πŸ”„ Loading LoRA adapter (vulnerable education)...")
38
+ model = PeftModel.from_pretrained(
39
+ model,
40
+ LORA_ADAPTER,
41
+ device_map={"": "cpu"}
42
+ )
43
 
44
  tokenizer = AutoTokenizer.from_pretrained(
45
  BASE_MODEL,
 
111
  # Inference Functions
112
  # ============================================================================
113
 
114
+ def query_vulnerable_model(prompt: str, max_new_tokens: int = 256) -> str:
 
115
  """Query the VULNERABLE model (no defences)"""
116
  # Format prompt using Qwen2.5 chat template
117
  messages = [
 
123
  add_generation_prompt=True
124
  )
125
 
126
+ inputs = tokenizer(text, return_tensors="pt").to("cpu")
127
  input_length = inputs.input_ids.shape[1]
128
 
129
  with torch.no_grad():
 
141
  response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
142
  return response
143
 
144
+ def query_defended_model(prompt: str, max_new_tokens: int = 256) -> Tuple[str, Dict]:
145
  """Query the model WITH defences"""
146
  # Layer 1: Input Validation
147
  validation = validator.detect(prompt)