granite-abstract / test_soft_embedding_with_trigger.py
Gavin-Wang's picture
scripts
b1b2e62 verified
#!/usr/bin/env python3
"""
Test soft embedding with trigger-based mode switching.
"""
import argparse
import torch
import torch.nn.functional as F
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
class TriggerHead(torch.nn.Module):
def __init__(self, hidden_size, hidden_dim=1024):
super().__init__()
self.w_gate = torch.nn.Linear(hidden_size, hidden_dim, bias=True)
self.w_value = torch.nn.Linear(hidden_size, hidden_dim, bias=True)
self.w_out = torch.nn.Linear(hidden_dim, 1, bias=True)
def forward(self, x):
gate = self.w_gate(x)
value = self.w_value(x)
activated = F.silu(gate) * value
x = self.w_out(activated)
return x.squeeze(-1)
def main():
parser = argparse.ArgumentParser(description="Test Soft Embedding with Trigger")
parser.add_argument('--sft-model', required=True, help='Path to SFT model')
parser.add_argument('--trigger-head', required=True, help='Path to trigger head checkpoint dir')
parser.add_argument('--max-length', type=int, default=256, help='Max generation length')
parser.add_argument('--threshold', type=float, default=0.5, help='Trigger threshold (>threshold = abstract mode)')
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for softmax')
args = parser.parse_args()
print("=" * 70)
print("Testing Soft Embedding with Trigger-Based Mode Switching")
print("=" * 70)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"\nLoading tokenizer from {args.sft_model}...")
tokenizer = AutoTokenizer.from_pretrained(args.sft_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading SFT model from {args.sft_model}...")
model = AutoModelForCausalLM.from_pretrained(
args.sft_model,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map=None
).to(device)
model.eval()
hidden_size = model.config.hidden_size
embed_layer = model.get_input_embeddings()
print(f"Loading trigger head from {args.trigger_head}...")
trigger_head = TriggerHead(hidden_size).to(device)
checkpoint_path = Path(args.trigger_head) / "trigger_head.pt"
if not checkpoint_path.exists():
print(f"Error: Checkpoint not found at {checkpoint_path}")
return
trigger_state = torch.load(checkpoint_path, map_location=device)
trigger_head.load_state_dict(trigger_state)
trigger_head.eval()
print("Models loaded.\n")
mode_stats = {'natural': 0, 'abstract': 0}
while True:
prompt = input("You: ").strip()
if prompt.lower() in ['quit', 'exit', 'q']:
break
if not prompt:
continue
messages = [{"role": "user", "content": prompt}]
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
input_ids = tokenizer(
formatted,
return_tensors='pt',
add_special_tokens=False
)['input_ids'].to(device)
print("Assistant: ", end="", flush=True)
generated_tokens = []
mode_sequence = []
with torch.no_grad():
current_embeddings = embed_layer(input_ids).squeeze(0)
next_mode = 'N'
while len(generated_tokens) + len(input_ids[0]) < args.max_length:
outputs = model.model(
inputs_embeds=current_embeddings.unsqueeze(0),
use_cache=False
)
hidden_state = outputs.last_hidden_state[0, -1]
hidden_state_normalized = F.normalize(hidden_state.float(), p=2, dim=-1)
trigger_logits = trigger_head(hidden_state_normalized.unsqueeze(0))
trigger_prob = torch.sigmoid(trigger_logits).item()
next_mode = 'S' if trigger_prob > args.threshold else 'N'
logits = model.lm_head(hidden_state)
logits = logits / args.temperature
probs = F.softmax(logits, dim=-1)
if next_mode == 'S':
mode_sequence.append('S')
embed_matrix = embed_layer.weight.float()
next_embedding = probs.float() @ embed_matrix
next_embedding = next_embedding.to(torch.bfloat16)
next_token = torch.argmax(probs).item()
token_text = tokenizer.decode([next_token])
print(f"<abstract>{token_text}", end="", flush=True)
else:
mode_sequence.append('N')
next_token = torch.argmax(probs).item()
next_embedding = embed_layer(torch.tensor([[next_token]], device=device)).squeeze(0).squeeze(0)
token_text = tokenizer.decode([next_token])
print(token_text, end="", flush=True)
if next_token == tokenizer.eos_token_id:
break
generated_tokens.append(next_token)
current_embeddings = torch.cat([current_embeddings, next_embedding.unsqueeze(0)], dim=0)
print("\n")
if mode_sequence:
n_count = mode_sequence.count('N')
s_count = mode_sequence.count('S')
mode_stats['natural'] += n_count
mode_stats['abstract'] += s_count
print(f"[Tokens: Natural={n_count}, Switch={s_count}, switch_ratio={s_count/(n_count+s_count)*100:.1f}%]\n")
print("\n" + "=" * 70)
print("Session Statistics:")
print(f" Natural mode tokens: {mode_stats['natural']}")
print(f" Switch point tokens: {mode_stats['abstract']}")
if mode_stats['natural'] + mode_stats['abstract'] > 0:
total = mode_stats['natural'] + mode_stats['abstract']
print(f" Switch ratio: {mode_stats['abstract']/total*100:.1f}%")
if __name__ == '__main__':
main()