scripts
Browse files- abstract_model.py +296 -0
- create_initialized_abstract.py +28 -0
- eval_simple.py +206 -0
- test_soft_embedding_with_trigger.py +164 -0
abstract_model.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Abstract Model - Robust Inference with Forbidden Token Masking (Fixed Dimensions)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 10 |
+
import json
|
| 11 |
+
import importlib
|
| 12 |
+
import inspect
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
class AbstractModel(nn.Module):
|
| 16 |
+
def __init__(self, sft_model_path, device=None):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.sft_model_path = sft_model_path
|
| 19 |
+
|
| 20 |
+
if device is None:
|
| 21 |
+
self._target_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 22 |
+
else:
|
| 23 |
+
self._target_device = device
|
| 24 |
+
|
| 25 |
+
print(f"Initializing AbstractModel on target device: {self._target_device}")
|
| 26 |
+
|
| 27 |
+
self.tokenizer = AutoTokenizer.from_pretrained(sft_model_path, trust_remote_code=True)
|
| 28 |
+
if self.tokenizer.pad_token is None:
|
| 29 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 30 |
+
|
| 31 |
+
print(f"Loading SFT model from {sft_model_path}...")
|
| 32 |
+
sft_model = AutoModelForCausalLM.from_pretrained(
|
| 33 |
+
sft_model_path,
|
| 34 |
+
torch_dtype=torch.bfloat16,
|
| 35 |
+
trust_remote_code=True,
|
| 36 |
+
attn_implementation="sdpa",
|
| 37 |
+
)
|
| 38 |
+
sft_model = sft_model.to(self._target_device)
|
| 39 |
+
sft_model.eval()
|
| 40 |
+
|
| 41 |
+
self.model_backbone = sft_model.model
|
| 42 |
+
self.lm_head = sft_model.lm_head
|
| 43 |
+
self.embed_layer = sft_model.get_input_embeddings()
|
| 44 |
+
self.config = sft_model.config
|
| 45 |
+
|
| 46 |
+
self.hidden_size = sft_model.config.hidden_size
|
| 47 |
+
self.vocab_size = sft_model.config.vocab_size
|
| 48 |
+
|
| 49 |
+
self.continuous_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
|
| 50 |
+
self.continuous_embed_layer = nn.Embedding(self.vocab_size, self.hidden_size)
|
| 51 |
+
|
| 52 |
+
self.continuous_head = self.continuous_head.to(self._target_device).to(torch.bfloat16)
|
| 53 |
+
self.continuous_embed_layer = self.continuous_embed_layer.to(self._target_device).to(torch.bfloat16)
|
| 54 |
+
|
| 55 |
+
self.think_id = self.tokenizer.encode("<think>", add_special_tokens=False)[0]
|
| 56 |
+
self.end_think_id = self.tokenizer.encode("</think>", add_special_tokens=False)[0]
|
| 57 |
+
|
| 58 |
+
forbidden_strings = [
|
| 59 |
+
"<|end_of_text|>", "<|start_of_role|>", "<|end_of_role|>",
|
| 60 |
+
"<|eot_id|>", "<|start_header_id|>", "user", "assistant", "system",
|
| 61 |
+
"<tool_call>", "<tool_response>"
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
self.banned_ids = []
|
| 65 |
+
if self.tokenizer.eos_token_id is not None:
|
| 66 |
+
self.banned_ids.append(self.tokenizer.eos_token_id)
|
| 67 |
+
|
| 68 |
+
for s in forbidden_strings:
|
| 69 |
+
ids = self.tokenizer.encode(s, add_special_tokens=False)
|
| 70 |
+
if ids:
|
| 71 |
+
self.banned_ids.extend(ids)
|
| 72 |
+
|
| 73 |
+
self.banned_ids = sorted(list(set(self.banned_ids)))
|
| 74 |
+
print(f"Banned {len(self.banned_ids)} structural tokens from Abstract Mode.")
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def device(self):
|
| 78 |
+
return self.embed_layer.weight.device
|
| 79 |
+
|
| 80 |
+
def _init_cache(self, batch_size, max_length):
|
| 81 |
+
try:
|
| 82 |
+
module = importlib.import_module(self.model_backbone.__module__)
|
| 83 |
+
if hasattr(module, "HybridMambaAttentionDynamicCache"):
|
| 84 |
+
CacheClass = getattr(module, "HybridMambaAttentionDynamicCache")
|
| 85 |
+
sig = inspect.signature(CacheClass.__init__)
|
| 86 |
+
kwargs = {}
|
| 87 |
+
if 'config' in sig.parameters: kwargs['config'] = self.config
|
| 88 |
+
if 'batch_size' in sig.parameters: kwargs['batch_size'] = batch_size
|
| 89 |
+
elif 'max_batch_size' in sig.parameters: kwargs['max_batch_size'] = batch_size
|
| 90 |
+
if 'max_cache_len' in sig.parameters: kwargs['max_cache_len'] = max_length
|
| 91 |
+
elif 'max_length' in sig.parameters: kwargs['max_length'] = max_length
|
| 92 |
+
if 'device' in sig.parameters: kwargs['device'] = self.device
|
| 93 |
+
if 'dtype' in sig.parameters: kwargs['dtype'] = self.embed_layer.weight.dtype
|
| 94 |
+
return CacheClass(**kwargs)
|
| 95 |
+
except Exception: pass
|
| 96 |
+
from transformers import DynamicCache
|
| 97 |
+
cache = DynamicCache()
|
| 98 |
+
cache.has_previous_state = False
|
| 99 |
+
return cache
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
input_ids,
|
| 104 |
+
max_length=512,
|
| 105 |
+
temperature=0.7,
|
| 106 |
+
sample=False,
|
| 107 |
+
no_grad=True,
|
| 108 |
+
sigma=0.0,
|
| 109 |
+
max_thinking_steps=64
|
| 110 |
+
):
|
| 111 |
+
if input_ids.device != self.device:
|
| 112 |
+
input_ids = input_ids.to(self.device)
|
| 113 |
+
|
| 114 |
+
if no_grad:
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0)
|
| 117 |
+
else:
|
| 118 |
+
initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0)
|
| 119 |
+
|
| 120 |
+
in_abstract_mode = True
|
| 121 |
+
abstract_step_count = 0
|
| 122 |
+
generated_tokens = []
|
| 123 |
+
all_logits = []
|
| 124 |
+
mode_sequence = []
|
| 125 |
+
|
| 126 |
+
past_key_values = self._init_cache(batch_size=1, max_length=max_length + input_ids.shape[0] + 16)
|
| 127 |
+
|
| 128 |
+
current_step_input = initial_embeddings.unsqueeze(0)
|
| 129 |
+
current_seq_len = initial_embeddings.shape[0]
|
| 130 |
+
|
| 131 |
+
context = torch.no_grad() if no_grad else torch.enable_grad()
|
| 132 |
+
|
| 133 |
+
with context:
|
| 134 |
+
for step in range(max_length):
|
| 135 |
+
|
| 136 |
+
if step == 0:
|
| 137 |
+
position_ids = torch.arange(0, current_seq_len, dtype=torch.long, device=self.device).unsqueeze(0)
|
| 138 |
+
else:
|
| 139 |
+
position_ids = torch.tensor([[current_seq_len - 1]], dtype=torch.long, device=self.device)
|
| 140 |
+
|
| 141 |
+
outputs = self.model_backbone(
|
| 142 |
+
inputs_embeds=current_step_input,
|
| 143 |
+
position_ids=position_ids,
|
| 144 |
+
past_key_values=past_key_values,
|
| 145 |
+
use_cache=True
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
past_key_values = outputs.past_key_values
|
| 149 |
+
last_hidden = outputs.last_hidden_state[0, -1, :]
|
| 150 |
+
|
| 151 |
+
# 1. Natural Head (Used for stopping condition)
|
| 152 |
+
logits = self.lm_head(last_hidden)
|
| 153 |
+
stop_probs = F.softmax(logits.float(), dim=-1)
|
| 154 |
+
natural_next_token = torch.argmax(stop_probs, dim=-1).item()
|
| 155 |
+
|
| 156 |
+
# Force Stop Condition
|
| 157 |
+
force_stop = False
|
| 158 |
+
if in_abstract_mode:
|
| 159 |
+
abstract_step_count += 1
|
| 160 |
+
if abstract_step_count >= max_thinking_steps:
|
| 161 |
+
natural_next_token = self.end_think_id
|
| 162 |
+
force_stop = True
|
| 163 |
+
|
| 164 |
+
# 2. Logic Flow
|
| 165 |
+
if (natural_next_token == self.end_think_id or force_stop) and in_abstract_mode:
|
| 166 |
+
# Transition to Natural
|
| 167 |
+
in_abstract_mode = False
|
| 168 |
+
mode_sequence.append('T')
|
| 169 |
+
generated_tokens.append(self.end_think_id)
|
| 170 |
+
next_embedding = self.embed_layer(torch.tensor([[self.end_think_id]], device=self.device)).squeeze(0).squeeze(0)
|
| 171 |
+
|
| 172 |
+
elif in_abstract_mode:
|
| 173 |
+
# Abstract Generation
|
| 174 |
+
mode_sequence.append('A')
|
| 175 |
+
cont_logits = self.continuous_head(last_hidden)
|
| 176 |
+
|
| 177 |
+
if self.banned_ids:
|
| 178 |
+
cont_logits[self.banned_ids] = float('-inf')
|
| 179 |
+
|
| 180 |
+
cont_logits_f32 = cont_logits.float() / (temperature if temperature else 1.0)
|
| 181 |
+
|
| 182 |
+
abstract_vis_token = torch.argmax(cont_logits_f32, dim=-1).item()
|
| 183 |
+
generated_tokens.append(abstract_vis_token)
|
| 184 |
+
|
| 185 |
+
top_k = min(256, self.vocab_size // 4)
|
| 186 |
+
top_logits, top_indices = torch.topk(cont_logits_f32, top_k, dim=-1)
|
| 187 |
+
top_probs = F.softmax(top_logits, dim=-1).to(torch.bfloat16)
|
| 188 |
+
top_embeddings = self.continuous_embed_layer(top_indices)
|
| 189 |
+
next_embedding = top_probs @ top_embeddings
|
| 190 |
+
|
| 191 |
+
if sigma > 0.0 and not no_grad:
|
| 192 |
+
next_embedding = next_embedding + (torch.randn_like(next_embedding) * sigma)
|
| 193 |
+
else:
|
| 194 |
+
# Natural Generation
|
| 195 |
+
mode_sequence.append('N')
|
| 196 |
+
generated_tokens.append(natural_next_token)
|
| 197 |
+
next_embedding = self.embed_layer(torch.tensor([[natural_next_token]], device=self.device)).squeeze(0).squeeze(0)
|
| 198 |
+
|
| 199 |
+
if no_grad: all_logits.append(logits.detach().cpu())
|
| 200 |
+
|
| 201 |
+
if natural_next_token == self.tokenizer.eos_token_id and not in_abstract_mode:
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
current_step_input = next_embedding.unsqueeze(0).unsqueeze(0)
|
| 205 |
+
current_seq_len += 1
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
'generated_tokens': torch.tensor(generated_tokens),
|
| 209 |
+
'logits': torch.stack(all_logits) if all_logits else torch.tensor([]),
|
| 210 |
+
'mode_sequence': mode_sequence,
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
def save_to_directory(self, output_dir):
|
| 214 |
+
output_path = Path(output_dir)
|
| 215 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 216 |
+
try:
|
| 217 |
+
head_state = {k: v.cpu() for k, v in self.continuous_head.state_dict().items()}
|
| 218 |
+
embed_state = {k: v.cpu() for k, v in self.continuous_embed_layer.state_dict().items()}
|
| 219 |
+
torch.save(head_state, output_path / "continuous_head.pt")
|
| 220 |
+
torch.save(embed_state, output_path / "continuous_embed.pt")
|
| 221 |
+
config = {'sft_model_path': str(self.sft_model_path), 'hidden_size': self.hidden_size, 'vocab_size': self.vocab_size}
|
| 222 |
+
with open(output_path / "config.json", 'w') as f: json.dump(config, f)
|
| 223 |
+
print(f"Saved model to {output_dir}")
|
| 224 |
+
except Exception as e: print(f"Error saving model: {e}")
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def load_from_directory(output_dir, sft_model_path=None, device='cuda:0'):
|
| 228 |
+
output_path = Path(output_dir)
|
| 229 |
+
with open(output_path / "config.json", 'r') as f: config = json.load(f)
|
| 230 |
+
if sft_model_path is None: sft_model_path = config['sft_model_path']
|
| 231 |
+
model = AbstractModel(sft_model_path, device=device)
|
| 232 |
+
print(f"Loading checkpoint to {model.device}...")
|
| 233 |
+
head_state = torch.load(output_path / "continuous_head.pt", map_location=model.device)
|
| 234 |
+
embed_state = torch.load(output_path / "continuous_embed.pt", map_location=model.device)
|
| 235 |
+
model.continuous_head.load_state_dict(head_state)
|
| 236 |
+
model.continuous_embed_layer.load_state_dict(embed_state)
|
| 237 |
+
model.continuous_head = model.continuous_head.to(torch.bfloat16)
|
| 238 |
+
model.continuous_embed_layer = model.continuous_embed_layer.to(torch.bfloat16)
|
| 239 |
+
return model
|
| 240 |
+
|
| 241 |
+
if __name__ == '__main__':
|
| 242 |
+
import argparse
|
| 243 |
+
parser = argparse.ArgumentParser()
|
| 244 |
+
parser.add_argument('--sft-model', required=True)
|
| 245 |
+
parser.add_argument('--load-model', default=None)
|
| 246 |
+
parser.add_argument('--max-length', type=int, default=256)
|
| 247 |
+
parser.add_argument('--temperature', type=float, default=0.7)
|
| 248 |
+
args = parser.parse_args()
|
| 249 |
+
|
| 250 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 251 |
+
model = AbstractModel.load_from_directory(args.load_model, sft_model_path=args.sft_model, device=device)
|
| 252 |
+
|
| 253 |
+
print("\n" + "=" * 70)
|
| 254 |
+
print(f"Abstract Model - Interactive Generation (Masked & Budgeted)")
|
| 255 |
+
print("=" * 70 + "\n")
|
| 256 |
+
|
| 257 |
+
while True:
|
| 258 |
+
try:
|
| 259 |
+
prompt = input("You: ").strip()
|
| 260 |
+
if not prompt: continue
|
| 261 |
+
if prompt.lower() in ['q', 'quit']: break
|
| 262 |
+
|
| 263 |
+
sys_prompt = "You are a reasoning assistant. Think step by step before answering."
|
| 264 |
+
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}]
|
| 265 |
+
|
| 266 |
+
formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 267 |
+
input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(model.device).squeeze(0)
|
| 268 |
+
|
| 269 |
+
print("Generating...", end="\r")
|
| 270 |
+
|
| 271 |
+
result = model.forward(
|
| 272 |
+
input_ids,
|
| 273 |
+
max_length=args.max_length,
|
| 274 |
+
temperature=args.temperature,
|
| 275 |
+
sample=False,
|
| 276 |
+
no_grad=True,
|
| 277 |
+
sigma=0.0,
|
| 278 |
+
max_thinking_steps=128
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
generated_ids = result['generated_tokens'].tolist()
|
| 282 |
+
modes = result['mode_sequence']
|
| 283 |
+
|
| 284 |
+
print("Assistant: ", end="")
|
| 285 |
+
for token_id, mode in zip(generated_ids, modes):
|
| 286 |
+
token_text = model.tokenizer.decode([token_id])
|
| 287 |
+
if mode == 'A':
|
| 288 |
+
print(f"\033[96m{token_text}\033[0m", end="", flush=True)
|
| 289 |
+
else:
|
| 290 |
+
print(token_text, end="", flush=True)
|
| 291 |
+
print("\n")
|
| 292 |
+
print(f"[Stats] Abstract: {modes.count('A')} | Natural: {modes.count('N')}")
|
| 293 |
+
print("-" * 70)
|
| 294 |
+
|
| 295 |
+
except KeyboardInterrupt: break
|
| 296 |
+
except Exception as e: print(f"\nError: {e}")
|
create_initialized_abstract.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Create initialized Abstract model checkpoint.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from abstract_model import AbstractModel
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument('--sft-model', required=True, help='Path to SFT model')
|
| 15 |
+
parser.add_argument('--output', required=True, help='Output directory for initialized model')
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
|
| 18 |
+
print(f"Loading SFT model from: {args.sft_model}")
|
| 19 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 20 |
+
|
| 21 |
+
model = AbstractModel(args.sft_model, device=device)
|
| 22 |
+
|
| 23 |
+
print(f"Saving initialized model to: {args.output}")
|
| 24 |
+
os.makedirs(args.output, exist_ok=True)
|
| 25 |
+
model.save_to_directory(args.output)
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
eval_simple.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import random
|
| 7 |
+
import torch.multiprocessing as mp
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from abstract_model import AbstractModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
RL_MODEL_PATH = "pathtocontinuoushead"
|
| 13 |
+
FALLBACK_SFT_PATH = "pathtobasemodel"
|
| 14 |
+
|
| 15 |
+
DATASET_FILES = [
|
| 16 |
+
"../bench/mmlu.jsonl",
|
| 17 |
+
"../bench/gsm8k.jsonl",
|
| 18 |
+
"../bench/drop.jsonl"
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
SAMPLES_PER_BENCHMARK = 1024
|
| 22 |
+
MAX_THINKING_STEPS = 256
|
| 23 |
+
MAX_TOTAL_LENGTH = 1536
|
| 24 |
+
LOG_FILE = "eval_results_random.jsonl"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def normalize_text(s):
|
| 28 |
+
import string
|
| 29 |
+
if s is None: return ""
|
| 30 |
+
def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text)
|
| 31 |
+
def white_space_fix(text): return ' '.join(text.split())
|
| 32 |
+
def remove_punc(text): return ''.join(ch for ch in text if ch not in set(string.punctuation))
|
| 33 |
+
return white_space_fix(remove_articles(remove_punc(str(s).lower())))
|
| 34 |
+
|
| 35 |
+
def extract_answer_content(text):
|
| 36 |
+
match = re.search(r"<ANSWER>(.*?)</ANSWER>", text, re.DOTALL)
|
| 37 |
+
if match: return match.group(1).strip()
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
def load_and_sample_data(files, samples_per_file):
|
| 41 |
+
"""
|
| 42 |
+
Loads full datasets and randomly samples N items from each.
|
| 43 |
+
"""
|
| 44 |
+
final_data = []
|
| 45 |
+
|
| 46 |
+
for filename in files:
|
| 47 |
+
if not os.path.exists(filename):
|
| 48 |
+
print(f"Warning: File {filename} not found. Skipping.")
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# Detect benchmark type
|
| 52 |
+
fname_lower = filename.lower()
|
| 53 |
+
if "mmlu" in fname_lower: bench_type = "mmlu"
|
| 54 |
+
elif "gsm8k" in fname_lower: bench_type = "gsm8k"
|
| 55 |
+
elif "drop" in fname_lower: bench_type = "drop"
|
| 56 |
+
else: bench_type = "unknown"
|
| 57 |
+
|
| 58 |
+
print(f"Loading {filename} ({bench_type})...")
|
| 59 |
+
|
| 60 |
+
file_data = []
|
| 61 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
| 62 |
+
for line in f:
|
| 63 |
+
try:
|
| 64 |
+
entry = json.loads(line)
|
| 65 |
+
if "benchmark" not in entry:
|
| 66 |
+
entry["benchmark"] = bench_type
|
| 67 |
+
file_data.append(entry)
|
| 68 |
+
except: continue
|
| 69 |
+
|
| 70 |
+
total_lines = len(file_data)
|
| 71 |
+
|
| 72 |
+
if total_lines > samples_per_file:
|
| 73 |
+
random.shuffle(file_data)
|
| 74 |
+
selected_data = file_data[:samples_per_file]
|
| 75 |
+
print(f" -> Randomly sampled {samples_per_file} from {total_lines} samples.")
|
| 76 |
+
else:
|
| 77 |
+
selected_data = file_data
|
| 78 |
+
print(f" -> Took all {total_lines} samples (less than requested limit).")
|
| 79 |
+
|
| 80 |
+
final_data.extend(selected_data)
|
| 81 |
+
|
| 82 |
+
return final_data
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def score_sample(pred, truth, benchmark):
|
| 86 |
+
if benchmark == 'mmlu':
|
| 87 |
+
p = extract_answer_content(pred)
|
| 88 |
+
if not p: return False
|
| 89 |
+
m = re.search(r'([A-D])', p.upper())
|
| 90 |
+
return m.group(1) == truth.strip().upper() if m else False
|
| 91 |
+
elif benchmark == 'gsm8k':
|
| 92 |
+
p = extract_answer_content(pred)
|
| 93 |
+
if not p: return False
|
| 94 |
+
t = truth.split("####")[-1].strip() if "####" in truth else truth.strip()
|
| 95 |
+
return normalize_text(t) in normalize_text(p)
|
| 96 |
+
else:
|
| 97 |
+
p = extract_answer_content(pred)
|
| 98 |
+
if not p: return False
|
| 99 |
+
return normalize_text(p) == normalize_text(truth)
|
| 100 |
+
|
| 101 |
+
def gpu(gpu_id, head_path, sft_path, dataset_chunk, results_queue):
|
| 102 |
+
torch.cuda.set_device(gpu_id)
|
| 103 |
+
device = f"cuda:{gpu_id}"
|
| 104 |
+
|
| 105 |
+
if not os.path.exists(os.path.join(head_path, "continuous_head.pt")):
|
| 106 |
+
print(f"[GPU {gpu_id}] Critical: continuous_head.pt not found in {head_path}")
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
print(f"[GPU {gpu_id}] Loading Model...")
|
| 110 |
+
try:
|
| 111 |
+
model = AbstractModel.load_from_directory(
|
| 112 |
+
head_path,
|
| 113 |
+
sft_model_path=sft_path,
|
| 114 |
+
device=device
|
| 115 |
+
)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"[GPU {gpu_id}] Error loading model: {e}")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
results = []
|
| 121 |
+
iterator = tqdm(dataset_chunk, desc=f"GPU {gpu_id}", position=gpu_id, leave=True)
|
| 122 |
+
|
| 123 |
+
for item in iterator:
|
| 124 |
+
try:
|
| 125 |
+
sys_prompt = "You are a reasoning assistant. Think step by step before answering."
|
| 126 |
+
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": item['question']}]
|
| 127 |
+
|
| 128 |
+
formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 129 |
+
input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(device).squeeze(0)
|
| 130 |
+
|
| 131 |
+
out = model.forward(
|
| 132 |
+
input_ids,
|
| 133 |
+
max_length=MAX_TOTAL_LENGTH,
|
| 134 |
+
temperature=0.0,
|
| 135 |
+
sample=False,
|
| 136 |
+
no_grad=True,
|
| 137 |
+
sigma=0.0,
|
| 138 |
+
max_thinking_steps=MAX_THINKING_STEPS
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
full_text = ""
|
| 142 |
+
for token_id in out['generated_tokens'].tolist():
|
| 143 |
+
full_text += model.tokenizer.decode([token_id])
|
| 144 |
+
|
| 145 |
+
is_correct = score_sample(full_text, item['answer'], item['benchmark'])
|
| 146 |
+
|
| 147 |
+
results.append({
|
| 148 |
+
"benchmark": item['benchmark'],
|
| 149 |
+
"correct": is_correct,
|
| 150 |
+
"think_steps": out['mode_sequence'].count('A'),
|
| 151 |
+
"prediction": full_text
|
| 152 |
+
})
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"[GPU {gpu_id}] Error: {e}")
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
results_queue.put(results)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def run_evaluation():
|
| 161 |
+
all_data = load_and_sample_data(DATASET_FILES, SAMPLES_PER_BENCHMARK)
|
| 162 |
+
|
| 163 |
+
if not all_data:
|
| 164 |
+
print("No data loaded. Exiting.")
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
print(f"Total Evaluation Set: {len(all_data)} samples.")
|
| 168 |
+
|
| 169 |
+
mid = len(all_data) // 2
|
| 170 |
+
queue = mp.Queue()
|
| 171 |
+
|
| 172 |
+
p1 = mp.Process(target=gpu, args=(0, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[:mid], queue))
|
| 173 |
+
p2 = mp.Process(target=gpu, args=(1, RL_MODEL_PATH, FALLBACK_SFT_PATH, all_data[mid:], queue))
|
| 174 |
+
|
| 175 |
+
start_time = time.time()
|
| 176 |
+
p1.start(); p2.start()
|
| 177 |
+
|
| 178 |
+
final_results = []
|
| 179 |
+
for _ in range(2): final_results.extend(queue.get())
|
| 180 |
+
p1.join(); p2.join()
|
| 181 |
+
|
| 182 |
+
print(f"Saving detailed logs to {LOG_FILE}...")
|
| 183 |
+
with open(LOG_FILE, 'w') as f:
|
| 184 |
+
for r in final_results: f.write(json.dumps(r) + '\n')
|
| 185 |
+
|
| 186 |
+
metrics = {}
|
| 187 |
+
for res in final_results:
|
| 188 |
+
b = res['benchmark']
|
| 189 |
+
if b not in metrics: metrics[b] = {'correct': [], 'steps': []}
|
| 190 |
+
metrics[b]['correct'].append(res['correct'])
|
| 191 |
+
metrics[b]['steps'].append(res['think_steps'])
|
| 192 |
+
|
| 193 |
+
print("\n" + "="*50)
|
| 194 |
+
print(f"FINAL SCORES (Random Sample N={SAMPLES_PER_BENCHMARK})")
|
| 195 |
+
print("="*50)
|
| 196 |
+
|
| 197 |
+
for b, d in metrics.items():
|
| 198 |
+
acc = sum(d['correct']) / len(d['correct']) * 100
|
| 199 |
+
avg_steps = sum(d['steps']) / len(d['steps'])
|
| 200 |
+
print(f"{b.upper():<10} | Acc: {acc:.2f}% | Avg Steps: {avg_steps:.1f} | N: {len(d['correct'])}")
|
| 201 |
+
|
| 202 |
+
print(f"Total time: {time.time() - start_time:.2f}s")
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
mp.set_start_method('spawn', force=True)
|
| 206 |
+
run_evaluation()
|
test_soft_embedding_with_trigger.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test soft embedding with trigger-based mode switching.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TriggerHead(torch.nn.Module):
|
| 14 |
+
def __init__(self, hidden_size, hidden_dim=1024):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.w_gate = torch.nn.Linear(hidden_size, hidden_dim, bias=True)
|
| 17 |
+
self.w_value = torch.nn.Linear(hidden_size, hidden_dim, bias=True)
|
| 18 |
+
self.w_out = torch.nn.Linear(hidden_dim, 1, bias=True)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
gate = self.w_gate(x)
|
| 22 |
+
value = self.w_value(x)
|
| 23 |
+
activated = F.silu(gate) * value
|
| 24 |
+
x = self.w_out(activated)
|
| 25 |
+
return x.squeeze(-1)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
parser = argparse.ArgumentParser(description="Test Soft Embedding with Trigger")
|
| 30 |
+
parser.add_argument('--sft-model', required=True, help='Path to SFT model')
|
| 31 |
+
parser.add_argument('--trigger-head', required=True, help='Path to trigger head checkpoint dir')
|
| 32 |
+
parser.add_argument('--max-length', type=int, default=256, help='Max generation length')
|
| 33 |
+
parser.add_argument('--threshold', type=float, default=0.5, help='Trigger threshold (>threshold = abstract mode)')
|
| 34 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for softmax')
|
| 35 |
+
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
print("=" * 70)
|
| 39 |
+
print("Testing Soft Embedding with Trigger-Based Mode Switching")
|
| 40 |
+
print("=" * 70)
|
| 41 |
+
|
| 42 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 43 |
+
|
| 44 |
+
print(f"\nLoading tokenizer from {args.sft_model}...")
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained(args.sft_model, trust_remote_code=True)
|
| 46 |
+
if tokenizer.pad_token is None:
|
| 47 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 48 |
+
|
| 49 |
+
print(f"Loading SFT model from {args.sft_model}...")
|
| 50 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 51 |
+
args.sft_model,
|
| 52 |
+
torch_dtype=torch.bfloat16,
|
| 53 |
+
trust_remote_code=True,
|
| 54 |
+
device_map=None
|
| 55 |
+
).to(device)
|
| 56 |
+
model.eval()
|
| 57 |
+
|
| 58 |
+
hidden_size = model.config.hidden_size
|
| 59 |
+
embed_layer = model.get_input_embeddings()
|
| 60 |
+
|
| 61 |
+
print(f"Loading trigger head from {args.trigger_head}...")
|
| 62 |
+
trigger_head = TriggerHead(hidden_size).to(device)
|
| 63 |
+
checkpoint_path = Path(args.trigger_head) / "trigger_head.pt"
|
| 64 |
+
|
| 65 |
+
if not checkpoint_path.exists():
|
| 66 |
+
print(f"Error: Checkpoint not found at {checkpoint_path}")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
trigger_state = torch.load(checkpoint_path, map_location=device)
|
| 70 |
+
trigger_head.load_state_dict(trigger_state)
|
| 71 |
+
trigger_head.eval()
|
| 72 |
+
|
| 73 |
+
print("Models loaded.\n")
|
| 74 |
+
|
| 75 |
+
mode_stats = {'natural': 0, 'abstract': 0}
|
| 76 |
+
|
| 77 |
+
while True:
|
| 78 |
+
prompt = input("You: ").strip()
|
| 79 |
+
if prompt.lower() in ['quit', 'exit', 'q']:
|
| 80 |
+
break
|
| 81 |
+
|
| 82 |
+
if not prompt:
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
messages = [{"role": "user", "content": prompt}]
|
| 86 |
+
formatted = tokenizer.apply_chat_template(
|
| 87 |
+
messages,
|
| 88 |
+
tokenize=False,
|
| 89 |
+
add_generation_prompt=True
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
input_ids = tokenizer(
|
| 93 |
+
formatted,
|
| 94 |
+
return_tensors='pt',
|
| 95 |
+
add_special_tokens=False
|
| 96 |
+
)['input_ids'].to(device)
|
| 97 |
+
|
| 98 |
+
print("Assistant: ", end="", flush=True)
|
| 99 |
+
|
| 100 |
+
generated_tokens = []
|
| 101 |
+
mode_sequence = []
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
current_embeddings = embed_layer(input_ids).squeeze(0)
|
| 105 |
+
next_mode = 'N'
|
| 106 |
+
|
| 107 |
+
while len(generated_tokens) + len(input_ids[0]) < args.max_length:
|
| 108 |
+
outputs = model.model(
|
| 109 |
+
inputs_embeds=current_embeddings.unsqueeze(0),
|
| 110 |
+
use_cache=False
|
| 111 |
+
)
|
| 112 |
+
hidden_state = outputs.last_hidden_state[0, -1]
|
| 113 |
+
|
| 114 |
+
hidden_state_normalized = F.normalize(hidden_state.float(), p=2, dim=-1)
|
| 115 |
+
|
| 116 |
+
trigger_logits = trigger_head(hidden_state_normalized.unsqueeze(0))
|
| 117 |
+
trigger_prob = torch.sigmoid(trigger_logits).item()
|
| 118 |
+
next_mode = 'S' if trigger_prob > args.threshold else 'N'
|
| 119 |
+
|
| 120 |
+
logits = model.lm_head(hidden_state)
|
| 121 |
+
logits = logits / args.temperature
|
| 122 |
+
probs = F.softmax(logits, dim=-1)
|
| 123 |
+
|
| 124 |
+
if next_mode == 'S':
|
| 125 |
+
mode_sequence.append('S')
|
| 126 |
+
embed_matrix = embed_layer.weight.float()
|
| 127 |
+
next_embedding = probs.float() @ embed_matrix
|
| 128 |
+
next_embedding = next_embedding.to(torch.bfloat16)
|
| 129 |
+
next_token = torch.argmax(probs).item()
|
| 130 |
+
token_text = tokenizer.decode([next_token])
|
| 131 |
+
print(f"<abstract>{token_text}", end="", flush=True)
|
| 132 |
+
else:
|
| 133 |
+
mode_sequence.append('N')
|
| 134 |
+
next_token = torch.argmax(probs).item()
|
| 135 |
+
next_embedding = embed_layer(torch.tensor([[next_token]], device=device)).squeeze(0).squeeze(0)
|
| 136 |
+
token_text = tokenizer.decode([next_token])
|
| 137 |
+
print(token_text, end="", flush=True)
|
| 138 |
+
|
| 139 |
+
if next_token == tokenizer.eos_token_id:
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
generated_tokens.append(next_token)
|
| 143 |
+
current_embeddings = torch.cat([current_embeddings, next_embedding.unsqueeze(0)], dim=0)
|
| 144 |
+
|
| 145 |
+
print("\n")
|
| 146 |
+
|
| 147 |
+
if mode_sequence:
|
| 148 |
+
n_count = mode_sequence.count('N')
|
| 149 |
+
s_count = mode_sequence.count('S')
|
| 150 |
+
mode_stats['natural'] += n_count
|
| 151 |
+
mode_stats['abstract'] += s_count
|
| 152 |
+
print(f"[Tokens: Natural={n_count}, Switch={s_count}, switch_ratio={s_count/(n_count+s_count)*100:.1f}%]\n")
|
| 153 |
+
|
| 154 |
+
print("\n" + "=" * 70)
|
| 155 |
+
print("Session Statistics:")
|
| 156 |
+
print(f" Natural mode tokens: {mode_stats['natural']}")
|
| 157 |
+
print(f" Switch point tokens: {mode_stats['abstract']}")
|
| 158 |
+
if mode_stats['natural'] + mode_stats['abstract'] > 0:
|
| 159 |
+
total = mode_stats['natural'] + mode_stats['abstract']
|
| 160 |
+
print(f" Switch ratio: {mode_stats['abstract']/total*100:.1f}%")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == '__main__':
|
| 164 |
+
main()
|