Jin Zhu commited on
Commit
741fa39
·
1 Parent(s): 69d7a73

Update model.py

Browse files
Files changed (1) hide show
  1. src/FineTune/model.py +34 -37
src/FineTune/model.py CHANGED
@@ -7,6 +7,10 @@ import json
7
 
8
  import os
9
 
 
 
 
 
10
  def from_pretrained(cls, model_name, kwargs, cache_dir):
11
  # use local model if it exists
12
  if "/" in model_name:
@@ -29,9 +33,6 @@ def get_model_fullname(model_name):
29
  def load_tokenizer(model_name, for_dataset, cache_dir):
30
  model_fullname = get_model_fullname(model_name)
31
  optional_tok_kwargs = {}
32
- if "facebook/opt-" in model_fullname:
33
- print("Using non-fast tokenizer for OPT")
34
- optional_tok_kwargs['fast'] = False
35
  if for_dataset in ['pubmed']:
36
  optional_tok_kwargs['padding_side'] = 'left'
37
  else:
@@ -60,12 +61,12 @@ def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
60
 
61
  return discrepancy, log_likelihood.sum(dim=-1)
62
 
63
- class ComputeScore(nn.Module):
64
- def __init__(self, scoring_model_name, reference_model_name, dataset='xsum', device='cuda', cache_dir='./models'):
65
  super().__init__()
66
  self.device = device
67
- self.reference_model_name = get_model_fullname(reference_model_name)
68
- self.scoring_model_name = get_model_fullname(scoring_model_name)
69
 
70
  def load_model(model_name, device, cache_dir):
71
  model_fullname = get_model_fullname(model_name)
@@ -76,16 +77,16 @@ class ComputeScore(nn.Module):
76
  if torch.__version__ >= '2.0.0' and 'gemma' in model_name:
77
  model_kwargs.update({'attn_implementation': 'sdpa'})
78
  model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
79
- print('Moving model to GPU...', end='', flush=True)
80
  start = time.time()
81
  model.to(device)
82
  print(f'DONE ({time.time() - start:.2f}s)')
83
  return model
84
 
85
  # load scoring model
86
- self.scoring_tokenizer = load_tokenizer(scoring_model_name, dataset, cache_dir)
87
- scoring_model = load_model(scoring_model_name, device, cache_dir)
88
- if scoring_model_name in ['gemma-1b']:
89
  self.peft_config = LoraConfig(
90
  task_type=TaskType.CAUSAL_LM,
91
  inference_mode=False,
@@ -105,8 +106,8 @@ class ComputeScore(nn.Module):
105
  self.scoring_model = get_peft_model(scoring_model, self.peft_config)
106
 
107
  # load sampling model
108
- self.reference_tokenizer = load_tokenizer(reference_model_name, dataset, cache_dir)
109
- reference_model = load_model(reference_model_name, device, cache_dir)
110
  self.reference_model = reference_model
111
  self.reference_model.eval()
112
  for p in self.reference_model.parameters():
@@ -146,7 +147,6 @@ class ComputeScore(nn.Module):
146
  # 1. 保存 scoring_model (LoRA adapter + 基础模型)
147
  scoring_dir = os.path.join(save_directory, "scoring_model")
148
  self.scoring_model.save_pretrained(scoring_dir, safe_serialization=True)
149
- self.scoring_tokenizer.save_pretrained(scoring_dir)
150
 
151
  # 2. 保存所有 null_distr_* buffers
152
  null_distrs = {}
@@ -185,7 +185,6 @@ class ComputeScore(nn.Module):
185
  low_cpu_mem_usage=True,
186
  use_safetensors=True
187
  )
188
- model.scoring_tokenizer = AutoTokenizer.from_pretrained(scoring_dir)
189
 
190
  # 3. 加载所有 null_distr
191
  null_distrs_path = os.path.join(load_directory, "null_distrs.pt")
@@ -207,21 +206,15 @@ class ComputeScore(nn.Module):
207
  print(f"✅ Model loaded from {load_directory}")
208
  return model
209
 
210
- def get_SPO_input(self, tokenized=None, text=[""], labels=[""], training_module=False):
211
  if training_module:
212
  logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
213
- if self.reference_model_name != self.scoring_model_name:
214
- tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False, add_special_tokens=True, return_attention_mask=True).to(self.device)
215
- assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
216
- logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
217
  crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
218
  else:
219
  with torch.no_grad(): # get reference
220
  logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] # shape: [bsz, sentence_len, dim]
221
- if self.reference_model_name != self.scoring_model_name:
222
- tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False ,add_special_tokens=True, return_attention_mask=True).to(self.device)
223
- assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
224
- logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
225
  crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
226
  return crit, SPO_input, logits_score
227
 
@@ -231,13 +224,14 @@ class ComputeScore(nn.Module):
231
 
232
  tokenized = self.scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
233
  labels = tokenized.input_ids[:, 1:]
234
- train_original_crit, _, _ = self.get_SPO_input(tokenized, original_text, labels,training_module=training_module)
235
 
236
  tokenized = self.scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
237
  labels = tokenized.input_ids[:, 1:]
238
- train_sampled_crit, _, _ = self.get_SPO_input(tokenized, sampled_text, labels,training_module=training_module)
239
 
240
- output = dict(crit=[train_original_crit.detach(), train_original_crit, train_sampled_crit.detach(), train_sampled_crit])
 
241
  return output
242
 
243
  def set_null_distr(self, null_distr: torch.Tensor, domain: str):
@@ -255,7 +249,7 @@ class ComputeScore(nn.Module):
255
 
256
  # 直接覆盖 buffer,避免 delattr 带来的问题
257
  self._buffers[distr_name] = null_distr
258
- print(f"✅ Null distribution on {domain} with shape: {self._buffers[distr_name].shape}")
259
 
260
  def compute_p_value(self, text, domain: str):
261
  """
@@ -273,8 +267,8 @@ class ComputeScore(nn.Module):
273
  ).to(self.device)
274
  labels = tokenized.input_ids[:, 1:]
275
 
276
- with torch.no_grad():
277
- crit, _, _ = self.get_SPO_input(tokenized, text, labels, training_module=False)
278
 
279
  # 获取对应domain的null distribution
280
  distr_name = f"null_distr_{domain}"
@@ -283,16 +277,19 @@ class ComputeScore(nn.Module):
283
  f"No null distribution found for domain '{domain}'. "
284
  f"Available domains: {self.get_available_domains()}"
285
  )
286
-
287
  null_distr = getattr(self, distr_name)
288
-
 
 
 
 
289
  # Compute p-value: (count + 1) / (total + 1)
290
  total = null_distr.numel()
291
- count = (null_distr >= crit.unsqueeze(-1)).float().sum() # slow computation
292
- # count = total - torch.searchsorted(null_distr, crit, right=False)
293
- p_value = (count + 1) / (total + 1)
294
-
295
- return crit, p_value
296
 
297
  def get_available_domains(self):
298
  """
 
7
 
8
  import os
9
 
10
+ def calculate_MMD_loss(human_crit, sample_crit):
11
+ mmd_loss = human_crit.mean() - sample_crit.mean()
12
+ return mmd_loss
13
+
14
  def from_pretrained(cls, model_name, kwargs, cache_dir):
15
  # use local model if it exists
16
  if "/" in model_name:
 
33
  def load_tokenizer(model_name, for_dataset, cache_dir):
34
  model_fullname = get_model_fullname(model_name)
35
  optional_tok_kwargs = {}
 
 
 
36
  if for_dataset in ['pubmed']:
37
  optional_tok_kwargs['padding_side'] = 'left'
38
  else:
 
61
 
62
  return discrepancy, log_likelihood.sum(dim=-1)
63
 
64
+ class ComputeStat(nn.Module):
65
+ def __init__(self, model_name, dataset='xsum', device='cuda', cache_dir='./models'):
66
  super().__init__()
67
  self.device = device
68
+ self.reference_model_name = get_model_fullname(model_name)
69
+ self.scoring_model_name = get_model_fullname(model_name)
70
 
71
  def load_model(model_name, device, cache_dir):
72
  model_fullname = get_model_fullname(model_name)
 
77
  if torch.__version__ >= '2.0.0' and 'gemma' in model_name:
78
  model_kwargs.update({'attn_implementation': 'sdpa'})
79
  model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
80
+ print(f'Moving model to {device}...', end='', flush=True)
81
  start = time.time()
82
  model.to(device)
83
  print(f'DONE ({time.time() - start:.2f}s)')
84
  return model
85
 
86
  # load scoring model
87
+ self.scoring_tokenizer = load_tokenizer(model_name, dataset, cache_dir)
88
+ scoring_model = load_model(model_name, device, cache_dir)
89
+ if model_name in ['gemma-1b']:
90
  self.peft_config = LoraConfig(
91
  task_type=TaskType.CAUSAL_LM,
92
  inference_mode=False,
 
106
  self.scoring_model = get_peft_model(scoring_model, self.peft_config)
107
 
108
  # load sampling model
109
+ self.reference_tokenizer = load_tokenizer(model_name, dataset, cache_dir)
110
+ reference_model = load_model(model_name, device, cache_dir)
111
  self.reference_model = reference_model
112
  self.reference_model.eval()
113
  for p in self.reference_model.parameters():
 
147
  # 1. 保存 scoring_model (LoRA adapter + 基础模型)
148
  scoring_dir = os.path.join(save_directory, "scoring_model")
149
  self.scoring_model.save_pretrained(scoring_dir, safe_serialization=True)
 
150
 
151
  # 2. 保存所有 null_distr_* buffers
152
  null_distrs = {}
 
185
  low_cpu_mem_usage=True,
186
  use_safetensors=True
187
  )
 
188
 
189
  # 3. 加载所有 null_distr
190
  null_distrs_path = os.path.join(load_directory, "null_distrs.pt")
 
206
  print(f"✅ Model loaded from {load_directory}")
207
  return model
208
 
209
+ def compute_stats(self, tokenized=None, labels=[""], training_module=False):
210
  if training_module:
211
  logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
212
+ logits_ref = self.reference_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
 
 
 
213
  crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
214
  else:
215
  with torch.no_grad(): # get reference
216
  logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] # shape: [bsz, sentence_len, dim]
217
+ logits_ref = self.reference_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
 
 
 
218
  crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
219
  return crit, SPO_input, logits_score
220
 
 
224
 
225
  tokenized = self.scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
226
  labels = tokenized.input_ids[:, 1:]
227
+ train_original_crit, _, _ = self.compute_stats(tokenized, labels, training_module=training_module)
228
 
229
  tokenized = self.scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
230
  labels = tokenized.input_ids[:, 1:]
231
+ train_sampled_crit, _, _ = self.compute_stats(tokenized, labels, training_module=training_module)
232
 
233
+ MMDloss = calculate_MMD_loss(train_original_crit, train_sampled_crit)
234
+ output = dict(crit=[train_original_crit.detach(), train_original_crit, train_sampled_crit.detach(), train_sampled_crit], loss=MMDloss)
235
  return output
236
 
237
  def set_null_distr(self, null_distr: torch.Tensor, domain: str):
 
249
 
250
  # 直接覆盖 buffer,避免 delattr 带来的问题
251
  self._buffers[distr_name] = null_distr
252
+ print(f"✅ Null distribution on {domain} with shape: {self._buffers[distr_name].shape} with mean {self._buffers[distr_name].mean():.4f} and std {self._buffers[distr_name].std():.4f}")
253
 
254
  def compute_p_value(self, text, domain: str):
255
  """
 
267
  ).to(self.device)
268
  labels = tokenized.input_ids[:, 1:]
269
 
270
+ with torch.inference_mode():
271
+ crit, _, _ = self.compute_stats(tokenized, labels, training_module=False)
272
 
273
  # 获取对应domain的null distribution
274
  distr_name = f"null_distr_{domain}"
 
277
  f"No null distribution found for domain '{domain}'. "
278
  f"Available domains: {self.get_available_domains()}"
279
  )
 
280
  null_distr = getattr(self, distr_name)
281
+ p_value = self.empirical_p_value(crit, null_distr)
282
+
283
+ return crit, p_value
284
+
285
+ def empirical_p_value(self, crit: torch.Tensor, null_distr: torch.Tensor):
286
  # Compute p-value: (count + 1) / (total + 1)
287
  total = null_distr.numel()
288
+ # count = (null_distr >= crit.unsqueeze(-1)).float().sum() # slow computation
289
+ count = total - torch.searchsorted(null_distr, crit, right=False)[0]
290
+ p_value = (count + 1.0) / (total + 1.0)
291
+ # print(f"p_value (slow): {p_value} & p_value (fast): {(count + 1) / (total + 1)}", )
292
+ return p_value
293
 
294
  def get_available_domains(self):
295
  """