Mahesh2841 commited on
Commit
25317e6
·
verified ·
1 Parent(s): 90f314a

Update custom_modeling.py

Browse files
Files changed (1) hide show
  1. custom_modeling.py +58 -60
custom_modeling.py CHANGED
@@ -1,116 +1,114 @@
1
- import os
2
- import torch
3
- import transformers
4
- import tensorflow as tf
5
  from pathlib import Path
6
- from transformers import LlamaForCausalLM # if your base arch is different, change this
7
  from huggingface_hub import hf_hub_download
8
 
9
 
10
- class SafeGenerationModel(LlamaForCausalLM):
11
  """
12
- Filters toxic prompts & completions using a Keras classifier stored in the repo.
 
 
13
  """
14
 
15
- def __init__(self, config, *args, **kwargs):
16
- super().__init__(config, *args, **kwargs)
 
 
 
17
 
18
- # --- Lazy-load classifier (set up placeholder only) -------------
 
 
 
 
 
19
  self._toxicity_model = None
20
  self.toxicity_threshold = 0.6
21
 
22
- # Try loading tokenizer for prompt/output decoding
23
  try:
24
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(
25
  config.name_or_path, trust_remote_code=True
26
  )
27
- except Exception as e:
28
  self.tokenizer = None
29
- print(f"[SafeGenerationModel] tokenizer load warning: {e}")
30
 
31
- # -------------------------------------------------------------------
32
- # Internal helpers
33
- # -------------------------------------------------------------------
34
  @property
35
  def toxicity_model(self):
36
- "Load the classifier the first time we actually need it."
37
  if self._toxicity_model is None:
38
- # Ensure file is present (download if missing)
39
- keras_path = hf_hub_download(
40
- repo_id=self.config.name_or_path,
41
- filename="toxic.keras",
42
- local_dir=None, # HF cache
43
- token=None, # use default token / public repo
44
- )
45
- # load_model forces compile=False for inference-only speed
46
- self._toxicity_model = tf.keras.models.load_model(
47
- keras_path, compile=False
48
  )
 
49
  return self._toxicity_model
50
 
51
  def _is_toxic(self, text: str) -> bool:
52
  if not text.strip():
53
  return False
54
- prob = float(self.toxicity_model.predict(tf.constant([text], dtype=tf.string))[0, 0])
55
- return prob >= self.toxicity_threshold
56
 
57
- def _safe_ids(self, message: str, length: int = None):
58
- """
59
- Encode a canned safe message and (optionally) pad/truncate to *length* tokens.
60
- """
61
  if self.tokenizer is None:
62
- raise RuntimeError("Tokenizer unavailable for safe-message encoding.")
63
- ids = self.tokenizer(message, return_tensors="pt")["input_ids"][0]
64
  if length is not None:
65
- pad_id = (
66
  self.config.eos_token_id
67
  if self.config.eos_token_id is not None
68
  else (self.config.pad_token_id or 0)
69
  )
70
  if ids.size(0) < length:
71
  ids = torch.cat(
72
- [ids, torch.full((length - ids.size(0),), pad_id, dtype=torch.long)],
73
  dim=0,
74
  )
75
  else:
76
  ids = ids[:length]
77
  return ids.to(self.device)
78
 
79
- # -------------------------------------------------------------------
80
- # Main override
81
- # -------------------------------------------------------------------
 
 
 
 
 
 
82
  def generate(self, *args, **kwargs):
83
- SAFE_MSG = (
84
- "Response is toxic, please be kind to yourself and others."
85
- )
86
 
87
- # ---------- 1. Detect prompt toxicity ----------
88
  prompt_text = None
89
  if "input_ids" in kwargs and self.tokenizer is not None:
90
- ids = kwargs["input_ids"][0].tolist()
91
- prompt_text = self.tokenizer.decode(ids, skip_special_tokens=True)
 
92
  elif args and self.tokenizer is not None:
93
- ids = args[0][0].tolist()
94
- prompt_text = self.tokenizer.decode(ids, skip_special_tokens=True)
 
95
 
96
  if prompt_text and self._is_toxic(prompt_text):
97
  return self._safe_ids(SAFE_MSG).unsqueeze(0)
98
 
99
- # ---------- 2. Normal generation ----------
100
- outputs = super().generate(*args, **kwargs)
101
 
102
- # ---------- 3. Check generated text ----------
103
  if self.tokenizer is None:
104
- return outputs # cannot decode; skip toxicity check
105
 
106
  outputs_cpu = outputs.detach().cpu()
107
- safe_seqs = []
108
-
109
  for seq in outputs_cpu:
110
- text = self.tokenizer.decode(seq.tolist(), skip_special_tokens=True)
111
- if self._is_toxic(text):
112
- safe_seqs.append(self._safe_ids(SAFE_MSG, length=seq.size(0)))
113
  else:
114
- safe_seqs.append(seq)
115
-
116
- return torch.stack(safe_seqs, dim=0).to(self.device)
 
1
+ import torch, transformers, tensorflow as tf
 
 
 
2
  from pathlib import Path
3
+ from transformers import PreTrainedModel, AutoModelForCausalLM
4
  from huggingface_hub import hf_hub_download
5
 
6
 
7
+ class SafeGenerationModel(PreTrainedModel):
8
  """
9
+ Model-agnostic toxicity-filter wrapper.
10
+ Instantiates the correct backbone for ANY causal-LM config,
11
+ then intercepts generate() to filter toxic prompts & completions.
12
  """
13
 
14
+ # ------------------------------------------------------------------
15
+ # A. Standard constructor
16
+ # ------------------------------------------------------------------
17
+ def __init__(self, config, *model_args, **model_kwargs):
18
+ super().__init__(config)
19
 
20
+ # 1) Dynamically build the *real* model class that matches this config
21
+ self.base_model = AutoModelForCausalLM.from_config(
22
+ config, trust_remote_code=True
23
+ )
24
+
25
+ # 2) Lazy-load toxicity classifier (loaded on first use)
26
  self._toxicity_model = None
27
  self.toxicity_threshold = 0.6
28
 
29
+ # 3) Tokenizer (needed for prompt/output decoding)
30
  try:
31
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(
32
  config.name_or_path, trust_remote_code=True
33
  )
34
+ except Exception:
35
  self.tokenizer = None
 
36
 
37
+ # ------------------------------------------------------------------
38
+ # B. Internal helpers
39
+ # ------------------------------------------------------------------
40
  @property
41
  def toxicity_model(self):
 
42
  if self._toxicity_model is None:
43
+ path = hf_hub_download(
44
+ repo_id=self.config.name_or_path, filename="toxic.keras"
 
 
 
 
 
 
 
 
45
  )
46
+ self._toxicity_model = tf.keras.models.load_model(path, compile=False)
47
  return self._toxicity_model
48
 
49
  def _is_toxic(self, text: str) -> bool:
50
  if not text.strip():
51
  return False
52
+ score = float(self.toxicity_model.predict([text])[0, 0])
53
+ return score >= self.toxicity_threshold
54
 
55
+ def _safe_ids(self, msg: str, length=None):
 
 
 
56
  if self.tokenizer is None:
57
+ raise RuntimeError("Tokenizer missing; cannot build safe reply.")
58
+ ids = self.tokenizer(msg, return_tensors="pt")["input_ids"][0]
59
  if length is not None:
60
+ pad = (
61
  self.config.eos_token_id
62
  if self.config.eos_token_id is not None
63
  else (self.config.pad_token_id or 0)
64
  )
65
  if ids.size(0) < length:
66
  ids = torch.cat(
67
+ [ids, torch.full((length - ids.size(0),), pad, dtype=torch.long)],
68
  dim=0,
69
  )
70
  else:
71
  ids = ids[:length]
72
  return ids.to(self.device)
73
 
74
+ # ------------------------------------------------------------------
75
+ # C. Forward simply proxies to backbone
76
+ # ------------------------------------------------------------------
77
+ def forward(self, *args, **kwargs):
78
+ return self.base_model(*args, **kwargs)
79
+
80
+ # ------------------------------------------------------------------
81
+ # D. generate() override with toxicity checks
82
+ # ------------------------------------------------------------------
83
  def generate(self, *args, **kwargs):
84
+ SAFE_MSG = "Response is toxic, please be kind to yourself and others."
 
 
85
 
86
+ # ---------- 1. Check prompt ----------
87
  prompt_text = None
88
  if "input_ids" in kwargs and self.tokenizer is not None:
89
+ prompt_text = self.tokenizer.decode(
90
+ kwargs["input_ids"][0], skip_special_tokens=True
91
+ )
92
  elif args and self.tokenizer is not None:
93
+ prompt_text = self.tokenizer.decode(
94
+ args[0][0], skip_special_tokens=True
95
+ )
96
 
97
  if prompt_text and self._is_toxic(prompt_text):
98
  return self._safe_ids(SAFE_MSG).unsqueeze(0)
99
 
100
+ # ---------- 2. Normal generation ----------
101
+ outputs = self.base_model.generate(*args, **kwargs)
102
 
 
103
  if self.tokenizer is None:
104
+ return outputs # cannot decode skip toxicity check
105
 
106
  outputs_cpu = outputs.detach().cpu()
107
+ safe = []
 
108
  for seq in outputs_cpu:
109
+ txt = self.tokenizer.decode(seq, skip_special_tokens=True)
110
+ if self._is_toxic(txt):
111
+ safe.append(self._safe_ids(SAFE_MSG, length=seq.size(0)))
112
  else:
113
+ safe.append(seq)
114
+ return torch.stack(safe, dim=0).to(self.device)