Jin Zhu commited on
Commit
35ad326
·
1 Parent(s): 1fe61f3

update code

Browse files
Files changed (2) hide show
  1. Dockerfile +13 -0
  2. src/FineTune/model.py +0 -306
Dockerfile CHANGED
@@ -13,6 +13,19 @@ COPY src/ ./src/
13
 
14
  RUN pip3 install -r requirements.txt
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  EXPOSE 8501
18
 
 
13
 
14
  RUN pip3 install -r requirements.txt
15
 
16
+ # ─────────────────────────────
17
+ # 新增:从 Hugging Face 私有仓库拉取核心 model.py
18
+ # ─────────────────────────────
19
+ # 使用 Hugging Face Token(需在 Space 的 Secrets 里配置)
20
+ ENV HF_TOKEN=$HF_TOKEN
21
+
22
+ # 下载私有仓库中的 model.py 并放入 FineTune 目录
23
+ RUN mkdir -p /app/src/FineTune && \
24
+ echo "Downloading private FineTune/model.py from Hugging Face..." && \
25
+ curl -H "Authorization: Bearer ${HF_TOKEN}" \
26
+ -L "https://huggingface.co/mamba413/ada-core/resolve/main/model.py" \
27
+ -o /app/src/FineTune/model.py && \
28
+ echo "✅ Successfully downloaded private model.py."
29
 
30
  EXPOSE 8501
31
 
src/FineTune/model.py DELETED
@@ -1,306 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from peft import get_peft_model, LoraConfig, TaskType, AutoPeftModelForCausalLM
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
- import time
6
- import json
7
-
8
- import os
9
-
10
- def from_pretrained(cls, model_name, kwargs, cache_dir):
11
- # use local model if it exists
12
- if "/" in model_name:
13
- local_path = os.path.join(cache_dir, model_name.split("/")[1])
14
- else:
15
- local_path = os.path.join(cache_dir, model_name)
16
-
17
- if os.path.exists(local_path):
18
- return cls.from_pretrained(local_path, **kwargs)
19
- return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir, device_map='auto')
20
-
21
- model_fullnames = {
22
- 'gemma-1b': 'google/gemma-3-1b-pt',
23
- }
24
- float16_models = []
25
-
26
- def get_model_fullname(model_name):
27
- return model_fullnames[model_name] if model_name in model_fullnames else model_name
28
-
29
- def load_tokenizer(model_name, for_dataset, cache_dir):
30
- model_fullname = get_model_fullname(model_name)
31
- optional_tok_kwargs = {}
32
- if "facebook/opt-" in model_fullname:
33
- print("Using non-fast tokenizer for OPT")
34
- optional_tok_kwargs['fast'] = False
35
- if for_dataset in ['pubmed']:
36
- optional_tok_kwargs['padding_side'] = 'left'
37
- else:
38
- optional_tok_kwargs['padding_side'] = 'right'
39
- base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir)
40
- if base_tokenizer.pad_token_id is None:
41
- base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
42
- if '13b' in model_fullname:
43
- base_tokenizer.pad_token_id = 0
44
- return base_tokenizer
45
-
46
- def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
47
- if logits_ref.size(-1) != logits_score.size(-1):
48
- vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
49
- logits_ref = logits_ref[:, :, :vocab_size]
50
- logits_score = logits_score[:, :, :vocab_size]
51
-
52
- labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
53
- lprobs_score = torch.log_softmax(logits_score, dim=-1)
54
- probs_ref = torch.softmax(logits_ref, dim=-1)
55
-
56
- log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
57
- mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
58
- var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
59
- discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).clamp_min(0.0001).sqrt()
60
-
61
- return discrepancy, log_likelihood.sum(dim=-1)
62
-
63
- class ComputeScore(nn.Module):
64
- def __init__(self, scoring_model_name, reference_model_name, dataset='xsum', device='cuda', cache_dir='./models'):
65
- super().__init__()
66
- self.device = device
67
- self.reference_model_name = get_model_fullname(reference_model_name)
68
- self.scoring_model_name = get_model_fullname(scoring_model_name)
69
-
70
- def load_model(model_name, device, cache_dir):
71
- model_fullname = get_model_fullname(model_name)
72
- print(f'Loading model {model_fullname}...')
73
- model_kwargs = {}
74
- if model_name in float16_models:
75
- model_kwargs.update(dict(torch_dtype=torch.float16))
76
- if torch.__version__ >= '2.0.0' and 'gemma' in model_name:
77
- model_kwargs.update({'attn_implementation': 'sdpa'})
78
- model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
79
- print('Moving model to GPU...', end='', flush=True)
80
- start = time.time()
81
- model.to(device)
82
- print(f'DONE ({time.time() - start:.2f}s)')
83
- return model
84
-
85
- # load scoring model
86
- self.scoring_tokenizer = load_tokenizer(scoring_model_name, dataset, cache_dir)
87
- scoring_model = load_model(scoring_model_name, device, cache_dir)
88
- if scoring_model_name in ['gemma-1b']:
89
- self.peft_config = LoraConfig(
90
- task_type=TaskType.CAUSAL_LM,
91
- inference_mode=False,
92
- r=8,
93
- lora_alpha=32,
94
- lora_dropout=0.1,
95
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
96
- )
97
- else:
98
- self.peft_config = LoraConfig(
99
- task_type=TaskType.CAUSAL_LM,
100
- inference_mode=False,
101
- r=8,
102
- lora_alpha=32,
103
- lora_dropout=0.1,
104
- )
105
- self.scoring_model = get_peft_model(scoring_model, self.peft_config)
106
-
107
- # load sampling model
108
- self.reference_tokenizer = load_tokenizer(reference_model_name, dataset, cache_dir)
109
- reference_model = load_model(reference_model_name, device, cache_dir)
110
- self.reference_model = reference_model
111
- self.reference_model.eval()
112
- for p in self.reference_model.parameters():
113
- p.requires_grad = False
114
-
115
- total = sum(p.numel() for p in self.scoring_model.parameters())
116
- trainable = sum(p.numel() for p in self.scoring_model.parameters() if p.requires_grad)
117
- print(f"Trainable / total (parameters): {trainable}/{total}={trainable/total}")
118
-
119
- def set_criterion_fn(self, criterion_fn):
120
- if criterion_fn == "mean":
121
- self.criterion = 'mean'
122
- self.criterion_fn = get_sampling_discrepancy_analytic
123
- else:
124
- raise ValueError(f"Unknown criterion function: {criterion_fn}")
125
-
126
- def print_gradient_requirement(self):
127
- for name, param in self.named_parameters():
128
- gradient_requirement = 'Requires Grad' if param.requires_grad else 'Does not require grad'
129
- color_code = '\033[92m' if param.requires_grad else '\033[91m' # Green for requires grad, red for does not require grad
130
- reset_color = '\033[0m' # Reset color after printing
131
- print(f"{name}: {color_code}{gradient_requirement}{reset_color}")
132
-
133
- def register_no_grad(self, module_names):
134
- for name, param in self.named_parameters():
135
- for selected_module in module_names:
136
- # print(selected_module, name)
137
- if selected_module in name:
138
- param.requires_grad = False
139
-
140
- def save_pretrained(self, save_directory: str):
141
- """
142
- Save the scoring model (with LoRA adapter) and all null_distr buffers in Hugging Face format.
143
- """
144
- os.makedirs(save_directory, exist_ok=True)
145
-
146
- # 1. 保存 scoring_model (LoRA adapter + 基础模型)
147
- scoring_dir = os.path.join(save_directory, "scoring_model")
148
- self.scoring_model.save_pretrained(scoring_dir, safe_serialization=True)
149
- self.scoring_tokenizer.save_pretrained(scoring_dir)
150
-
151
- # 2. 保存所有 null_distr_* buffers
152
- null_distrs = {}
153
- for buffer_name, buffer_value in self.named_buffers():
154
- if buffer_name.startswith("null_distr_"):
155
- domain = buffer_name.replace("null_distr_", "")
156
- null_distrs[domain] = buffer_value.detach().cpu()
157
-
158
- if null_distrs:
159
- torch.save(null_distrs, os.path.join(save_directory, "null_distrs.pt"))
160
- print(f"✅ Saved {len(null_distrs)} null distributions: {list(null_distrs.keys())}")
161
-
162
- # 3. 保存配置信息(包括domain列表)
163
- config = {
164
- "domains": list(null_distrs.keys()),
165
- "criterion": getattr(self, "criterion", None),
166
- }
167
- with open(os.path.join(save_directory, "config.json"), "w") as f:
168
- json.dump(config, f)
169
-
170
- print(f"✅ Model saved to {save_directory}")
171
-
172
- @classmethod
173
- def from_pretrained(cls, load_directory: str, *args, **kwargs):
174
- """
175
- Load the scoring model, reference model, and all null_distr buffers.
176
- """
177
- # 1. 初始化类
178
- model = cls(*args, **kwargs)
179
-
180
- # 2. 加载 scoring_model
181
- scoring_dir = os.path.join(load_directory, "scoring_model")
182
- model.scoring_model = AutoPeftModelForCausalLM.from_pretrained(
183
- scoring_dir,
184
- device_map="auto",
185
- low_cpu_mem_usage=True,
186
- use_safetensors=True
187
- )
188
- model.scoring_tokenizer = AutoTokenizer.from_pretrained(scoring_dir)
189
-
190
- # 3. 加载所有 null_distr
191
- null_distrs_path = os.path.join(load_directory, "null_distrs.pt")
192
- if os.path.exists(null_distrs_path):
193
- null_distrs = torch.load(null_distrs_path, map_location="cpu")
194
- for domain, null_distr in null_distrs.items():
195
- model.set_null_distr(null_distr, domain)
196
- print(f"✅ Restored {len(null_distrs)} null distributions: {list(null_distrs.keys())}")
197
-
198
- # 4. 加载配置信息
199
- config_path = os.path.join(load_directory, "config.json")
200
- if os.path.exists(config_path):
201
- with open(config_path, "r") as f:
202
- config = json.load(f)
203
- if "criterion" in config and config["criterion"] is not None:
204
- model.criterion = config["criterion"]
205
- print(f"✅ Loaded config: {config}")
206
-
207
- print(f"✅ Model loaded from {load_directory}")
208
- return model
209
-
210
- def get_SPO_input(self, tokenized=None, text=[""], labels=[""], training_module=False):
211
- if training_module:
212
- logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:]
213
- if self.reference_model_name != self.scoring_model_name:
214
- tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False, add_special_tokens=True, return_attention_mask=True).to(self.device)
215
- assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
216
- logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
217
- crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
218
- else:
219
- with torch.no_grad(): # get reference
220
- logits_score = self.scoring_model(tokenized.input_ids, attention_mask=tokenized.attention_mask).logits[:,:-1,:] # shape: [bsz, sentence_len, dim]
221
- if self.reference_model_name != self.scoring_model_name:
222
- tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False ,add_special_tokens=True, return_attention_mask=True).to(self.device)
223
- assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
224
- logits_ref = self.reference_model(tokenized.input_ids).logits[:,:-1,:]
225
- crit, SPO_input = self.criterion_fn(logits_ref, logits_score, labels)
226
- return crit, SPO_input, logits_score
227
-
228
- def forward(self, text, training_module=True):
229
- original_text = text[0]
230
- sampled_text = text[1]
231
-
232
- tokenized = self.scoring_tokenizer(original_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
233
- labels = tokenized.input_ids[:, 1:]
234
- train_original_crit, _, _ = self.get_SPO_input(tokenized, original_text, labels,training_module=training_module)
235
-
236
- tokenized = self.scoring_tokenizer(sampled_text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
237
- labels = tokenized.input_ids[:, 1:]
238
- train_sampled_crit, _, _ = self.get_SPO_input(tokenized, sampled_text, labels,training_module=training_module)
239
-
240
- output = dict(crit=[train_original_crit.detach(), train_original_crit, train_sampled_crit.detach(), train_sampled_crit])
241
- return output
242
-
243
- def set_null_distr(self, null_distr: torch.Tensor, domain: str):
244
- """
245
- Set the null distribution tensor safely.
246
- """
247
- distr_name = f"null_distr_{domain}"
248
- self.register_buffer(distr_name, torch.empty(0))
249
-
250
- if not isinstance(null_distr, torch.Tensor):
251
- null_distr = torch.tensor(null_distr)
252
-
253
- # detach + clone + 移到正确设备
254
- null_distr = null_distr.detach().clone().to(self.device)
255
-
256
- # 直接覆盖 buffer,避免 delattr 带来的问题
257
- self._buffers[distr_name] = null_distr
258
- print(f"✅ Null distribution on {domain} with shape: {self._buffers[distr_name].shape}")
259
-
260
- def compute_p_value(self, text, domain: str):
261
- """
262
- Compute p-value for given text using the null distribution of specified domain.
263
-
264
- Args:
265
- text: Input text to compute score for
266
- domain: Domain name to use for null distribution
267
- """
268
- tokenized = self.scoring_tokenizer(
269
- text,
270
- return_tensors="pt",
271
- padding=True,
272
- return_token_type_ids=False
273
- ).to(self.device)
274
- labels = tokenized.input_ids[:, 1:]
275
-
276
- with torch.no_grad():
277
- crit, _, _ = self.get_SPO_input(tokenized, text, labels, training_module=False)
278
-
279
- # 获取对应domain的null distribution
280
- distr_name = f"null_distr_{domain}"
281
- if not hasattr(self, distr_name):
282
- raise ValueError(
283
- f"No null distribution found for domain '{domain}'. "
284
- f"Available domains: {self.get_available_domains()}"
285
- )
286
-
287
- null_distr = getattr(self, distr_name)
288
-
289
- # Compute p-value: (count + 1) / (total + 1)
290
- total = null_distr.numel()
291
- count = (null_distr >= crit.unsqueeze(-1)).float().sum() # slow computation
292
- # count = total - torch.searchsorted(null_distr, crit, right=False)
293
- p_value = (count + 1) / (total + 1)
294
-
295
- return crit, p_value
296
-
297
- def get_available_domains(self):
298
- """
299
- Get list of all available domains with null distributions.
300
- """
301
- domains = []
302
- for buffer_name in self._buffers.keys():
303
- if buffer_name.startswith("null_distr_"):
304
- domain = buffer_name.replace("null_distr_", "")
305
- domains.append(domain)
306
- return domains