#!/usr/bin/env python3 """ Test soft embedding with trigger-based mode switching. """ import argparse import torch import torch.nn.functional as F from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM class TriggerHead(torch.nn.Module): def __init__(self, hidden_size, hidden_dim=1024): super().__init__() self.w_gate = torch.nn.Linear(hidden_size, hidden_dim, bias=True) self.w_value = torch.nn.Linear(hidden_size, hidden_dim, bias=True) self.w_out = torch.nn.Linear(hidden_dim, 1, bias=True) def forward(self, x): gate = self.w_gate(x) value = self.w_value(x) activated = F.silu(gate) * value x = self.w_out(activated) return x.squeeze(-1) def main(): parser = argparse.ArgumentParser(description="Test Soft Embedding with Trigger") parser.add_argument('--sft-model', required=True, help='Path to SFT model') parser.add_argument('--trigger-head', required=True, help='Path to trigger head checkpoint dir') parser.add_argument('--max-length', type=int, default=256, help='Max generation length') parser.add_argument('--threshold', type=float, default=0.5, help='Trigger threshold (>threshold = abstract mode)') parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for softmax') args = parser.parse_args() print("=" * 70) print("Testing Soft Embedding with Trigger-Based Mode Switching") print("=" * 70) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print(f"\nLoading tokenizer from {args.sft_model}...") tokenizer = AutoTokenizer.from_pretrained(args.sft_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading SFT model from {args.sft_model}...") model = AutoModelForCausalLM.from_pretrained( args.sft_model, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map=None ).to(device) model.eval() hidden_size = model.config.hidden_size embed_layer = model.get_input_embeddings() print(f"Loading trigger head from {args.trigger_head}...") trigger_head = TriggerHead(hidden_size).to(device) checkpoint_path = Path(args.trigger_head) / "trigger_head.pt" if not checkpoint_path.exists(): print(f"Error: Checkpoint not found at {checkpoint_path}") return trigger_state = torch.load(checkpoint_path, map_location=device) trigger_head.load_state_dict(trigger_state) trigger_head.eval() print("Models loaded.\n") mode_stats = {'natural': 0, 'abstract': 0} while True: prompt = input("You: ").strip() if prompt.lower() in ['quit', 'exit', 'q']: break if not prompt: continue messages = [{"role": "user", "content": prompt}] formatted = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) input_ids = tokenizer( formatted, return_tensors='pt', add_special_tokens=False )['input_ids'].to(device) print("Assistant: ", end="", flush=True) generated_tokens = [] mode_sequence = [] with torch.no_grad(): current_embeddings = embed_layer(input_ids).squeeze(0) next_mode = 'N' while len(generated_tokens) + len(input_ids[0]) < args.max_length: outputs = model.model( inputs_embeds=current_embeddings.unsqueeze(0), use_cache=False ) hidden_state = outputs.last_hidden_state[0, -1] hidden_state_normalized = F.normalize(hidden_state.float(), p=2, dim=-1) trigger_logits = trigger_head(hidden_state_normalized.unsqueeze(0)) trigger_prob = torch.sigmoid(trigger_logits).item() next_mode = 'S' if trigger_prob > args.threshold else 'N' logits = model.lm_head(hidden_state) logits = logits / args.temperature probs = F.softmax(logits, dim=-1) if next_mode == 'S': mode_sequence.append('S') embed_matrix = embed_layer.weight.float() next_embedding = probs.float() @ embed_matrix next_embedding = next_embedding.to(torch.bfloat16) next_token = torch.argmax(probs).item() token_text = tokenizer.decode([next_token]) print(f"{token_text}", end="", flush=True) else: mode_sequence.append('N') next_token = torch.argmax(probs).item() next_embedding = embed_layer(torch.tensor([[next_token]], device=device)).squeeze(0).squeeze(0) token_text = tokenizer.decode([next_token]) print(token_text, end="", flush=True) if next_token == tokenizer.eos_token_id: break generated_tokens.append(next_token) current_embeddings = torch.cat([current_embeddings, next_embedding.unsqueeze(0)], dim=0) print("\n") if mode_sequence: n_count = mode_sequence.count('N') s_count = mode_sequence.count('S') mode_stats['natural'] += n_count mode_stats['abstract'] += s_count print(f"[Tokens: Natural={n_count}, Switch={s_count}, switch_ratio={s_count/(n_count+s_count)*100:.1f}%]\n") print("\n" + "=" * 70) print("Session Statistics:") print(f" Natural mode tokens: {mode_stats['natural']}") print(f" Switch point tokens: {mode_stats['abstract']}") if mode_stats['natural'] + mode_stats['abstract'] > 0: total = mode_stats['natural'] + mode_stats['abstract'] print(f" Switch ratio: {mode_stats['abstract']/total*100:.1f}%") if __name__ == '__main__': main()