Gavin-Wang commited on
Commit
b1b2e62
·
verified ·
1 Parent(s): 8ff3654
abstract_model.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Abstract Model - Robust Inference with Forbidden Token Masking (Fixed Dimensions)
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ import json
11
+ import importlib
12
+ import inspect
13
+ from pathlib import Path
14
+
15
+ class AbstractModel(nn.Module):
16
+ def __init__(self, sft_model_path, device=None):
17
+ super().__init__()
18
+ self.sft_model_path = sft_model_path
19
+
20
+ if device is None:
21
+ self._target_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
22
+ else:
23
+ self._target_device = device
24
+
25
+ print(f"Initializing AbstractModel on target device: {self._target_device}")
26
+
27
+ self.tokenizer = AutoTokenizer.from_pretrained(sft_model_path, trust_remote_code=True)
28
+ if self.tokenizer.pad_token is None:
29
+ self.tokenizer.pad_token = self.tokenizer.eos_token
30
+
31
+ print(f"Loading SFT model from {sft_model_path}...")
32
+ sft_model = AutoModelForCausalLM.from_pretrained(
33
+ sft_model_path,
34
+ torch_dtype=torch.bfloat16,
35
+ trust_remote_code=True,
36
+ attn_implementation="sdpa",
37
+ )
38
+ sft_model = sft_model.to(self._target_device)
39
+ sft_model.eval()
40
+
41
+ self.model_backbone = sft_model.model
42
+ self.lm_head = sft_model.lm_head
43
+ self.embed_layer = sft_model.get_input_embeddings()
44
+ self.config = sft_model.config
45
+
46
+ self.hidden_size = sft_model.config.hidden_size
47
+ self.vocab_size = sft_model.config.vocab_size
48
+
49
+ self.continuous_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
50
+ self.continuous_embed_layer = nn.Embedding(self.vocab_size, self.hidden_size)
51
+
52
+ self.continuous_head = self.continuous_head.to(self._target_device).to(torch.bfloat16)
53
+ self.continuous_embed_layer = self.continuous_embed_layer.to(self._target_device).to(torch.bfloat16)
54
+
55
+ self.think_id = self.tokenizer.encode("<think>", add_special_tokens=False)[0]
56
+ self.end_think_id = self.tokenizer.encode("</think>", add_special_tokens=False)[0]
57
+
58
+ forbidden_strings = [
59
+ "<|end_of_text|>", "<|start_of_role|>", "<|end_of_role|>",
60
+ "<|eot_id|>", "<|start_header_id|>", "user", "assistant", "system",
61
+ "<tool_call>", "<tool_response>"
62
+ ]
63
+
64
+ self.banned_ids = []
65
+ if self.tokenizer.eos_token_id is not None:
66
+ self.banned_ids.append(self.tokenizer.eos_token_id)
67
+
68
+ for s in forbidden_strings:
69
+ ids = self.tokenizer.encode(s, add_special_tokens=False)
70
+ if ids:
71
+ self.banned_ids.extend(ids)
72
+
73
+ self.banned_ids = sorted(list(set(self.banned_ids)))
74
+ print(f"Banned {len(self.banned_ids)} structural tokens from Abstract Mode.")
75
+
76
+ @property
77
+ def device(self):
78
+ return self.embed_layer.weight.device
79
+
80
+ def _init_cache(self, batch_size, max_length):
81
+ try:
82
+ module = importlib.import_module(self.model_backbone.__module__)
83
+ if hasattr(module, "HybridMambaAttentionDynamicCache"):
84
+ CacheClass = getattr(module, "HybridMambaAttentionDynamicCache")
85
+ sig = inspect.signature(CacheClass.__init__)
86
+ kwargs = {}
87
+ if 'config' in sig.parameters: kwargs['config'] = self.config
88
+ if 'batch_size' in sig.parameters: kwargs['batch_size'] = batch_size
89
+ elif 'max_batch_size' in sig.parameters: kwargs['max_batch_size'] = batch_size
90
+ if 'max_cache_len' in sig.parameters: kwargs['max_cache_len'] = max_length
91
+ elif 'max_length' in sig.parameters: kwargs['max_length'] = max_length
92
+ if 'device' in sig.parameters: kwargs['device'] = self.device
93
+ if 'dtype' in sig.parameters: kwargs['dtype'] = self.embed_layer.weight.dtype
94
+ return CacheClass(**kwargs)
95
+ except Exception: pass
96
+ from transformers import DynamicCache
97
+ cache = DynamicCache()
98
+ cache.has_previous_state = False
99
+ return cache
100
+
101
+ def forward(
102
+ self,
103
+ input_ids,
104
+ max_length=512,
105
+ temperature=0.7,
106
+ sample=False,
107
+ no_grad=True,
108
+ sigma=0.0,
109
+ max_thinking_steps=64
110
+ ):
111
+ if input_ids.device != self.device:
112
+ input_ids = input_ids.to(self.device)
113
+
114
+ if no_grad:
115
+ with torch.no_grad():
116
+ initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0)
117
+ else:
118
+ initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0)
119
+
120
+ in_abstract_mode = True
121
+ abstract_step_count = 0
122
+ generated_tokens = []
123
+ all_logits = []
124
+ mode_sequence = []
125
+
126
+ past_key_values = self._init_cache(batch_size=1, max_length=max_length + input_ids.shape[0] + 16)
127
+
128
+ current_step_input = initial_embeddings.unsqueeze(0)
129
+ current_seq_len = initial_embeddings.shape[0]
130
+
131
+ context = torch.no_grad() if no_grad else torch.enable_grad()
132
+
133
+ with context:
134
+ for step in range(max_length):
135
+
136
+ if step == 0:
137
+ position_ids = torch.arange(0, current_seq_len, dtype=torch.long, device=self.device).unsqueeze(0)
138
+ else:
139
+ position_ids = torch.tensor([[current_seq_len - 1]], dtype=torch.long, device=self.device)
140
+
141
+ outputs = self.model_backbone(
142
+ inputs_embeds=current_step_input,
143
+ position_ids=position_ids,
144
+ past_key_values=past_key_values,
145
+ use_cache=True
146
+ )
147
+
148
+ past_key_values = outputs.past_key_values
149
+ last_hidden = outputs.last_hidden_state[0, -1, :]
150
+
151
+ # 1. Natural Head (Used for stopping condition)
152
+ logits = self.lm_head(last_hidden)
153
+ stop_probs = F.softmax(logits.float(), dim=-1)
154
+ natural_next_token = torch.argmax(stop_probs, dim=-1).item()
155
+
156
+ # Force Stop Condition
157
+ force_stop = False
158
+ if in_abstract_mode:
159
+ abstract_step_count += 1
160
+ if abstract_step_count >= max_thinking_steps:
161
+ natural_next_token = self.end_think_id
162
+ force_stop = True
163
+
164
+ # 2. Logic Flow
165
+ if (natural_next_token == self.end_think_id or force_stop) and in_abstract_mode:
166
+ # Transition to Natural
167
+ in_abstract_mode = False
168
+ mode_sequence.append('T')
169
+ generated_tokens.append(self.end_think_id)
170
+ next_embedding = self.embed_layer(torch.tensor([[self.end_think_id]], device=self.device)).squeeze(0).squeeze(0)
171
+
172
+ elif in_abstract_mode:
173
+ # Abstract Generation
174
+ mode_sequence.append('A')
175
+ cont_logits = self.continuous_head(last_hidden)
176
+
177
+ if self.banned_ids:
178
+ cont_logits[self.banned_ids] = float('-inf')
179
+
180
+ cont_logits_f32 = cont_logits.float() / (temperature if temperature else 1.0)
181
+
182
+ abstract_vis_token = torch.argmax(cont_logits_f32, dim=-1).item()
183
+ generated_tokens.append(abstract_vis_token)
184
+
185
+ top_k = min(256, self.vocab_size // 4)
186
+ top_logits, top_indices = torch.topk(cont_logits_f32, top_k, dim=-1)
187
+ top_probs = F.softmax(top_logits, dim=-1).to(torch.bfloat16)
188
+ top_embeddings = self.continuous_embed_layer(top_indices)
189
+ next_embedding = top_probs @ top_embeddings
190
+
191
+ if sigma > 0.0 and not no_grad:
192
+ next_embedding = next_embedding + (torch.randn_like(next_embedding) * sigma)
193
+ else:
194
+ # Natural Generation
195
+ mode_sequence.append('N')
196
+ generated_tokens.append(natural_next_token)
197
+ next_embedding = self.embed_layer(torch.tensor([[natural_next_token]], device=self.device)).squeeze(0).squeeze(0)
198
+
199
+ if no_grad: all_logits.append(logits.detach().cpu())
200
+
201
+ if natural_next_token == self.tokenizer.eos_token_id and not in_abstract_mode:
202
+ break
203
+
204
+ current_step_input = next_embedding.unsqueeze(0).unsqueeze(0)
205
+ current_seq_len += 1
206
+
207
+ return {
208
+ 'generated_tokens': torch.tensor(generated_tokens),
209
+ 'logits': torch.stack(all_logits) if all_logits else torch.tensor([]),
210
+ 'mode_sequence': mode_sequence,
211
+ }
212
+
213
+ def save_to_directory(self, output_dir):
214
+ output_path = Path(output_dir)
215
+ output_path.mkdir(parents=True, exist_ok=True)
216
+ try:
217
+ head_state = {k: v.cpu() for k, v in self.continuous_head.state_dict().items()}
218
+ embed_state = {k: v.cpu() for k, v in self.continuous_embed_layer.state_dict().items()}
219
+ torch.save(head_state, output_path / "continuous_head.pt")
220
+ torch.save(embed_state, output_path / "continuous_embed.pt")
221
+ config = {'sft_model_path': str(self.sft_model_path), 'hidden_size': self.hidden_size, 'vocab_size': self.vocab_size}
222
+ with open(output_path / "config.json", 'w') as f: json.dump(config, f)
223
+ print(f"Saved model to {output_dir}")
224
+ except Exception as e: print(f"Error saving model: {e}")
225
+
226
+ @staticmethod
227
+ def load_from_directory(output_dir, sft_model_path=None, device='cuda:0'):
228
+ output_path = Path(output_dir)
229
+ with open(output_path / "config.json", 'r') as f: config = json.load(f)
230
+ if sft_model_path is None: sft_model_path = config['sft_model_path']
231
+ model = AbstractModel(sft_model_path, device=device)
232
+ print(f"Loading checkpoint to {model.device}...")
233
+ head_state = torch.load(output_path / "continuous_head.pt", map_location=model.device)
234
+ embed_state = torch.load(output_path / "continuous_embed.pt", map_location=model.device)
235
+ model.continuous_head.load_state_dict(head_state)
236
+ model.continuous_embed_layer.load_state_dict(embed_state)
237
+ model.continuous_head = model.continuous_head.to(torch.bfloat16)
238
+ model.continuous_embed_layer = model.continuous_embed_layer.to(torch.bfloat16)
239
+ return model
240
+
241
+ if __name__ == '__main__':
242
+ import argparse
243
+ parser = argparse.ArgumentParser()
244
+ parser.add_argument('--sft-model', required=True)
245
+ parser.add_argument('--load-model', default=None)
246
+ parser.add_argument('--max-length', type=int, default=256)
247
+ parser.add_argument('--temperature', type=float, default=0.7)
248
+ args = parser.parse_args()
249
+
250
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
251
+ model = AbstractModel.load_from_directory(args.load_model, sft_model_path=args.sft_model, device=device)
252
+
253
+ print("\n" + "=" * 70)
254
+ print(f"Abstract Model - Interactive Generation (Masked & Budgeted)")
255
+ print("=" * 70 + "\n")
256
+
257
+ while True:
258
+ try:
259
+ prompt = input("You: ").strip()
260
+ if not prompt: continue
261
+ if prompt.lower() in ['q', 'quit']: break
262
+
263
+ sys_prompt = "You are a reasoning assistant. Think step by step before answering."
264
+ messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}]
265
+
266
+ formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
267
+ input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(model.device).squeeze(0)
268
+
269
+ print("Generating...", end="\r")
270
+
271
+ result = model.forward(
272
+ input_ids,
273
+ max_length=args.max_length,
274
+ temperature=args.temperature,
275
+ sample=False,
276
+ no_grad=True,
277
+ sigma=0.0,
278
+ max_thinking_steps=128
279
+ )
280
+
281
+ generated_ids = result['generated_tokens'].tolist()
282
+ modes = result['mode_sequence']
283
+
284
+ print("Assistant: ", end="")
285
+ for token_id, mode in zip(generated_ids, modes):
286
+ token_text = model.tokenizer.decode([token_id])
287
+ if mode == 'A':
288
+ print(f"\033[96m{token_text}\033[0m", end="", flush=True)
289
+ else:
290
+ print(token_text, end="", flush=True)
291
+ print("\n")
292
+ print(f"[Stats] Abstract: {modes.count('A')} | Natural: {modes.count('N')}")
293
+ print("-" * 70)
294
+
295
+ except KeyboardInterrupt: break
296
+ except Exception as e: print(f"\nError: {e}")
create_initialized_abstract.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Create initialized Abstract model checkpoint.
4
+ """
5
+
6
+ import argparse
7
+ import torch
8
+ import os
9
+ from pathlib import Path
10
+ from abstract_model import AbstractModel
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--sft-model', required=True, help='Path to SFT model')
15
+ parser.add_argument('--output', required=True, help='Output directory for initialized model')
16
+ args = parser.parse_args()
17
+
18
+ print(f"Loading SFT model from: {args.sft_model}")
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ model = AbstractModel(args.sft_model, device=device)
22
+
23
+ print(f"Saving initialized model to: {args.output}")
24
+ os.makedirs(args.output, exist_ok=True)
25
+ model.save_to_directory(args.output)
26
+
27
+ if __name__ == "__main__":
28
+ main()
eval_simple.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import re
4
+ import os
5
+ import time
6
+ import random
7
+ import torch.multiprocessing as mp
8
+ from tqdm import tqdm
9
+ from abstract_model import AbstractModel
10
+
11
+
12
+ RL_MODEL_PATH = "pathtocontinuoushead"
13
+ FALLBACK_SFT_PATH = "pathtobasemodel"
14
+
15
+ DATASET_FILES = [
16
+ "../bench/mmlu.jsonl",
17
+ "../bench/gsm8k.jsonl",
18
+ "../bench/drop.jsonl"
19
+ ]
20
+
21
+ SAMPLES_PER_BENCHMARK = 1024
22
+ MAX_THINKING_STEPS = 256
23
+ MAX_TOTAL_LENGTH = 1536
24
+ LOG_FILE = "eval_results_random.jsonl"
25
+
26
+
27
+ def normalize_text(s):
28
+ import string
29
+ if s is None: return ""
30
+ def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text)
31
+ def white_space_fix(text): return ' '.join(text.split())
32
+ def remove_punc(text): return ''.join(ch for ch in text if ch not in set(string.punctuation))
33
+ return white_space_fix(remove_articles(remove_punc(str(s).lower())))
34
+
35
+ def extract_answer_content(text):
36
+ match = re.search(r"<ANSWER>(.*?)</ANSWER>", text, re.DOTALL)
37
+ if match: return match.group(1).strip()
38
+ return None
39
+
40
+ def load_and_sample_data(files, samples_per_file):
41
+ """
42
+ Loads full datasets and randomly samples N items from each.
43
+ """
44
+ final_data = []
45
+
46
+ for filename in files:
47
+ if not os.path.exists(filename):
48
+ print(f"Warning: File {filename} not found. Skipping.")
49
+ continue
50
+
51
+ # Detect benchmark type
52
+ fname_lower = filename.lower()
53
+ if "mmlu" in fname_lower: bench_type = "mmlu"
54
+ elif "gsm8k" in fname_lower: bench_type = "gsm8k"
55
+ elif "drop" in fname_lower: bench_type = "drop"
56
+ else: bench_type = "unknown"
57
+
58
+ print(f"Loading {filename} ({bench_type})...")
59
+
60
+ file_data = []
61
+ with open(filename, 'r', encoding='utf-8') as f:
62
+ for line in f:
63
+ try:
64
+ entry = json.loads(line)
65
+ if "benchmark" not in entry:
66
+ entry["benchmark"] = bench_type
67
+ file_data.append(entry)
68
+ except: continue
69
+
70
+ total_lines = len(file_data)
71
+
72
+ if total_lines > samples_per_file:
73
+ random.shuffle(file_data)
74
+ selected_data = file_data[:samples_per_file]
75
+ print(f" -> Randomly sampled {samples_per_file} from {total_lines} samples.")
76
+ else:
77
+ selected_data = file_data
78
+ print(f" -> Took all {total_lines} samples (less than requested limit).")
79
+
80
+ final_data.extend(selected_data)
81
+
82
+ return final_data
83
+
84
+
85
+ def score_sample(pred, truth, benchmark):
86
+ if benchmark == 'mmlu':
87
+ p = extract_answer_content(pred)
88
+ if not p: return False
89
+ m = re.search(r'([A-D])', p.upper())
90
+ return m.group(1) == truth.strip().upper() if m else False
91
+ elif benchmark == 'gsm8k':
92
+ p = extract_answer_content(pred)
93
+ if not p: return False
94
+ t = truth.split("####")[-1].strip() if "####" in truth else truth.strip()
95
+ return normalize_text(t) in normalize_text(p)
96
+ else:
97
+ p = extract_answer_content(pred)
98
+ if not p: return False
99
+ return normalize_text(p) == normalize_text(truth)
100
+
101
+ def gpu(gpu_id, head_path, sft_path, dataset_chunk, results_queue):
102
+ torch.cuda.set_device(gpu_id)
103
+ device = f"cuda:{gpu_id}"
104
+
105
+ if not os.path.exists(os.path.join(head_path, "continuous_head.pt")):
106
+ print(f"[GPU {gpu_id}] Critical: continuous_head.pt not found in {head_path}")
107
+ return
108
+
109
+ print(f"[GPU {gpu_id}] Loading Model...")
110
+ try:
111
+ model = AbstractModel.load_from_directory(
112
+ head_path,
113
+ sft_model_path=sft_path,
114
+ device=device
115
+ )
116
+ except Exception as e:
117
+ print(f"[GPU {gpu_id}] Error loading model: {e}")
118
+ return
119
+
120
+ results = []
121
+ iterator = tqdm(dataset_chunk, desc=f"GPU {gpu_id}", position=gpu_id, leave=True)
122
+
123
+ for item in iterator:
124
+ try:
125
+ sys_prompt = "You are a reasoning assistant. Think step by step before answering."
126
+ messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": item['question']}]
127
+
128
+ formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
129
+ input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(device).squeeze(0)
130
+
131
+ out = model.forward(
132
+ input_ids,
133
+ max_length=MAX_TOTAL_LENGTH,
134
+ temperature=0.0,
135
+ sample=False,
136
+ no_grad=True,
137
+ sigma=0.0,
138
+ max_thinking_steps=MAX_THINKING_STEPS
139
+ )
140
+
141
+ full_text = ""
142
+ for token_id in out['generated_tokens'].tolist():
143
+ full_text += model.tokenizer.decode([token_id])
144
+
145
+ is_correct = score_sample(full_text, item['answer'], item['benchmark'])
146
+
147
+ results.append({
148
+ "benchmark": item['benchmark'],
149
+ "correct": is_correct,
150
+ "think_steps": out['mode_sequence'].count('A'),
151
+ "prediction": full_text
152
+ })
153
+ except Exception as e:
154
+ print(f"[GPU {gpu_id}] Error: {e}")
155
+ continue
156
+
157
+ results_queue.put(results)
158
+
159
+
160
+ def run_evaluation():
161
+ all_data = load_and_sample_data(DATASET_FILES, SAMPLES_PER_BENCHMARK)
162
+
163
+ if not all_data:
164
+ print("No data loaded. Exiting.")
165
+ return
166
+
167
+ print(f"Total Evaluation Set: {len(all_data)} samples.")
168
+
169
+ mid = len(all_data) // 2
170
+ queue = mp.Queue()
171
+
172
+ p1 = mp.Process(target=gpu, args=(0, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[:mid], queue))
173
+ p2 = mp.Process(target=gpu, args=(1, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[mid:], queue))
174
+
175
+ start_time = time.time()
176
+ p1.start(); p2.start()
177
+
178
+ final_results = []
179
+ for _ in range(2): final_results.extend(queue.get())
180
+ p1.join(); p2.join()
181
+
182
+ print(f"Saving detailed logs to {LOG_FILE}...")
183
+ with open(LOG_FILE, 'w') as f:
184
+ for r in final_results: f.write(json.dumps(r) + '\n')
185
+
186
+ metrics = {}
187
+ for res in final_results:
188
+ b = res['benchmark']
189
+ if b not in metrics: metrics[b] = {'correct': [], 'steps': []}
190
+ metrics[b]['correct'].append(res['correct'])
191
+ metrics[b]['steps'].append(res['think_steps'])
192
+
193
+ print("\n" + "="*50)
194
+ print(f"FINAL SCORES (Random Sample N={SAMPLES_PER_BENCHMARK})")
195
+ print("="*50)
196
+
197
+ for b, d in metrics.items():
198
+ acc = sum(d['correct']) / len(d['correct']) * 100
199
+ avg_steps = sum(d['steps']) / len(d['steps'])
200
+ print(f"{b.upper():<10} | Acc: {acc:.2f}% | Avg Steps: {avg_steps:.1f} | N: {len(d['correct'])}")
201
+
202
+ print(f"Total time: {time.time() - start_time:.2f}s")
203
+
204
+ if __name__ == "__main__":
205
+ mp.set_start_method('spawn', force=True)
206
+ run_evaluation()
test_soft_embedding_with_trigger.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test soft embedding with trigger-based mode switching.
4
+ """
5
+
6
+ import argparse
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from pathlib import Path
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+
12
+
13
+ class TriggerHead(torch.nn.Module):
14
+ def __init__(self, hidden_size, hidden_dim=1024):
15
+ super().__init__()
16
+ self.w_gate = torch.nn.Linear(hidden_size, hidden_dim, bias=True)
17
+ self.w_value = torch.nn.Linear(hidden_size, hidden_dim, bias=True)
18
+ self.w_out = torch.nn.Linear(hidden_dim, 1, bias=True)
19
+
20
+ def forward(self, x):
21
+ gate = self.w_gate(x)
22
+ value = self.w_value(x)
23
+ activated = F.silu(gate) * value
24
+ x = self.w_out(activated)
25
+ return x.squeeze(-1)
26
+
27
+
28
+ def main():
29
+ parser = argparse.ArgumentParser(description="Test Soft Embedding with Trigger")
30
+ parser.add_argument('--sft-model', required=True, help='Path to SFT model')
31
+ parser.add_argument('--trigger-head', required=True, help='Path to trigger head checkpoint dir')
32
+ parser.add_argument('--max-length', type=int, default=256, help='Max generation length')
33
+ parser.add_argument('--threshold', type=float, default=0.5, help='Trigger threshold (>threshold = abstract mode)')
34
+ parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for softmax')
35
+
36
+ args = parser.parse_args()
37
+
38
+ print("=" * 70)
39
+ print("Testing Soft Embedding with Trigger-Based Mode Switching")
40
+ print("=" * 70)
41
+
42
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
43
+
44
+ print(f"\nLoading tokenizer from {args.sft_model}...")
45
+ tokenizer = AutoTokenizer.from_pretrained(args.sft_model, trust_remote_code=True)
46
+ if tokenizer.pad_token is None:
47
+ tokenizer.pad_token = tokenizer.eos_token
48
+
49
+ print(f"Loading SFT model from {args.sft_model}...")
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ args.sft_model,
52
+ torch_dtype=torch.bfloat16,
53
+ trust_remote_code=True,
54
+ device_map=None
55
+ ).to(device)
56
+ model.eval()
57
+
58
+ hidden_size = model.config.hidden_size
59
+ embed_layer = model.get_input_embeddings()
60
+
61
+ print(f"Loading trigger head from {args.trigger_head}...")
62
+ trigger_head = TriggerHead(hidden_size).to(device)
63
+ checkpoint_path = Path(args.trigger_head) / "trigger_head.pt"
64
+
65
+ if not checkpoint_path.exists():
66
+ print(f"Error: Checkpoint not found at {checkpoint_path}")
67
+ return
68
+
69
+ trigger_state = torch.load(checkpoint_path, map_location=device)
70
+ trigger_head.load_state_dict(trigger_state)
71
+ trigger_head.eval()
72
+
73
+ print("Models loaded.\n")
74
+
75
+ mode_stats = {'natural': 0, 'abstract': 0}
76
+
77
+ while True:
78
+ prompt = input("You: ").strip()
79
+ if prompt.lower() in ['quit', 'exit', 'q']:
80
+ break
81
+
82
+ if not prompt:
83
+ continue
84
+
85
+ messages = [{"role": "user", "content": prompt}]
86
+ formatted = tokenizer.apply_chat_template(
87
+ messages,
88
+ tokenize=False,
89
+ add_generation_prompt=True
90
+ )
91
+
92
+ input_ids = tokenizer(
93
+ formatted,
94
+ return_tensors='pt',
95
+ add_special_tokens=False
96
+ )['input_ids'].to(device)
97
+
98
+ print("Assistant: ", end="", flush=True)
99
+
100
+ generated_tokens = []
101
+ mode_sequence = []
102
+
103
+ with torch.no_grad():
104
+ current_embeddings = embed_layer(input_ids).squeeze(0)
105
+ next_mode = 'N'
106
+
107
+ while len(generated_tokens) + len(input_ids[0]) < args.max_length:
108
+ outputs = model.model(
109
+ inputs_embeds=current_embeddings.unsqueeze(0),
110
+ use_cache=False
111
+ )
112
+ hidden_state = outputs.last_hidden_state[0, -1]
113
+
114
+ hidden_state_normalized = F.normalize(hidden_state.float(), p=2, dim=-1)
115
+
116
+ trigger_logits = trigger_head(hidden_state_normalized.unsqueeze(0))
117
+ trigger_prob = torch.sigmoid(trigger_logits).item()
118
+ next_mode = 'S' if trigger_prob > args.threshold else 'N'
119
+
120
+ logits = model.lm_head(hidden_state)
121
+ logits = logits / args.temperature
122
+ probs = F.softmax(logits, dim=-1)
123
+
124
+ if next_mode == 'S':
125
+ mode_sequence.append('S')
126
+ embed_matrix = embed_layer.weight.float()
127
+ next_embedding = probs.float() @ embed_matrix
128
+ next_embedding = next_embedding.to(torch.bfloat16)
129
+ next_token = torch.argmax(probs).item()
130
+ token_text = tokenizer.decode([next_token])
131
+ print(f"<abstract>{token_text}", end="", flush=True)
132
+ else:
133
+ mode_sequence.append('N')
134
+ next_token = torch.argmax(probs).item()
135
+ next_embedding = embed_layer(torch.tensor([[next_token]], device=device)).squeeze(0).squeeze(0)
136
+ token_text = tokenizer.decode([next_token])
137
+ print(token_text, end="", flush=True)
138
+
139
+ if next_token == tokenizer.eos_token_id:
140
+ break
141
+
142
+ generated_tokens.append(next_token)
143
+ current_embeddings = torch.cat([current_embeddings, next_embedding.unsqueeze(0)], dim=0)
144
+
145
+ print("\n")
146
+
147
+ if mode_sequence:
148
+ n_count = mode_sequence.count('N')
149
+ s_count = mode_sequence.count('S')
150
+ mode_stats['natural'] += n_count
151
+ mode_stats['abstract'] += s_count
152
+ print(f"[Tokens: Natural={n_count}, Switch={s_count}, switch_ratio={s_count/(n_count+s_count)*100:.1f}%]\n")
153
+
154
+ print("\n" + "=" * 70)
155
+ print("Session Statistics:")
156
+ print(f" Natural mode tokens: {mode_stats['natural']}")
157
+ print(f" Switch point tokens: {mode_stats['abstract']}")
158
+ if mode_stats['natural'] + mode_stats['abstract'] > 0:
159
+ total = mode_stats['natural'] + mode_stats['abstract']
160
+ print(f" Switch ratio: {mode_stats['abstract']/total*100:.1f}%")
161
+
162
+
163
+ if __name__ == '__main__':
164
+ main()