granite-abstract / abstract_model.py
Gavin-Wang's picture
scripts
b1b2e62 verified
#!/usr/bin/env python3
"""
Abstract Model - Robust Inference with Forbidden Token Masking (Fixed Dimensions)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import importlib
import inspect
from pathlib import Path
class AbstractModel(nn.Module):
def __init__(self, sft_model_path, device=None):
super().__init__()
self.sft_model_path = sft_model_path
if device is None:
self._target_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
else:
self._target_device = device
print(f"Initializing AbstractModel on target device: {self._target_device}")
self.tokenizer = AutoTokenizer.from_pretrained(sft_model_path, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"Loading SFT model from {sft_model_path}...")
sft_model = AutoModelForCausalLM.from_pretrained(
sft_model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
attn_implementation="sdpa",
)
sft_model = sft_model.to(self._target_device)
sft_model.eval()
self.model_backbone = sft_model.model
self.lm_head = sft_model.lm_head
self.embed_layer = sft_model.get_input_embeddings()
self.config = sft_model.config
self.hidden_size = sft_model.config.hidden_size
self.vocab_size = sft_model.config.vocab_size
self.continuous_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
self.continuous_embed_layer = nn.Embedding(self.vocab_size, self.hidden_size)
self.continuous_head = self.continuous_head.to(self._target_device).to(torch.bfloat16)
self.continuous_embed_layer = self.continuous_embed_layer.to(self._target_device).to(torch.bfloat16)
self.think_id = self.tokenizer.encode("<think>", add_special_tokens=False)[0]
self.end_think_id = self.tokenizer.encode("</think>", add_special_tokens=False)[0]
forbidden_strings = [
"<|end_of_text|>", "<|start_of_role|>", "<|end_of_role|>",
"<|eot_id|>", "<|start_header_id|>", "user", "assistant", "system",
"<tool_call>", "<tool_response>"
]
self.banned_ids = []
if self.tokenizer.eos_token_id is not None:
self.banned_ids.append(self.tokenizer.eos_token_id)
for s in forbidden_strings:
ids = self.tokenizer.encode(s, add_special_tokens=False)
if ids:
self.banned_ids.extend(ids)
self.banned_ids = sorted(list(set(self.banned_ids)))
print(f"Banned {len(self.banned_ids)} structural tokens from Abstract Mode.")
@property
def device(self):
return self.embed_layer.weight.device
def _init_cache(self, batch_size, max_length):
try:
module = importlib.import_module(self.model_backbone.__module__)
if hasattr(module, "HybridMambaAttentionDynamicCache"):
CacheClass = getattr(module, "HybridMambaAttentionDynamicCache")
sig = inspect.signature(CacheClass.__init__)
kwargs = {}
if 'config' in sig.parameters: kwargs['config'] = self.config
if 'batch_size' in sig.parameters: kwargs['batch_size'] = batch_size
elif 'max_batch_size' in sig.parameters: kwargs['max_batch_size'] = batch_size
if 'max_cache_len' in sig.parameters: kwargs['max_cache_len'] = max_length
elif 'max_length' in sig.parameters: kwargs['max_length'] = max_length
if 'device' in sig.parameters: kwargs['device'] = self.device
if 'dtype' in sig.parameters: kwargs['dtype'] = self.embed_layer.weight.dtype
return CacheClass(**kwargs)
except Exception: pass
from transformers import DynamicCache
cache = DynamicCache()
cache.has_previous_state = False
return cache
def forward(
self,
input_ids,
max_length=512,
temperature=0.7,
sample=False,
no_grad=True,
sigma=0.0,
max_thinking_steps=64
):
if input_ids.device != self.device:
input_ids = input_ids.to(self.device)
if no_grad:
with torch.no_grad():
initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0)
else:
initial_embeddings = self.embed_layer(input_ids.unsqueeze(0)).squeeze(0)
in_abstract_mode = True
abstract_step_count = 0
generated_tokens = []
all_logits = []
mode_sequence = []
past_key_values = self._init_cache(batch_size=1, max_length=max_length + input_ids.shape[0] + 16)
current_step_input = initial_embeddings.unsqueeze(0)
current_seq_len = initial_embeddings.shape[0]
context = torch.no_grad() if no_grad else torch.enable_grad()
with context:
for step in range(max_length):
if step == 0:
position_ids = torch.arange(0, current_seq_len, dtype=torch.long, device=self.device).unsqueeze(0)
else:
position_ids = torch.tensor([[current_seq_len - 1]], dtype=torch.long, device=self.device)
outputs = self.model_backbone(
inputs_embeds=current_step_input,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
last_hidden = outputs.last_hidden_state[0, -1, :]
# 1. Natural Head (Used for stopping condition)
logits = self.lm_head(last_hidden)
stop_probs = F.softmax(logits.float(), dim=-1)
natural_next_token = torch.argmax(stop_probs, dim=-1).item()
# Force Stop Condition
force_stop = False
if in_abstract_mode:
abstract_step_count += 1
if abstract_step_count >= max_thinking_steps:
natural_next_token = self.end_think_id
force_stop = True
# 2. Logic Flow
if (natural_next_token == self.end_think_id or force_stop) and in_abstract_mode:
# Transition to Natural
in_abstract_mode = False
mode_sequence.append('T')
generated_tokens.append(self.end_think_id)
next_embedding = self.embed_layer(torch.tensor([[self.end_think_id]], device=self.device)).squeeze(0).squeeze(0)
elif in_abstract_mode:
# Abstract Generation
mode_sequence.append('A')
cont_logits = self.continuous_head(last_hidden)
if self.banned_ids:
cont_logits[self.banned_ids] = float('-inf')
cont_logits_f32 = cont_logits.float() / (temperature if temperature else 1.0)
abstract_vis_token = torch.argmax(cont_logits_f32, dim=-1).item()
generated_tokens.append(abstract_vis_token)
top_k = min(256, self.vocab_size // 4)
top_logits, top_indices = torch.topk(cont_logits_f32, top_k, dim=-1)
top_probs = F.softmax(top_logits, dim=-1).to(torch.bfloat16)
top_embeddings = self.continuous_embed_layer(top_indices)
next_embedding = top_probs @ top_embeddings
if sigma > 0.0 and not no_grad:
next_embedding = next_embedding + (torch.randn_like(next_embedding) * sigma)
else:
# Natural Generation
mode_sequence.append('N')
generated_tokens.append(natural_next_token)
next_embedding = self.embed_layer(torch.tensor([[natural_next_token]], device=self.device)).squeeze(0).squeeze(0)
if no_grad: all_logits.append(logits.detach().cpu())
if natural_next_token == self.tokenizer.eos_token_id and not in_abstract_mode:
break
current_step_input = next_embedding.unsqueeze(0).unsqueeze(0)
current_seq_len += 1
return {
'generated_tokens': torch.tensor(generated_tokens),
'logits': torch.stack(all_logits) if all_logits else torch.tensor([]),
'mode_sequence': mode_sequence,
}
def save_to_directory(self, output_dir):
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
try:
head_state = {k: v.cpu() for k, v in self.continuous_head.state_dict().items()}
embed_state = {k: v.cpu() for k, v in self.continuous_embed_layer.state_dict().items()}
torch.save(head_state, output_path / "continuous_head.pt")
torch.save(embed_state, output_path / "continuous_embed.pt")
config = {'sft_model_path': str(self.sft_model_path), 'hidden_size': self.hidden_size, 'vocab_size': self.vocab_size}
with open(output_path / "config.json", 'w') as f: json.dump(config, f)
print(f"Saved model to {output_dir}")
except Exception as e: print(f"Error saving model: {e}")
@staticmethod
def load_from_directory(output_dir, sft_model_path=None, device='cuda:0'):
output_path = Path(output_dir)
with open(output_path / "config.json", 'r') as f: config = json.load(f)
if sft_model_path is None: sft_model_path = config['sft_model_path']
model = AbstractModel(sft_model_path, device=device)
print(f"Loading checkpoint to {model.device}...")
head_state = torch.load(output_path / "continuous_head.pt", map_location=model.device)
embed_state = torch.load(output_path / "continuous_embed.pt", map_location=model.device)
model.continuous_head.load_state_dict(head_state)
model.continuous_embed_layer.load_state_dict(embed_state)
model.continuous_head = model.continuous_head.to(torch.bfloat16)
model.continuous_embed_layer = model.continuous_embed_layer.to(torch.bfloat16)
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--sft-model', required=True)
parser.add_argument('--load-model', default=None)
parser.add_argument('--max-length', type=int, default=256)
parser.add_argument('--temperature', type=float, default=0.7)
args = parser.parse_args()
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = AbstractModel.load_from_directory(args.load_model, sft_model_path=args.sft_model, device=device)
print("\n" + "=" * 70)
print(f"Abstract Model - Interactive Generation (Masked & Budgeted)")
print("=" * 70 + "\n")
while True:
try:
prompt = input("You: ").strip()
if not prompt: continue
if prompt.lower() in ['q', 'quit']: break
sys_prompt = "You are a reasoning assistant. Think step by step before answering."
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}]
formatted = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = model.tokenizer(formatted, return_tensors='pt', add_special_tokens=False)['input_ids'].to(model.device).squeeze(0)
print("Generating...", end="\r")
result = model.forward(
input_ids,
max_length=args.max_length,
temperature=args.temperature,
sample=False,
no_grad=True,
sigma=0.0,
max_thinking_steps=128
)
generated_ids = result['generated_tokens'].tolist()
modes = result['mode_sequence']
print("Assistant: ", end="")
for token_id, mode in zip(generated_ids, modes):
token_text = model.tokenizer.decode([token_id])
if mode == 'A':
print(f"\033[96m{token_text}\033[0m", end="", flush=True)
else:
print(token_text, end="", flush=True)
print("\n")
print(f"[Stats] Abstract: {modes.count('A')} | Natural: {modes.count('N')}")
print("-" * 70)
except KeyboardInterrupt: break
except Exception as e: print(f"\nError: {e}")