|
|
|
|
|
""" |
|
|
Abstract Model - Robust Inference with Forbidden Token Masking (Fixed Dimensions) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import json |
|
|
import importlib |
|
|
import inspect |
|
|
from pathlib import Path |
|
|
|
|
|
class AbstractModel(nn.Module): |
|
|
def __init__(self, sft_model_path, device=None): |
|
|
super().__init__() |
|
|
self.sft_model_path = sft_model_path |
|
|
|
|
|
if device is None: |
|
|
self._target_device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
else: |
|
|
self._target_device = device |
|
|
|
|
|
print(f"Initializing AbstractModel on target device: {self._target_device}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(sft_model_path, trust_remote_code=True) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
print(f"Loading SFT model from {sft_model_path}...") |
|
|
sft_model = AutoModelForCausalLM.from_pretrained( |
|
|
sft_model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True, |
|
|
attn_implementation="sdpa", |
|
|
) |
|
|
sft_model = sft_model.to(self._target_device) |
|
|
sft_model.eval() |
|
|
|
|
|
self.model_backbone = sft_model.model |
|
|
self.lm_head = sft_model.lm_head |
|
|
self.embed_layer = sft_model.get_input_embeddings() |
|
|
self.config = sft_model.config |
|
|
|
|
|
self.hidden_size = sft_model.config.hidden_size |
|
|
self.vocab_size = sft_model.config.vocab_size |
|
|
|
|
|
self.continuous_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) |
|
|
self.continuous_embed_layer = nn.Embedding(self.vocab_size, self.hidden_size) |
|
|
|
|
|
self.continuous_head = self.continuous_head.to(self._target_device).to(torch.bfloat16) |
|
|
self.continuous_embed_layer = self.continuous_embed_layer.to(self._target_device).to(torch.bfloat16) |
|
|
|
|
|
self.think_id = self.tokenizer.encode("<think>", add_special_tokens=False)[0] |
|
|
self.end_think_id = self.tokenizer.encode("</think>", add_special_tokens=False)[0] |
|
|
|
|
|
forbidden_strings = [ |
|
|
"<|end_of_text|>", "<|start_of_role|>", "<|end_of_role|>", |
|
|
"<|eot_id|>", "<|start_header_id|>", "user", "assistant", "system", |
|
|
"<tool_call>", "<tool_response>" |
|
|
] |
|
|
|
|
|
self.banned_ids = [] |
|
|
if self.tokenizer.eos_token_id is not None: |
|
|
self.banned_ids.append(self.tokenizer.eos_token_id) |
|
|
|
|
|
for s in forbidden_strings: |
|
|
ids = self.tokenizer.encode(s, add_special_tokens=False) |
|
|
if ids: |
|
|
self.banned_ids.extend(ids) |
|
|
|
|
|
self.banned_ids = sorted(list(set(self.banned_ids))) |
|
|
print(f"Banned {len(self.banned_ids)} structural tokens from Abstract Mode.") |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.embed_layer.weight.device |
|
|
|
|
|
def _init_cache(self, batch_size, max_length): |
|
|
try: |
|
|
module = importlib.import_module(self.model_backbone.__module__) |
|
|
if hasattr(module, "HybridMambaAttentionDynamicCache"): |
|
|
CacheClass = getattr(module, "HybridMambaAttentionDynamicCache") |
|
|
sig = inspect.signature(CacheClass.__init__) |
|
|
kwargs = {} |
|
|
if 'config' in sig.parameters: kwargs['config'] = self.config |
|
|
if 'batch_size' in sig.parameters: kwargs['batch_size'] = batch_size |
|
|
elif 'max_batch_size' in sig.parameters: kwargs['max_batch_size'] = batch_size |
|
|
if 'max_cache_len' in sig.parameters: kwargs['max_cache_len'] = max_length |
|
|
elif 'max_length' in sig.parameters: kwargs['max_length'] = max_length |
|
|
if 'device' in sig.parameters: kwargs['device'] = self.device |
|
|
if 'dtype' in sig.parameters: kwargs['dtype'] = self.embed_layer.weight.dtype |
|
|
return CacheClass(**kwargs) |
|
|
except Exception: pass |
|
|
from transformers import DynamicCache |
|
|
cache = DynamicCache() |
|
|
cache.has_previous_state = False |
|
|
return cache |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids, |
|
|
max_length=512, |
|
|
temperature=0.7, |
|
|
sample=False, |
|
|
no_grad=True, |
|
|
sigma=0.0, |
|
|
max_thinking_steps=64 |
|
|
): |
|
|
if input_ids.device != self.device: |
|
|
input_ids = input_ids.to(self.device) |
|
|
|
|
|
if no_grad: |
|
|
with torch.no_grad(): |
|
|
initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0) |
|
|
else: |
|
|
initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0) |
|
|
|
|
|
in_abstract_mode = True |
|
|
abstract_step_count = 0 |
|
|
generated_tokens = [] |
|
|
all_logits = [] |
|
|
mode_sequence = [] |
|
|
|
|
|
past_key_values = self._init_cache(batch_size=1, max_length=max_length + input_ids.shape[0] + 16) |
|
|
|
|
|
current_step_input = initial_embeddings.unsqueeze(0) |
|
|
current_seq_len = initial_embeddings.shape[0] |
|
|
|
|
|
context = torch.no_grad() if no_grad else torch.enable_grad() |
|
|
|
|
|
with context: |
|
|
for step in range(max_length): |
|
|
|
|
|
if step == 0: |
|
|
position_ids = torch.arange(0, current_seq_len, dtype=torch.long, device=self.device).unsqueeze(0) |
|
|
else: |
|
|
position_ids = torch.tensor([[current_seq_len - 1]], dtype=torch.long, device=self.device) |
|
|
|
|
|
outputs = self.model_backbone( |
|
|
inputs_embeds=current_step_input, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
past_key_values = outputs.past_key_values |
|
|
last_hidden = outputs.last_hidden_state[0, -1, :] |
|
|
|
|
|
|
|
|
logits = self.lm_head(last_hidden) |
|
|
stop_probs = F.softmax(logits.float(), dim=-1) |
|
|
natural_next_token = torch.argmax(stop_probs, dim=-1).item() |
|
|
|
|
|
|
|
|
force_stop = False |
|
|
if in_abstract_mode: |
|
|
abstract_step_count += 1 |
|
|
if abstract_step_count >= max_thinking_steps: |
|
|
natural_next_token = self.end_think_id |
|
|
force_stop = True |
|
|
|
|
|
|
|
|
if (natural_next_token == self.end_think_id or force_stop) and in_abstract_mode: |
|
|
|
|
|
in_abstract_mode = False |
|
|
mode_sequence.append('T') |
|
|
generated_tokens.append(self.end_think_id) |
|
|
next_embedding = self.embed_layer(torch.tensor([[self.end_think_id]], device=self.device)).squeeze(0).squeeze(0) |
|
|
|
|
|
elif in_abstract_mode: |
|
|
|
|
|
mode_sequence.append('A') |
|
|
cont_logits = self.continuous_head(last_hidden) |
|
|
|
|
|
if self.banned_ids: |
|
|
cont_logits[self.banned_ids] = float('-inf') |
|
|
|
|
|
cont_logits_f32 = cont_logits.float() / (temperature if temperature else 1.0) |
|
|
|
|
|
abstract_vis_token = torch.argmax(cont_logits_f32, dim=-1).item() |
|
|
generated_tokens.append(abstract_vis_token) |
|
|
|
|
|
top_k = min(256, self.vocab_size // 4) |
|
|
top_logits, top_indices = torch.topk(cont_logits_f32, top_k, dim=-1) |
|
|
top_probs = F.softmax(top_logits, dim=-1).to(torch.bfloat16) |
|
|
top_embeddings = self.continuous_embed_layer(top_indices) |
|
|
next_embedding = top_probs @ top_embeddings |
|
|
|
|
|
if sigma > 0.0 and not no_grad: |
|
|
next_embedding = next_embedding + (torch.randn_like(next_embedding) * sigma) |
|
|
else: |
|
|
|
|
|
mode_sequence.append('N') |
|
|
generated_tokens.append(natural_next_token) |
|
|
next_embedding = self.embed_layer(torch.tensor([[natural_next_token]], device=self.device)).squeeze(0).squeeze(0) |
|
|
|
|
|
if no_grad: all_logits.append(logits.detach().cpu()) |
|
|
|
|
|
if natural_next_token == self.tokenizer.eos_token_id and not in_abstract_mode: |
|
|
break |
|
|
|
|
|
current_step_input = next_embedding.unsqueeze(0).unsqueeze(0) |
|
|
current_seq_len += 1 |
|
|
|
|
|
return { |
|
|
'generated_tokens': torch.tensor(generated_tokens), |
|
|
'logits': torch.stack(all_logits) if all_logits else torch.tensor([]), |
|
|
'mode_sequence': mode_sequence, |
|
|
} |
|
|
|
|
|
def save_to_directory(self, output_dir): |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
try: |
|
|
head_state = {k: v.cpu() for k, v in self.continuous_head.state_dict().items()} |
|
|
embed_state = {k: v.cpu() for k, v in self.continuous_embed_layer.state_dict().items()} |
|
|
torch.save(head_state, output_path / "continuous_head.pt") |
|
|
torch.save(embed_state, output_path / "continuous_embed.pt") |
|
|
config = {'sft_model_path': str(self.sft_model_path), 'hidden_size': self.hidden_size, 'vocab_size': self.vocab_size} |
|
|
with open(output_path / "config.json", 'w') as f: json.dump(config, f) |
|
|
print(f"Saved model to {output_dir}") |
|
|
except Exception as e: print(f"Error saving model: {e}") |
|
|
|
|
|
@staticmethod |
|
|
def load_from_directory(output_dir, sft_model_path=None, device='cuda:0'): |
|
|
output_path = Path(output_dir) |
|
|
with open(output_path / "config.json", 'r') as f: config = json.load(f) |
|
|
if sft_model_path is None: sft_model_path = config['sft_model_path'] |
|
|
model = AbstractModel(sft_model_path, device=device) |
|
|
print(f"Loading checkpoint to {model.device}...") |
|
|
head_state = torch.load(output_path / "continuous_head.pt", map_location=model.device) |
|
|
embed_state = torch.load(output_path / "continuous_embed.pt", map_location=model.device) |
|
|
model.continuous_head.load_state_dict(head_state) |
|
|
model.continuous_embed_layer.load_state_dict(embed_state) |
|
|
model.continuous_head = model.continuous_head.to(torch.bfloat16) |
|
|
model.continuous_embed_layer = model.continuous_embed_layer.to(torch.bfloat16) |
|
|
return model |
|
|
|
|
|
if __name__ == '__main__': |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--sft-model', required=True) |
|
|
parser.add_argument('--load-model', default=None) |
|
|
parser.add_argument('--max-length', type=int, default=256) |
|
|
parser.add_argument('--temperature', type=float, default=0.7) |
|
|
args = parser.parse_args() |
|
|
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
model = AbstractModel.load_from_directory(args.load_model, sft_model_path=args.sft_model, device=device) |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(f"Abstract Model - Interactive Generation (Masked & Budgeted)") |
|
|
print("=" * 70 + "\n") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
prompt = input("You: ").strip() |
|
|
if not prompt: continue |
|
|
if prompt.lower() in ['q', 'quit']: break |
|
|
|
|
|
sys_prompt = "You are a reasoning assistant. Think step by step before answering." |
|
|
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}] |
|
|
|
|
|
formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(model.device).squeeze(0) |
|
|
|
|
|
print("Generating...", end="\r") |
|
|
|
|
|
result = model.forward( |
|
|
input_ids, |
|
|
max_length=args.max_length, |
|
|
temperature=args.temperature, |
|
|
sample=False, |
|
|
no_grad=True, |
|
|
sigma=0.0, |
|
|
max_thinking_steps=128 |
|
|
) |
|
|
|
|
|
generated_ids = result['generated_tokens'].tolist() |
|
|
modes = result['mode_sequence'] |
|
|
|
|
|
print("Assistant: ", end="") |
|
|
for token_id, mode in zip(generated_ids, modes): |
|
|
token_text = model.tokenizer.decode([token_id]) |
|
|
if mode == 'A': |
|
|
print(f"\033[96m{token_text}\033[0m", end="", flush=True) |
|
|
else: |
|
|
print(token_text, end="", flush=True) |
|
|
print("\n") |
|
|
print(f"[Stats] Abstract: {modes.count('A')} | Natural: {modes.count('N')}") |
|
|
print("-" * 70) |
|
|
|
|
|
except KeyboardInterrupt: break |
|
|
except Exception as e: print(f"\nError: {e}") |