Spaces:
Running
Running
Jin Zhu
commited on
Commit
·
741fa39
1
Parent(s):
69d7a73
Update model.py
Browse files- src/FineTune/model.py +34 -37
src/FineTune/model.py
CHANGED
|
@@ -7,6 +7,10 @@ import json
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def from_pretrained(cls, model_name, kwargs, cache_dir):
|
| 11 |
# use local model if it exists
|
| 12 |
if "/" in model_name:
|
|
@@ -29,9 +33,6 @@ def get_model_fullname(model_name):
|
|
| 29 |
def load_tokenizer(model_name, for_dataset, cache_dir):
|
| 30 |
model_fullname = get_model_fullname(model_name)
|
| 31 |
optional_tok_kwargs = {}
|
| 32 |
-
if "facebook/opt-" in model_fullname:
|
| 33 |
-
print("Using non-fast tokenizer for OPT")
|
| 34 |
-
optional_tok_kwargs['fast'] = False
|
| 35 |
if for_dataset in ['pubmed']:
|
| 36 |
optional_tok_kwargs['padding_side'] = 'left'
|
| 37 |
else:
|
|
@@ -60,12 +61,12 @@ def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
|
|
| 60 |
|
| 61 |
return discrepancy, log_likelihood.sum(dim=-1)
|
| 62 |
|
| 63 |
-
class
|
| 64 |
-
def __init__(self,
|
| 65 |
super().__init__()
|
| 66 |
self.device = device
|
| 67 |
-
self.reference_model_name = get_model_fullname(
|
| 68 |
-
self.scoring_model_name = get_model_fullname(
|
| 69 |
|
| 70 |
def load_model(model_name, device, cache_dir):
|
| 71 |
model_fullname = get_model_fullname(model_name)
|
|
@@ -76,16 +77,16 @@ class ComputeScore(nn.Module):
|
|
| 76 |
if torch.__version__ >= '2.0.0' and 'gemma' in model_name:
|
| 77 |
model_kwargs.update({'attn_implementation': 'sdpa'})
|
| 78 |
model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
|
| 79 |
-
print('Moving model to
|
| 80 |
start = time.time()
|
| 81 |
model.to(device)
|
| 82 |
print(f'DONE ({time.time() - start:.2f}s)')
|
| 83 |
return model
|
| 84 |
|
| 85 |
# load scoring model
|
| 86 |
-
self.scoring_tokenizer = load_tokenizer(
|
| 87 |
-
scoring_model = load_model(
|
| 88 |
-
if
|
| 89 |
self.peft_config = LoraConfig(
|
| 90 |
task_type=TaskType.CAUSAL_LM,
|
| 91 |
inference_mode=False,
|
|
@@ -105,8 +106,8 @@ class ComputeScore(nn.Module):
|
|
| 105 |
self.scoring_model = get_peft_model(scoring_model, self.peft_config)
|
| 106 |
|
| 107 |
# load sampling model
|
| 108 |
-
self.reference_tokenizer = load_tokenizer(
|
| 109 |
-
reference_model = load_model(
|
| 110 |
self.reference_model = reference_model
|
| 111 |
self.reference_model.eval()
|
| 112 |
for p in self.reference_model.parameters():
|
|
@@ -146,7 +147,6 @@ class ComputeScore(nn.Module):
|
|
| 146 |
# 1. 保存 scoring_model (LoRA adapter + 基础模型)
|
| 147 |
scoring_dir = os.path.join(save_directory, "scoring_model")
|
| 148 |
self.scoring_model.save_pretrained(scoring_dir, safe_serialization=True)
|
| 149 |
-
self.scoring_tokenizer.save_pretrained(scoring_dir)
|
| 150 |
|
| 151 |
# 2. 保存所有 null_distr_* buffers
|
| 152 |
null_distrs = {}
|
|
@@ -185,7 +185,6 @@ class ComputeScore(nn.Module):
|
|
| 185 |
low_cpu_mem_usage=True,
|
| 186 |
use_safetensors=True
|
| 187 |
)
|
| 188 |
-
model.scoring_tokenizer = AutoTokenizer.from_pretrained(scoring_dir)
|
| 189 |
|
| 190 |
# 3. 加载所有 null_distr
|
| 191 |
null_distrs_path = os.path.join(load_directory, "null_distrs.pt")
|
|
@@ -207,21 +206,15 @@ class ComputeScore(nn.Module):
|
|
| 207 |
print(f"✅ Model loaded from {load_directory}")
|
| 208 |
return model
|
| 209 |
|
| 210 |
-
def
|
| 211 |
if training_module:
|
| 212 |
logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
|
| 213 |
-
|
| 214 |
-
tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False, add_special_tokens=True, return_attention_mask=True).to(self.device)
|
| 215 |
-
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
|
| 216 |
-
logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
|
| 217 |
crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
|
| 218 |
else:
|
| 219 |
with torch.no_grad(): # get reference
|
| 220 |
logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] # shape: [bsz, sentence_len, dim]
|
| 221 |
-
|
| 222 |
-
tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False ,add_special_tokens=True, return_attention_mask=True).to(self.device)
|
| 223 |
-
assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
|
| 224 |
-
logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
|
| 225 |
crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
|
| 226 |
return crit, SPO_input, logits_score
|
| 227 |
|
|
@@ -231,13 +224,14 @@ class ComputeScore(nn.Module):
|
|
| 231 |
|
| 232 |
tokenized = self.scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
|
| 233 |
labels = tokenized.input_ids[:, 1:]
|
| 234 |
-
train_original_crit, _, _ = self.
|
| 235 |
|
| 236 |
tokenized = self.scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
|
| 237 |
labels = tokenized.input_ids[:, 1:]
|
| 238 |
-
train_sampled_crit, _, _ = self.
|
| 239 |
|
| 240 |
-
|
|
|
|
| 241 |
return output
|
| 242 |
|
| 243 |
def set_null_distr(self, null_distr: torch.Tensor, domain: str):
|
|
@@ -255,7 +249,7 @@ class ComputeScore(nn.Module):
|
|
| 255 |
|
| 256 |
# 直接覆盖 buffer,避免 delattr 带来的问题
|
| 257 |
self._buffers[distr_name] = null_distr
|
| 258 |
-
print(f"✅ Null distribution on {domain} with shape: {self._buffers[distr_name].shape}")
|
| 259 |
|
| 260 |
def compute_p_value(self, text, domain: str):
|
| 261 |
"""
|
|
@@ -273,8 +267,8 @@ class ComputeScore(nn.Module):
|
|
| 273 |
).to(self.device)
|
| 274 |
labels = tokenized.input_ids[:, 1:]
|
| 275 |
|
| 276 |
-
with torch.
|
| 277 |
-
crit, _, _ = self.
|
| 278 |
|
| 279 |
# 获取对应domain的null distribution
|
| 280 |
distr_name = f"null_distr_{domain}"
|
|
@@ -283,16 +277,19 @@ class ComputeScore(nn.Module):
|
|
| 283 |
f"No null distribution found for domain '{domain}'. "
|
| 284 |
f"Available domains: {self.get_available_domains()}"
|
| 285 |
)
|
| 286 |
-
|
| 287 |
null_distr = getattr(self, distr_name)
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
# Compute p-value: (count + 1) / (total + 1)
|
| 290 |
total = null_distr.numel()
|
| 291 |
-
count = (null_distr >= crit.unsqueeze(-1)).float().sum() # slow computation
|
| 292 |
-
|
| 293 |
-
p_value = (count + 1) / (total + 1)
|
| 294 |
-
|
| 295 |
-
return
|
| 296 |
|
| 297 |
def get_available_domains(self):
|
| 298 |
"""
|
|
|
|
| 7 |
|
| 8 |
import os
|
| 9 |
|
| 10 |
+
def calculate_MMD_loss(human_crit, sample_crit):
|
| 11 |
+
mmd_loss = human_crit.mean() - sample_crit.mean()
|
| 12 |
+
return mmd_loss
|
| 13 |
+
|
| 14 |
def from_pretrained(cls, model_name, kwargs, cache_dir):
|
| 15 |
# use local model if it exists
|
| 16 |
if "/" in model_name:
|
|
|
|
| 33 |
def load_tokenizer(model_name, for_dataset, cache_dir):
|
| 34 |
model_fullname = get_model_fullname(model_name)
|
| 35 |
optional_tok_kwargs = {}
|
|
|
|
|
|
|
|
|
|
| 36 |
if for_dataset in ['pubmed']:
|
| 37 |
optional_tok_kwargs['padding_side'] = 'left'
|
| 38 |
else:
|
|
|
|
| 61 |
|
| 62 |
return discrepancy, log_likelihood.sum(dim=-1)
|
| 63 |
|
| 64 |
+
class ComputeStat(nn.Module):
|
| 65 |
+
def __init__(self, model_name, dataset='xsum', device='cuda', cache_dir='./models'):
|
| 66 |
super().__init__()
|
| 67 |
self.device = device
|
| 68 |
+
self.reference_model_name = get_model_fullname(model_name)
|
| 69 |
+
self.scoring_model_name = get_model_fullname(model_name)
|
| 70 |
|
| 71 |
def load_model(model_name, device, cache_dir):
|
| 72 |
model_fullname = get_model_fullname(model_name)
|
|
|
|
| 77 |
if torch.__version__ >= '2.0.0' and 'gemma' in model_name:
|
| 78 |
model_kwargs.update({'attn_implementation': 'sdpa'})
|
| 79 |
model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
|
| 80 |
+
print(f'Moving model to {device}...', end='', flush=True)
|
| 81 |
start = time.time()
|
| 82 |
model.to(device)
|
| 83 |
print(f'DONE ({time.time() - start:.2f}s)')
|
| 84 |
return model
|
| 85 |
|
| 86 |
# load scoring model
|
| 87 |
+
self.scoring_tokenizer = load_tokenizer(model_name, dataset, cache_dir)
|
| 88 |
+
scoring_model = load_model(model_name, device, cache_dir)
|
| 89 |
+
if model_name in ['gemma-1b']:
|
| 90 |
self.peft_config = LoraConfig(
|
| 91 |
task_type=TaskType.CAUSAL_LM,
|
| 92 |
inference_mode=False,
|
|
|
|
| 106 |
self.scoring_model = get_peft_model(scoring_model, self.peft_config)
|
| 107 |
|
| 108 |
# load sampling model
|
| 109 |
+
self.reference_tokenizer = load_tokenizer(model_name, dataset, cache_dir)
|
| 110 |
+
reference_model = load_model(model_name, device, cache_dir)
|
| 111 |
self.reference_model = reference_model
|
| 112 |
self.reference_model.eval()
|
| 113 |
for p in self.reference_model.parameters():
|
|
|
|
| 147 |
# 1. 保存 scoring_model (LoRA adapter + 基础模型)
|
| 148 |
scoring_dir = os.path.join(save_directory, "scoring_model")
|
| 149 |
self.scoring_model.save_pretrained(scoring_dir, safe_serialization=True)
|
|
|
|
| 150 |
|
| 151 |
# 2. 保存所有 null_distr_* buffers
|
| 152 |
null_distrs = {}
|
|
|
|
| 185 |
low_cpu_mem_usage=True,
|
| 186 |
use_safetensors=True
|
| 187 |
)
|
|
|
|
| 188 |
|
| 189 |
# 3. 加载所有 null_distr
|
| 190 |
null_distrs_path = os.path.join(load_directory, "null_distrs.pt")
|
|
|
|
| 206 |
print(f"✅ Model loaded from {load_directory}")
|
| 207 |
return model
|
| 208 |
|
| 209 |
+
def compute_stats(self, tokenized=None, labels=[""], training_module=False):
|
| 210 |
if training_module:
|
| 211 |
logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
|
| 212 |
+
logits_ref = self.reference_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
|
|
|
|
|
|
|
|
|
|
| 213 |
crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
|
| 214 |
else:
|
| 215 |
with torch.no_grad(): # get reference
|
| 216 |
logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] # shape: [bsz, sentence_len, dim]
|
| 217 |
+
logits_ref = self.reference_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
|
|
|
|
|
|
|
|
|
|
| 218 |
crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
|
| 219 |
return crit, SPO_input, logits_score
|
| 220 |
|
|
|
|
| 224 |
|
| 225 |
tokenized = self.scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
|
| 226 |
labels = tokenized.input_ids[:, 1:]
|
| 227 |
+
train_original_crit, _, _ = self.compute_stats(tokenized, labels, training_module=training_module)
|
| 228 |
|
| 229 |
tokenized = self.scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
|
| 230 |
labels = tokenized.input_ids[:, 1:]
|
| 231 |
+
train_sampled_crit, _, _ = self.compute_stats(tokenized, labels, training_module=training_module)
|
| 232 |
|
| 233 |
+
MMDloss = calculate_MMD_loss(train_original_crit, train_sampled_crit)
|
| 234 |
+
output = dict(crit=[train_original_crit.detach(), train_original_crit, train_sampled_crit.detach(), train_sampled_crit], loss=MMDloss)
|
| 235 |
return output
|
| 236 |
|
| 237 |
def set_null_distr(self, null_distr: torch.Tensor, domain: str):
|
|
|
|
| 249 |
|
| 250 |
# 直接覆盖 buffer,避免 delattr 带来的问题
|
| 251 |
self._buffers[distr_name] = null_distr
|
| 252 |
+
print(f"✅ Null distribution on {domain} with shape: {self._buffers[distr_name].shape} with mean {self._buffers[distr_name].mean():.4f} and std {self._buffers[distr_name].std():.4f}")
|
| 253 |
|
| 254 |
def compute_p_value(self, text, domain: str):
|
| 255 |
"""
|
|
|
|
| 267 |
).to(self.device)
|
| 268 |
labels = tokenized.input_ids[:, 1:]
|
| 269 |
|
| 270 |
+
with torch.inference_mode():
|
| 271 |
+
crit, _, _ = self.compute_stats(tokenized, labels, training_module=False)
|
| 272 |
|
| 273 |
# 获取对应domain的null distribution
|
| 274 |
distr_name = f"null_distr_{domain}"
|
|
|
|
| 277 |
f"No null distribution found for domain '{domain}'. "
|
| 278 |
f"Available domains: {self.get_available_domains()}"
|
| 279 |
)
|
|
|
|
| 280 |
null_distr = getattr(self, distr_name)
|
| 281 |
+
p_value = self.empirical_p_value(crit, null_distr)
|
| 282 |
+
|
| 283 |
+
return crit, p_value
|
| 284 |
+
|
| 285 |
+
def empirical_p_value(self, crit: torch.Tensor, null_distr: torch.Tensor):
|
| 286 |
# Compute p-value: (count + 1) / (total + 1)
|
| 287 |
total = null_distr.numel()
|
| 288 |
+
# count = (null_distr >= crit.unsqueeze(-1)).float().sum() # slow computation
|
| 289 |
+
count = total - torch.searchsorted(null_distr, crit, right=False)[0]
|
| 290 |
+
p_value = (count + 1.0) / (total + 1.0)
|
| 291 |
+
# print(f"p_value (slow): {p_value} & p_value (fast): {(count + 1) / (total + 1)}", )
|
| 292 |
+
return p_value
|
| 293 |
|
| 294 |
def get_available_domains(self):
|
| 295 |
"""
|