|
|
|
|
|
""" |
|
|
Test soft embedding with trigger-based mode switching. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from pathlib import Path |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
class TriggerHead(torch.nn.Module): |
|
|
def __init__(self, hidden_size, hidden_dim=1024): |
|
|
super().__init__() |
|
|
self.w_gate = torch.nn.Linear(hidden_size, hidden_dim, bias=True) |
|
|
self.w_value = torch.nn.Linear(hidden_size, hidden_dim, bias=True) |
|
|
self.w_out = torch.nn.Linear(hidden_dim, 1, bias=True) |
|
|
|
|
|
def forward(self, x): |
|
|
gate = self.w_gate(x) |
|
|
value = self.w_value(x) |
|
|
activated = F.silu(gate) * value |
|
|
x = self.w_out(activated) |
|
|
return x.squeeze(-1) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Test Soft Embedding with Trigger") |
|
|
parser.add_argument('--sft-model', required=True, help='Path to SFT model') |
|
|
parser.add_argument('--trigger-head', required=True, help='Path to trigger head checkpoint dir') |
|
|
parser.add_argument('--max-length', type=int, default=256, help='Max generation length') |
|
|
parser.add_argument('--threshold', type=float, default=0.5, help='Trigger threshold (>threshold = abstract mode)') |
|
|
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for softmax') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print("=" * 70) |
|
|
print("Testing Soft Embedding with Trigger-Based Mode Switching") |
|
|
print("=" * 70) |
|
|
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
print(f"\nLoading tokenizer from {args.sft_model}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(args.sft_model, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print(f"Loading SFT model from {args.sft_model}...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.sft_model, |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True, |
|
|
device_map=None |
|
|
).to(device) |
|
|
model.eval() |
|
|
|
|
|
hidden_size = model.config.hidden_size |
|
|
embed_layer = model.get_input_embeddings() |
|
|
|
|
|
print(f"Loading trigger head from {args.trigger_head}...") |
|
|
trigger_head = TriggerHead(hidden_size).to(device) |
|
|
checkpoint_path = Path(args.trigger_head) / "trigger_head.pt" |
|
|
|
|
|
if not checkpoint_path.exists(): |
|
|
print(f"Error: Checkpoint not found at {checkpoint_path}") |
|
|
return |
|
|
|
|
|
trigger_state = torch.load(checkpoint_path, map_location=device) |
|
|
trigger_head.load_state_dict(trigger_state) |
|
|
trigger_head.eval() |
|
|
|
|
|
print("Models loaded.\n") |
|
|
|
|
|
mode_stats = {'natural': 0, 'abstract': 0} |
|
|
|
|
|
while True: |
|
|
prompt = input("You: ").strip() |
|
|
if prompt.lower() in ['quit', 'exit', 'q']: |
|
|
break |
|
|
|
|
|
if not prompt: |
|
|
continue |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
formatted = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
input_ids = tokenizer( |
|
|
formatted, |
|
|
return_tensors='pt', |
|
|
add_special_tokens=False |
|
|
)['input_ids'].to(device) |
|
|
|
|
|
print("Assistant: ", end="", flush=True) |
|
|
|
|
|
generated_tokens = [] |
|
|
mode_sequence = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
current_embeddings = embed_layer(input_ids).squeeze(0) |
|
|
next_mode = 'N' |
|
|
|
|
|
while len(generated_tokens) + len(input_ids[0]) < args.max_length: |
|
|
outputs = model.model( |
|
|
inputs_embeds=current_embeddings.unsqueeze(0), |
|
|
use_cache=False |
|
|
) |
|
|
hidden_state = outputs.last_hidden_state[0, -1] |
|
|
|
|
|
hidden_state_normalized = F.normalize(hidden_state.float(), p=2, dim=-1) |
|
|
|
|
|
trigger_logits = trigger_head(hidden_state_normalized.unsqueeze(0)) |
|
|
trigger_prob = torch.sigmoid(trigger_logits).item() |
|
|
next_mode = 'S' if trigger_prob > args.threshold else 'N' |
|
|
|
|
|
logits = model.lm_head(hidden_state) |
|
|
logits = logits / args.temperature |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
if next_mode == 'S': |
|
|
mode_sequence.append('S') |
|
|
embed_matrix = embed_layer.weight.float() |
|
|
next_embedding = probs.float() @ embed_matrix |
|
|
next_embedding = next_embedding.to(torch.bfloat16) |
|
|
next_token = torch.argmax(probs).item() |
|
|
token_text = tokenizer.decode([next_token]) |
|
|
print(f"<abstract>{token_text}", end="", flush=True) |
|
|
else: |
|
|
mode_sequence.append('N') |
|
|
next_token = torch.argmax(probs).item() |
|
|
next_embedding = embed_layer(torch.tensor([[next_token]], device=device)).squeeze(0).squeeze(0) |
|
|
token_text = tokenizer.decode([next_token]) |
|
|
print(token_text, end="", flush=True) |
|
|
|
|
|
if next_token == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
current_embeddings = torch.cat([current_embeddings, next_embedding.unsqueeze(0)], dim=0) |
|
|
|
|
|
print("\n") |
|
|
|
|
|
if mode_sequence: |
|
|
n_count = mode_sequence.count('N') |
|
|
s_count = mode_sequence.count('S') |
|
|
mode_stats['natural'] += n_count |
|
|
mode_stats['abstract'] += s_count |
|
|
print(f"[Tokens: Natural={n_count}, Switch={s_count}, switch_ratio={s_count/(n_count+s_count)*100:.1f}%]\n") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("Session Statistics:") |
|
|
print(f" Natural mode tokens: {mode_stats['natural']}") |
|
|
print(f" Switch point tokens: {mode_stats['abstract']}") |
|
|
if mode_stats['natural'] + mode_stats['abstract'] > 0: |
|
|
total = mode_stats['natural'] + mode_stats['abstract'] |
|
|
print(f" Switch ratio: {mode_stats['abstract']/total*100:.1f}%") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|