Jin Zhu commited on
Commit
fa5ff8a
·
1 Parent(s): 9e0d8d7
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. src/FineTune/model.py +306 -0
Dockerfile CHANGED
@@ -1,6 +1,6 @@
1
  FROM python:3.10.8
2
 
3
- CMD python download_private_model.py
4
 
5
  WORKDIR /app
6
 
 
1
  FROM python:3.10.8
2
 
3
+ # CMD python download_private_model.py
4
 
5
  WORKDIR /app
6
 
src/FineTune/model.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForCausalLM
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import time
6
+ import json
7
+
8
+ import os
9
+
10
+ def from_pretrained(cls, model_name, kwargs, cache_dir):
11
+ # use local model if it exists
12
+ if "/" in model_name:
13
+ local_path = os.path.join(cache_dir, model_name.split("/")[1])
14
+ else:
15
+ local_path = os.path.join(cache_dir, model_name)
16
+
17
+ if os.path.exists(local_path):
18
+ return cls.from_pretrained(local_path, **kwargs)
19
+ return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir, device_map='auto')
20
+
21
+ model_fullnames = {
22
+ 'gemma-1b': 'google/gemma-3-1b-pt',
23
+ }
24
+ float16_models = []
25
+
26
+ def get_model_fullname(model_name):
27
+ return model_fullnames[model_name] if model_name in model_fullnames else model_name
28
+
29
+ def load_tokenizer(model_name, for_dataset, cache_dir):
30
+ model_fullname = get_model_fullname(model_name)
31
+ optional_tok_kwargs = {}
32
+ if "facebook/opt-" in model_fullname:
33
+ print("Using non-fast tokenizer for OPT")
34
+ optional_tok_kwargs['fast'] = False
35
+ if for_dataset in ['pubmed']:
36
+ optional_tok_kwargs['padding_side'] = 'left'
37
+ else:
38
+ optional_tok_kwargs['padding_side'] = 'right'
39
+ base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir)
40
+ if base_tokenizer.pad_token_id is None:
41
+ base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
42
+ if '13b' in model_fullname:
43
+ base_tokenizer.pad_token_id = 0
44
+ return base_tokenizer
45
+
46
+ def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
47
+ if logits_ref.size(-1) != logits_score.size(-1):
48
+ vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
49
+ logits_ref = logits_ref[:, :, :vocab_size]
50
+ logits_score = logits_score[:, :, :vocab_size]
51
+
52
+ labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
53
+ lprobs_score = torch.log_softmax(logits_score, dim=-1)
54
+ probs_ref = torch.softmax(logits_ref, dim=-1)
55
+
56
+ log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
57
+ mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
58
+ var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
59
+ discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).clamp_min(0.0001).sqrt()
60
+
61
+ return discrepancy, log_likelihood.sum(dim=-1)
62
+
63
+ class ComputeScore(nn.Module):
64
+ def __init__(self, scoring_model_name, reference_model_name, dataset='xsum', device='cuda', cache_dir='./models'):
65
+ super().__init__()
66
+ self.device = device
67
+ self.reference_model_name = get_model_fullname(reference_model_name)
68
+ self.scoring_model_name = get_model_fullname(scoring_model_name)
69
+
70
+ def load_model(model_name, device, cache_dir):
71
+ model_fullname = get_model_fullname(model_name)
72
+ print(f'Loading model {model_fullname}...')
73
+ model_kwargs = {}
74
+ if model_name in float16_models:
75
+ model_kwargs.update(dict(torch_dtype=torch.float16))
76
+ if torch.__version__ >= '2.0.0' and 'gemma' in model_name:
77
+ model_kwargs.update({'attn_implementation': 'sdpa'})
78
+ model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
79
+ print('Moving model to GPU...', end='', flush=True)
80
+ start = time.time()
81
+ model.to(device)
82
+ print(f'DONE ({time.time() - start:.2f}s)')
83
+ return model
84
+
85
+ # load scoring model
86
+ self.scoring_tokenizer = load_tokenizer(scoring_model_name, dataset, cache_dir)
87
+ scoring_model = load_model(scoring_model_name, device, cache_dir)
88
+ if scoring_model_name in ['gemma-1b']:
89
+ self.peft_config = LoraConfig(
90
+ task_type=TaskType.CAUSAL_LM,
91
+ inference_mode=False,
92
+ r=8,
93
+ lora_alpha=32,
94
+ lora_dropout=0.1,
95
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
96
+ )
97
+ else:
98
+ self.peft_config = LoraConfig(
99
+ task_type=TaskType.CAUSAL_LM,
100
+ inference_mode=False,
101
+ r=8,
102
+ lora_alpha=32,
103
+ lora_dropout=0.1,
104
+ )
105
+ self.scoring_model = get_peft_model(scoring_model, self.peft_config)
106
+
107
+ # load sampling model
108
+ self.reference_tokenizer = load_tokenizer(reference_model_name, dataset, cache_dir)
109
+ reference_model = load_model(reference_model_name, device, cache_dir)
110
+ self.reference_model = reference_model
111
+ self.reference_model.eval()
112
+ for p in self.reference_model.parameters():
113
+ p.requires_grad = False
114
+
115
+ total = sum(p.numel() for p in self.scoring_model.parameters())
116
+ trainable = sum(p.numel() for p in self.scoring_model.parameters() if p.requires_grad)
117
+ print(f"Trainable / total (parameters): {trainable}/{total}={trainable/total}")
118
+
119
+ def set_criterion_fn(self, criterion_fn):
120
+ if criterion_fn == "mean":
121
+ self.criterion = 'mean'
122
+ self.criterion_fn = get_sampling_discrepancy_analytic
123
+ else:
124
+ raise ValueError(f"Unknown criterion function: {criterion_fn}")
125
+
126
+ def print_gradient_requirement(self):
127
+ for name, param in self.named_parameters():
128
+ gradient_requirement = 'Requires Grad' if param.requires_grad else 'Does not require grad'
129
+ color_code = '\033[92m' if param.requires_grad else '\033[91m' # Green for requires grad, red for does not require grad
130
+ reset_color = '\033[0m' # Reset color after printing
131
+ print(f"{name}: {color_code}{gradient_requirement}{reset_color}")
132
+
133
+ def register_no_grad(self, module_names):
134
+ for name, param in self.named_parameters():
135
+ for selected_module in module_names:
136
+ # print(selected_module, name)
137
+ if selected_module in name:
138
+ param.requires_grad = False
139
+
140
+ def save_pretrained(self, save_directory: str):
141
+ """
142
+ Save the scoring model (with LoRA adapter) and all null_distr buffers in Hugging Face format.
143
+ """
144
+ os.makedirs(save_directory, exist_ok=True)
145
+
146
+ # 1. 保存 scoring_model (LoRA adapter + 基础模型)
147
+ scoring_dir = os.path.join(save_directory, "scoring_model")
148
+ self.scoring_model.save_pretrained(scoring_dir, safe_serialization=True)
149
+ self.scoring_tokenizer.save_pretrained(scoring_dir)
150
+
151
+ # 2. 保存所有 null_distr_* buffers
152
+ null_distrs = {}
153
+ for buffer_name, buffer_value in self.named_buffers():
154
+ if buffer_name.startswith("null_distr_"):
155
+ domain = buffer_name.replace("null_distr_", "")
156
+ null_distrs[domain] = buffer_value.detach().cpu()
157
+
158
+ if null_distrs:
159
+ torch.save(null_distrs, os.path.join(save_directory, "null_distrs.pt"))
160
+ print(f"✅ Saved {len(null_distrs)} null distributions: {list(null_distrs.keys())}")
161
+
162
+ # 3. 保存配置信息(包括domain列表)
163
+ config = {
164
+ "domains": list(null_distrs.keys()),
165
+ "criterion": getattr(self, "criterion", None),
166
+ }
167
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
168
+ json.dump(config, f)
169
+
170
+ print(f"✅ Model saved to {save_directory}")
171
+
172
+ @classmethod
173
+ def from_pretrained(cls, load_directory: str, *args, **kwargs):
174
+ """
175
+ Load the scoring model, reference model, and all null_distr buffers.
176
+ """
177
+ # 1. 初始化类
178
+ model = cls(*args, **kwargs)
179
+
180
+ # 2. 加载 scoring_model
181
+ scoring_dir = os.path.join(load_directory, "scoring_model")
182
+ model.scoring_model = AutoPeftModelForCausalLM.from_pretrained(
183
+ scoring_dir,
184
+ device_map="auto",
185
+ low_cpu_mem_usage=True,
186
+ use_safetensors=True
187
+ )
188
+ model.scoring_tokenizer = AutoTokenizer.from_pretrained(scoring_dir)
189
+
190
+ # 3. 加载所有 null_distr
191
+ null_distrs_path = os.path.join(load_directory, "null_distrs.pt")
192
+ if os.path.exists(null_distrs_path):
193
+ null_distrs = torch.load(null_distrs_path, map_location="cpu")
194
+ for domain, null_distr in null_distrs.items():
195
+ model.set_null_distr(null_distr, domain)
196
+ print(f"✅ Restored {len(null_distrs)} null distributions: {list(null_distrs.keys())}")
197
+
198
+ # 4. 加载配置信息
199
+ config_path = os.path.join(load_directory, "config.json")
200
+ if os.path.exists(config_path):
201
+ with open(config_path, "r") as f:
202
+ config = json.load(f)
203
+ if "criterion" in config and config["criterion"] is not None:
204
+ model.criterion = config["criterion"]
205
+ print(f"✅ Loaded config: {config}")
206
+
207
+ print(f"✅ Model loaded from {load_directory}")
208
+ return model
209
+
210
+ def get_SPO_input(self, tokenized=None, text=[""], labels=[""], training_module=False):
211
+ if training_module:
212
+ logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
213
+ if self.reference_model_name != self.scoring_model_name:
214
+ tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False, add_special_tokens=True, return_attention_mask=True).to(self.device)
215
+ assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
216
+ logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
217
+ crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
218
+ else:
219
+ with torch.no_grad(): # get reference
220
+ logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] # shape: [bsz, sentence_len, dim]
221
+ if self.reference_model_name != self.scoring_model_name:
222
+ tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False ,add_special_tokens=True, return_attention_mask=True).to(self.device)
223
+ assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
224
+ logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
225
+ crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
226
+ return crit, SPO_input, logits_score
227
+
228
+ def forward(self, text, training_module=True):
229
+ original_text = text[0]
230
+ sampled_text = text[1]
231
+
232
+ tokenized = self.scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
233
+ labels = tokenized.input_ids[:, 1:]
234
+ train_original_crit, _, _ = self.get_SPO_input(tokenized, original_text, labels,training_module=training_module)
235
+
236
+ tokenized = self.scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
237
+ labels = tokenized.input_ids[:, 1:]
238
+ train_sampled_crit, _, _ = self.get_SPO_input(tokenized, sampled_text, labels,training_module=training_module)
239
+
240
+ output = dict(crit=[train_original_crit.detach(), train_original_crit, train_sampled_crit.detach(), train_sampled_crit])
241
+ return output
242
+
243
+ def set_null_distr(self, null_distr: torch.Tensor, domain: str):
244
+ """
245
+ Set the null distribution tensor safely.
246
+ """
247
+ distr_name = f"null_distr_{domain}"
248
+ self.register_buffer(distr_name, torch.empty(0))
249
+
250
+ if not isinstance(null_distr, torch.Tensor):
251
+ null_distr = torch.tensor(null_distr)
252
+
253
+ # detach + clone + 移到正确设备
254
+ null_distr = null_distr.detach().clone().to(self.device)
255
+
256
+ # 直接覆盖 buffer,避免 delattr 带来的问题
257
+ self._buffers[distr_name] = null_distr
258
+ print(f"✅ Null distribution on {domain} with shape: {self._buffers[distr_name].shape}")
259
+
260
+ def compute_p_value(self, text, domain: str):
261
+ """
262
+ Compute p-value for given text using the null distribution of specified domain.
263
+
264
+ Args:
265
+ text: Input text to compute score for
266
+ domain: Domain name to use for null distribution
267
+ """
268
+ tokenized = self.scoring_tokenizer(
269
+ text,
270
+ return_tensors="pt",
271
+ padding=True,
272
+ return_token_type_ids=False
273
+ ).to(self.device)
274
+ labels = tokenized.input_ids[:, 1:]
275
+
276
+ with torch.no_grad():
277
+ crit, _, _ = self.get_SPO_input(tokenized, text, labels, training_module=False)
278
+
279
+ # 获取对应domain的null distribution
280
+ distr_name = f"null_distr_{domain}"
281
+ if not hasattr(self, distr_name):
282
+ raise ValueError(
283
+ f"No null distribution found for domain '{domain}'. "
284
+ f"Available domains: {self.get_available_domains()}"
285
+ )
286
+
287
+ null_distr = getattr(self, distr_name)
288
+
289
+ # Compute p-value: (count + 1) / (total + 1)
290
+ total = null_distr.numel()
291
+ count = (null_distr >= crit.unsqueeze(-1)).float().sum() # slow computation
292
+ # count = total - torch.searchsorted(null_distr, crit, right=False)
293
+ p_value = (count + 1) / (total + 1)
294
+
295
+ return crit, p_value
296
+
297
+ def get_available_domains(self):
298
+ """
299
+ Get list of all available domains with null distributions.
300
+ """
301
+ domains = []
302
+ for buffer_name in self._buffers.keys():
303
+ if buffer_name.startswith("null_distr_"):
304
+ domain = buffer_name.replace("null_distr_", "")
305
+ domains.append(domain)
306
+ return domains