anthonym21 commited on
Commit
4328168
·
verified ·
1 Parent(s): 30b5d0d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +41 -20
app.py CHANGED
@@ -22,6 +22,33 @@ DEFAULT_LAYER = 6
22
  NUM_GENERATE_TOKENS = 30
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def load_model(model_id: str, layer_idx: int):
26
  """Load a model and calibrate persona vectors."""
27
  status_lines = [f"Loading {model_id}..."]
@@ -31,30 +58,24 @@ def load_model(model_id: str, layer_idx: int):
31
  if tokenizer.pad_token is None:
32
  tokenizer.pad_token = tokenizer.eos_token
33
 
34
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
35
- device_map = "auto" if torch.cuda.is_available() else "cpu"
36
-
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_id, torch_dtype=dtype, device_map=device_map
39
- )
40
- model.eval()
41
- lens = SafetyLens(model, tokenizer)
42
-
43
- _state["lens"] = lens
44
- _state["model"] = model
45
- _state["tokenizer"] = tokenizer
46
- _state["vectors"] = {}
47
 
48
- status_lines.append(f"Loaded on {lens.device}. Calibrating persona vectors on layer {layer_idx}...")
49
  yield "\n".join(status_lines), None, None
50
 
51
- for name, stim in STIMULUS_SETS.items():
52
- vec = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx)
53
- _state["vectors"][name] = vec
54
- status_lines.append(f" Calibrated: {name}")
55
- yield "\n".join(status_lines), None, None
56
 
57
- status_lines.append("Ready for scanning.")
 
 
58
  yield "\n".join(status_lines), None, None
59
 
60
 
 
22
  NUM_GENERATE_TOKENS = 30
23
 
24
 
25
+ def _calibrate_on_gpu(model, tokenizer, layer_idx: int):
26
+ """Calibrate persona vectors — runs inside @spaces.GPU on ZeroGPU."""
27
+ if torch.cuda.is_available():
28
+ model = model.half().to("cuda")
29
+
30
+ model.eval()
31
+ lens = SafetyLens(model, tokenizer)
32
+
33
+ _state["lens"] = lens
34
+ _state["model"] = model
35
+ _state["tokenizer"] = tokenizer
36
+ _state["vectors"] = {}
37
+
38
+ vectors = {}
39
+ for name, stim in STIMULUS_SETS.items():
40
+ vec = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx)
41
+ vectors[name] = vec
42
+
43
+ _state["vectors"] = vectors
44
+ return lens.device, list(vectors.keys())
45
+
46
+
47
+ # Wrap calibration for ZeroGPU when on HF Spaces
48
+ if IS_HF_SPACE:
49
+ _calibrate_on_gpu = spaces.GPU()(_calibrate_on_gpu)
50
+
51
+
52
  def load_model(model_id: str, layer_idx: int):
53
  """Load a model and calibrate persona vectors."""
54
  status_lines = [f"Loading {model_id}..."]
 
58
  if tokenizer.pad_token is None:
59
  tokenizer.pad_token = tokenizer.eos_token
60
 
61
+ # On ZeroGPU, load on CPU first — GPU is only available inside @spaces.GPU
62
+ if IS_HF_SPACE:
63
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
64
+ else:
65
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
66
+ device_map = "auto" if torch.cuda.is_available() else "cpu"
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ model_id, torch_dtype=dtype, device_map=device_map
69
+ )
 
 
 
 
70
 
71
+ status_lines.append(f"Model loaded. Calibrating persona vectors on layer {layer_idx}...")
72
  yield "\n".join(status_lines), None, None
73
 
74
+ device, calibrated = _calibrate_on_gpu(model, tokenizer, layer_idx)
 
 
 
 
75
 
76
+ for name in calibrated:
77
+ status_lines.append(f" Calibrated: {name}")
78
+ status_lines.append(f"Ready for scanning on {device}.")
79
  yield "\n".join(status_lines), None, None
80
 
81