#!/usr/bin/env python3 """ Abstract Model - Robust Inference with Forbidden Token Masking (Fixed Dimensions) """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM import json import importlib import inspect from pathlib import Path class AbstractModel(nn.Module): def __init__(self, sft_model_path, device=None): super().__init__() self.sft_model_path = sft_model_path if device is None: self._target_device = 'cuda:0' if torch.cuda.is_available() else 'cpu' else: self._target_device = device print(f"Initializing AbstractModel on target device: {self._target_device}") self.tokenizer = AutoTokenizer.from_pretrained(sft_model_path, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print(f"Loading SFT model from {sft_model_path}...") sft_model = AutoModelForCausalLM.from_pretrained( sft_model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation="sdpa", ) sft_model = sft_model.to(self._target_device) sft_model.eval() self.model_backbone = sft_model.model self.lm_head = sft_model.lm_head self.embed_layer = sft_model.get_input_embeddings() self.config = sft_model.config self.hidden_size = sft_model.config.hidden_size self.vocab_size = sft_model.config.vocab_size self.continuous_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) self.continuous_embed_layer = nn.Embedding(self.vocab_size, self.hidden_size) self.continuous_head = self.continuous_head.to(self._target_device).to(torch.bfloat16) self.continuous_embed_layer = self.continuous_embed_layer.to(self._target_device).to(torch.bfloat16) self.think_id = self.tokenizer.encode("", add_special_tokens=False)[0] self.end_think_id = self.tokenizer.encode("", add_special_tokens=False)[0] forbidden_strings = [ "<|end_of_text|>", "<|start_of_role|>", "<|end_of_role|>", "<|eot_id|>", "<|start_header_id|>", "user", "assistant", "system", "", "" ] self.banned_ids = [] if self.tokenizer.eos_token_id is not None: self.banned_ids.append(self.tokenizer.eos_token_id) for s in forbidden_strings: ids = self.tokenizer.encode(s, add_special_tokens=False) if ids: self.banned_ids.extend(ids) self.banned_ids = sorted(list(set(self.banned_ids))) print(f"Banned {len(self.banned_ids)} structural tokens from Abstract Mode.") @property def device(self): return self.embed_layer.weight.device def _init_cache(self, batch_size, max_length): try: module = importlib.import_module(self.model_backbone.__module__) if hasattr(module, "HybridMambaAttentionDynamicCache"): CacheClass = getattr(module, "HybridMambaAttentionDynamicCache") sig = inspect.signature(CacheClass.__init__) kwargs = {} if 'config' in sig.parameters: kwargs['config'] = self.config if 'batch_size' in sig.parameters: kwargs['batch_size'] = batch_size elif 'max_batch_size' in sig.parameters: kwargs['max_batch_size'] = batch_size if 'max_cache_len' in sig.parameters: kwargs['max_cache_len'] = max_length elif 'max_length' in sig.parameters: kwargs['max_length'] = max_length if 'device' in sig.parameters: kwargs['device'] = self.device if 'dtype' in sig.parameters: kwargs['dtype'] = self.embed_layer.weight.dtype return CacheClass(**kwargs) except Exception: pass from transformers import DynamicCache cache = DynamicCache() cache.has_previous_state = False return cache def forward( self, input_ids, max_length=512, temperature=0.7, sample=False, no_grad=True, sigma=0.0, max_thinking_steps=64 ): if input_ids.device != self.device: input_ids = input_ids.to(self.device) if no_grad: with torch.no_grad(): initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0) else: initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0) in_abstract_mode = True abstract_step_count = 0 generated_tokens = [] all_logits = [] mode_sequence = [] past_key_values = self._init_cache(batch_size=1, max_length=max_length + input_ids.shape[0] + 16) current_step_input = initial_embeddings.unsqueeze(0) current_seq_len = initial_embeddings.shape[0] context = torch.no_grad() if no_grad else torch.enable_grad() with context: for step in range(max_length): if step == 0: position_ids = torch.arange(0, current_seq_len, dtype=torch.long, device=self.device).unsqueeze(0) else: position_ids = torch.tensor([[current_seq_len - 1]], dtype=torch.long, device=self.device) outputs = self.model_backbone( inputs_embeds=current_step_input, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) past_key_values = outputs.past_key_values last_hidden = outputs.last_hidden_state[0, -1, :] # 1. Natural Head (Used for stopping condition) logits = self.lm_head(last_hidden) stop_probs = F.softmax(logits.float(), dim=-1) natural_next_token = torch.argmax(stop_probs, dim=-1).item() # Force Stop Condition force_stop = False if in_abstract_mode: abstract_step_count += 1 if abstract_step_count >= max_thinking_steps: natural_next_token = self.end_think_id force_stop = True # 2. Logic Flow if (natural_next_token == self.end_think_id or force_stop) and in_abstract_mode: # Transition to Natural in_abstract_mode = False mode_sequence.append('T') generated_tokens.append(self.end_think_id) next_embedding = self.embed_layer(torch.tensor([[self.end_think_id]], device=self.device)).squeeze(0).squeeze(0) elif in_abstract_mode: # Abstract Generation mode_sequence.append('A') cont_logits = self.continuous_head(last_hidden) if self.banned_ids: cont_logits[self.banned_ids] = float('-inf') cont_logits_f32 = cont_logits.float() / (temperature if temperature else 1.0) abstract_vis_token = torch.argmax(cont_logits_f32, dim=-1).item() generated_tokens.append(abstract_vis_token) top_k = min(256, self.vocab_size // 4) top_logits, top_indices = torch.topk(cont_logits_f32, top_k, dim=-1) top_probs = F.softmax(top_logits, dim=-1).to(torch.bfloat16) top_embeddings = self.continuous_embed_layer(top_indices) next_embedding = top_probs @ top_embeddings if sigma > 0.0 and not no_grad: next_embedding = next_embedding + (torch.randn_like(next_embedding) * sigma) else: # Natural Generation mode_sequence.append('N') generated_tokens.append(natural_next_token) next_embedding = self.embed_layer(torch.tensor([[natural_next_token]], device=self.device)).squeeze(0).squeeze(0) if no_grad: all_logits.append(logits.detach().cpu()) if natural_next_token == self.tokenizer.eos_token_id and not in_abstract_mode: break current_step_input = next_embedding.unsqueeze(0).unsqueeze(0) current_seq_len += 1 return { 'generated_tokens': torch.tensor(generated_tokens), 'logits': torch.stack(all_logits) if all_logits else torch.tensor([]), 'mode_sequence': mode_sequence, } def save_to_directory(self, output_dir): output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) try: head_state = {k: v.cpu() for k, v in self.continuous_head.state_dict().items()} embed_state = {k: v.cpu() for k, v in self.continuous_embed_layer.state_dict().items()} torch.save(head_state, output_path / "continuous_head.pt") torch.save(embed_state, output_path / "continuous_embed.pt") config = {'sft_model_path': str(self.sft_model_path), 'hidden_size': self.hidden_size, 'vocab_size': self.vocab_size} with open(output_path / "config.json", 'w') as f: json.dump(config, f) print(f"Saved model to {output_dir}") except Exception as e: print(f"Error saving model: {e}") @staticmethod def load_from_directory(output_dir, sft_model_path=None, device='cuda:0'): output_path = Path(output_dir) with open(output_path / "config.json", 'r') as f: config = json.load(f) if sft_model_path is None: sft_model_path = config['sft_model_path'] model = AbstractModel(sft_model_path, device=device) print(f"Loading checkpoint to {model.device}...") head_state = torch.load(output_path / "continuous_head.pt", map_location=model.device) embed_state = torch.load(output_path / "continuous_embed.pt", map_location=model.device) model.continuous_head.load_state_dict(head_state) model.continuous_embed_layer.load_state_dict(embed_state) model.continuous_head = model.continuous_head.to(torch.bfloat16) model.continuous_embed_layer = model.continuous_embed_layer.to(torch.bfloat16) return model if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--sft-model', required=True) parser.add_argument('--load-model', default=None) parser.add_argument('--max-length', type=int, default=256) parser.add_argument('--temperature', type=float, default=0.7) args = parser.parse_args() device = 'cuda:0' if torch.cuda.is_available() else 'cpu' model = AbstractModel.load_from_directory(args.load_model, sft_model_path=args.sft_model, device=device) print("\n" + "=" * 70) print(f"Abstract Model - Interactive Generation (Masked & Budgeted)") print("=" * 70 + "\n") while True: try: prompt = input("You: ").strip() if not prompt: continue if prompt.lower() in ['q', 'quit']: break sys_prompt = "You are a reasoning assistant. Think step by step before answering." messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}] formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(model.device).squeeze(0) print("Generating...", end="\r") result = model.forward( input_ids, max_length=args.max_length, temperature=args.temperature, sample=False, no_grad=True, sigma=0.0, max_thinking_steps=128 ) generated_ids = result['generated_tokens'].tolist() modes = result['mode_sequence'] print("Assistant: ", end="") for token_id, mode in zip(generated_ids, modes): token_text = model.tokenizer.decode([token_id]) if mode == 'A': print(f"\033[96m{token_text}\033[0m", end="", flush=True) else: print(token_text, end="", flush=True) print("\n") print(f"[Stats] Abstract: {modes.count('A')} | Natural: {modes.count('N')}") print("-" * 70) except KeyboardInterrupt: break except Exception as e: print(f"\nError: {e}")