Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForCausalLM | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import time | |
| import json | |
| import os | |
| def calculate_MMD_loss(human_crit, sample_crit): | |
| mmd_loss = human_crit.mean() - sample_crit.mean() | |
| return mmd_loss | |
| def from_pretrained(cls, model_name, kwargs, cache_dir): | |
| # use local model if it exists | |
| if "/" in model_name: | |
| local_path = os.path.join(cache_dir, model_name.split("/")[1]) | |
| else: | |
| local_path = os.path.join(cache_dir, model_name) | |
| if os.path.exists(local_path): | |
| return cls.from_pretrained(local_path, **kwargs) | |
| return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir, device_map='auto') | |
| model_fullnames = { | |
| 'gemma-1b': 'google/gemma-3-1b-pt', | |
| } | |
| float16_models = [] | |
| def get_model_fullname(model_name): | |
| return model_fullnames[model_name] if model_name in model_fullnames else model_name | |
| def load_tokenizer(model_name, for_dataset, cache_dir): | |
| model_fullname = get_model_fullname(model_name) | |
| optional_tok_kwargs = {} | |
| if for_dataset in ['pubmed']: | |
| optional_tok_kwargs['padding_side'] = 'left' | |
| else: | |
| optional_tok_kwargs['padding_side'] = 'right' | |
| base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir) | |
| if base_tokenizer.pad_token_id is None: | |
| base_tokenizer.pad_token_id = base_tokenizer.eos_token_id | |
| if '13b' in model_fullname: | |
| base_tokenizer.pad_token_id = 0 | |
| return base_tokenizer | |
| def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels): | |
| if logits_ref.size(-1) != logits_score.size(-1): | |
| vocab_size = min(logits_ref.size(-1), logits_score.size(-1)) | |
| logits_ref = logits_ref[:, :, :vocab_size] | |
| logits_score = logits_score[:, :, :vocab_size] | |
| labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels | |
| lprobs_score = torch.log_softmax(logits_score, dim=-1) | |
| probs_ref = torch.softmax(logits_ref, dim=-1) | |
| log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1) | |
| mean_ref = (probs_ref * lprobs_score).sum(dim=-1) | |
| var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref) | |
| discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).clamp_min(0.0001).sqrt() | |
| return discrepancy, log_likelihood.sum(dim=-1) | |
| class ComputeStat(nn.Module): | |
| def __init__(self, model_name, dataset='xsum', device='cuda', cache_dir='./models'): | |
| super().__init__() | |
| self.device = device | |
| self.reference_model_name = get_model_fullname(model_name) | |
| self.scoring_model_name = get_model_fullname(model_name) | |
| def load_model(model_name, device, cache_dir): | |
| model_fullname = get_model_fullname(model_name) | |
| print(f'Loading model {model_fullname}...') | |
| model_kwargs = {} | |
| if model_name in float16_models: | |
| model_kwargs.update(dict(torch_dtype=torch.float16)) | |
| if torch.__version__ >= '2.0.0' and 'gemma' in model_name: | |
| model_kwargs.update({'attn_implementation': 'sdpa'}) | |
| model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir) | |
| print(f'Moving model to {device}...', end='', flush=True) | |
| start = time.time() | |
| model.to(device) | |
| print(f'DONE ({time.time() - start:.2f}s)') | |
| return model | |
| # load scoring model | |
| self.scoring_tokenizer = load_tokenizer(model_name, dataset, cache_dir) | |
| scoring_model = load_model(model_name, device, cache_dir) | |
| if model_name in ['gemma-1b']: | |
| self.peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=4, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| ) | |
| else: | |
| self.peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=8, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| ) | |
| self.scoring_model = get_peft_model(scoring_model, self.peft_config) | |
| # load sampling model | |
| self.reference_tokenizer = load_tokenizer(model_name, dataset, cache_dir) | |
| reference_model = load_model(model_name, device, cache_dir) | |
| self.reference_model = reference_model | |
| self.reference_model.eval() | |
| for p in self.reference_model.parameters(): | |
| p.requires_grad = False | |
| total = sum(p.numel() for p in self.scoring_model.parameters()) | |
| trainable = sum(p.numel() for p in self.scoring_model.parameters() if p.requires_grad) | |
| print(f"Trainable / total (parameters): {trainable}/{total}={trainable/total}") | |
| def set_criterion_fn(self, criterion_fn): | |
| if criterion_fn == "mean": | |
| self.criterion = 'mean' | |
| self.criterion_fn = get_sampling_discrepancy_analytic | |
| else: | |
| raise ValueError(f"Unknown criterion function: {criterion_fn}") | |
| def print_gradient_requirement(self): | |
| for name, param in self.named_parameters(): | |
| gradient_requirement = 'Requires Grad' if param.requires_grad else 'Does not require grad' | |
| color_code = '\033[92m' if param.requires_grad else '\033[91m' # Green for requires grad, red for does not require grad | |
| reset_color = '\033[0m' # Reset color after printing | |
| print(f"{name}: {color_code}{gradient_requirement}{reset_color}") | |
| def register_no_grad(self, module_names): | |
| for name, param in self.named_parameters(): | |
| for selected_module in module_names: | |
| # print(selected_module, name) | |
| if selected_module in name: | |
| param.requires_grad = False | |
| def save_pretrained(self, save_directory: str, save_null_distr_only=False): | |
| """ | |
| Save the scoring model (with LoRA adapter) and all null_distr buffers in Hugging Face format. | |
| """ | |
| os.makedirs(save_directory, exist_ok=True) | |
| # 1. 保存 scoring_model (LoRA adapter + 基础模型) | |
| if not save_null_distr_only: | |
| scoring_dir = os.path.join(save_directory, "scoring_model") | |
| self.scoring_model.save_pretrained(scoring_dir, safe_serialization=True) | |
| # 2. 保存所有 null_distr_* buffers | |
| null_distrs = {} | |
| for buffer_name, buffer_value in self.named_buffers(): | |
| if buffer_name.startswith("null_distr_"): | |
| domain = buffer_name.replace("null_distr_", "") | |
| null_distrs[domain] = buffer_value.detach().cpu() | |
| if null_distrs: | |
| torch.save(null_distrs, os.path.join(save_directory, "null_distrs.pt")) | |
| print(f"✅ Saved {len(null_distrs)} null distributions: {list(null_distrs.keys())}") | |
| # 3. 保存配置信息(包括domain列表) | |
| config = { | |
| "domains": list(null_distrs.keys()), | |
| "criterion": getattr(self, "criterion", None), | |
| } | |
| with open(os.path.join(save_directory, "config.json"), "w") as f: | |
| json.dump(config, f) | |
| print(f"✅ Model saved to {save_directory}") | |
| def from_pretrained(cls, load_directory: str, *args, **kwargs): | |
| """ | |
| Load the scoring model, reference model, and all null_distr buffers. | |
| """ | |
| # 1. 初始化类 | |
| model = cls(*args, **kwargs) | |
| # 2. 加载 scoring_model | |
| scoring_dir = os.path.join(load_directory, "scoring_model") | |
| model.scoring_model = AutoPeftModelForCausalLM.from_pretrained( | |
| scoring_dir, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True | |
| ) | |
| # 3. 加载所有 null_distr | |
| null_distrs_path = os.path.join(load_directory, "null_distrs.pt") | |
| if os.path.exists(null_distrs_path): | |
| null_distrs = torch.load(null_distrs_path, map_location="cpu") | |
| for domain, null_distr in null_distrs.items(): | |
| model.set_null_distr(null_distr, domain) | |
| print(f"✅ Restored {len(null_distrs)} null distributions: {list(null_distrs.keys())}") | |
| # 4. 加载配置信息 | |
| config_path = os.path.join(load_directory, "config.json") | |
| if os.path.exists(config_path): | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| if "criterion" in config and config["criterion"] is not None: | |
| model.criterion = config["criterion"] | |
| print(f"✅ Loaded config: {config}") | |
| print(f"✅ Model loaded from {load_directory}") | |
| return model | |
| def compute_stats(self, tokenized=None, labels=[""], training_module=False): | |
| if training_module: | |
| logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] | |
| logits_ref = self.reference_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] | |
| crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels) | |
| else: | |
| with torch.no_grad(): # get reference | |
| logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] # shape: [bsz, sentence_len, dim] | |
| logits_ref = self.reference_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] | |
| crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels) | |
| return crit, SPO_input, logits_score | |
| def forward(self, text, training_module=True): | |
| original_text = text[0] | |
| sampled_text = text[1] | |
| tokenized = self.scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device) | |
| labels = tokenized.input_ids[:, 1:] | |
| train_original_crit, _, _ = self.compute_stats(tokenized, labels, training_module=training_module) | |
| tokenized = self.scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device) | |
| labels = tokenized.input_ids[:, 1:] | |
| train_sampled_crit, _, _ = self.compute_stats(tokenized, labels, training_module=training_module) | |
| MMDloss = calculate_MMD_loss(train_original_crit, train_sampled_crit) | |
| output = dict(crit=[train_original_crit.detach(), train_original_crit, train_sampled_crit.detach(), train_sampled_crit], loss=MMDloss) | |
| return output | |
| def set_null_distr(self, null_distr: torch.Tensor, domain: str): | |
| """ | |
| Set the null distribution tensor safely. | |
| """ | |
| distr_name = f"null_distr_{domain}" | |
| self.register_buffer(distr_name, torch.empty(0)) | |
| if not isinstance(null_distr, torch.Tensor): | |
| null_distr = torch.tensor(null_distr) | |
| # detach + clone + 移到正确设备 | |
| null_distr = null_distr.detach().clone().to(self.device) | |
| # 直接覆盖 buffer,避免 delattr 带来的问题 | |
| self._buffers[distr_name] = null_distr | |
| print(f"✅ Null distribution on {domain} with shape: {self._buffers[distr_name].shape} with mean {self._buffers[distr_name].mean():.4f} and std {self._buffers[distr_name].std():.4f}") | |
| def compute_p_value(self, text, domain: str): | |
| """ | |
| Compute p-value for given text using the null distribution of specified domain. | |
| Args: | |
| text: Input text to compute score for | |
| domain: Domain name to use for null distribution | |
| """ | |
| tokenized = self.scoring_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| return_token_type_ids=False | |
| ).to(self.device) | |
| labels = tokenized.input_ids[:, 1:] | |
| with torch.inference_mode(): | |
| crit, _, _ = self.compute_stats(tokenized, labels, training_module=False) | |
| # 获取对应domain的null distribution | |
| distr_name = f"null_distr_{domain}" | |
| if not hasattr(self, distr_name): | |
| raise ValueError( | |
| f"No null distribution found for domain '{domain}'. " | |
| f"Available domains: {self.get_available_domains()}" | |
| ) | |
| null_distr = getattr(self, distr_name) | |
| p_value = self.empirical_p_value(crit, null_distr) | |
| return crit, p_value | |
| def empirical_p_value(self, crit: torch.Tensor, null_distr: torch.Tensor): | |
| # Compute p-value: (count + 1) / (total + 1) | |
| total = null_distr.numel() | |
| # count = (null_distr >= crit.unsqueeze(-1)).float().sum() # slow computation | |
| count = total - torch.searchsorted(null_distr, crit, right=False)[0] | |
| p_value = (count + 1.0) / (total + 1.0) | |
| # print(f"p_value (slow): {p_value} & p_value (fast): {(count + 1) / (total + 1)}", ) | |
| return p_value | |
| def get_available_domains(self): | |
| """ | |
| Get list of all available domains with null distributions. | |
| """ | |
| domains = [] | |
| for buffer_name in self._buffers.keys(): | |
| if buffer_name.startswith("null_distr_"): | |
| domain = buffer_name.replace("null_distr_", "") | |
| domains.append(domain) | |
| return domains | |