|
|
|
|
|
""" |
|
|
Agent Token Management System |
|
|
|
|
|
This module provides comprehensive agent token management for multi-agent training, |
|
|
including special token handling, embedding management, and integration with |
|
|
existing tokenization systems. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
from typing import Dict, List, Optional, Tuple, Any, Union |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, PreTrainedTokenizer |
|
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class AgentTokenConfig: |
|
|
"""Configuration for agent token management""" |
|
|
agent_prefix: str = "<|agent:" |
|
|
agent_suffix: str = "|>" |
|
|
special_tokens: Optional[Dict[str, str]] = None |
|
|
resize_embeddings: bool = True |
|
|
save_tokens: bool = True |
|
|
tokens_file: str = "agent_tokens.json" |
|
|
|
|
|
class AgentTokenManager: |
|
|
""" |
|
|
Manages agent-specific tokens and their integration with tokenizers |
|
|
""" |
|
|
|
|
|
def __init__(self, config: AgentTokenConfig): |
|
|
self.config = config |
|
|
self.agent_tokens: Dict[str, str] = {} |
|
|
self.token_ids: Dict[str, int] = {} |
|
|
self.original_vocab_size: Optional[int] = None |
|
|
|
|
|
def generate_agent_tokens(self, agents: List[str]) -> List[str]: |
|
|
"""Generate agent tokens for given agent list""" |
|
|
tokens = [] |
|
|
for agent in agents: |
|
|
token = f"{self.config.agent_prefix}{agent}{self.config.agent_suffix}" |
|
|
tokens.append(token) |
|
|
self.agent_tokens[agent] = token |
|
|
|
|
|
logger.info(f"Generated {len(tokens)} agent tokens: {tokens}") |
|
|
return tokens |
|
|
|
|
|
def add_agent_tokens_to_tokenizer(self, tokenizer: PreTrainedTokenizer, agents: List[str]) -> Tuple[PreTrainedTokenizer, List[str]]: |
|
|
""" |
|
|
Add agent tokens to tokenizer and return updated tokenizer with token list |
|
|
""" |
|
|
if not agents: |
|
|
logger.warning("No agents provided, skipping token addition") |
|
|
return tokenizer, [] |
|
|
|
|
|
|
|
|
agent_tokens = self.generate_agent_tokens(agents) |
|
|
|
|
|
|
|
|
existing_tokens = set(tokenizer.get_vocab().keys()) |
|
|
tokens_to_add = [token for token in agent_tokens if token not in existing_tokens] |
|
|
|
|
|
if not tokens_to_add: |
|
|
logger.info("All agent tokens already exist in tokenizer") |
|
|
return tokenizer, agent_tokens |
|
|
|
|
|
|
|
|
self.original_vocab_size = len(tokenizer) |
|
|
|
|
|
|
|
|
logger.info(f"Adding {len(tokens_to_add)} new agent tokens to tokenizer") |
|
|
tokenizer.add_special_tokens({ |
|
|
"additional_special_tokens": tokens_to_add |
|
|
}) |
|
|
|
|
|
|
|
|
for agent, token in self.agent_tokens.items(): |
|
|
if token in tokenizer.get_vocab(): |
|
|
self.token_ids[agent] = tokenizer.convert_tokens_to_ids(token) |
|
|
|
|
|
logger.info(f"Added agent tokens. New vocab size: {len(tokenizer)}") |
|
|
return tokenizer, agent_tokens |
|
|
|
|
|
def resize_model_embeddings(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer) -> torch.nn.Module: |
|
|
""" |
|
|
Resize model embeddings to accommodate new agent tokens |
|
|
""" |
|
|
if not self.config.resize_embeddings: |
|
|
logger.info("Embedding resize disabled, skipping") |
|
|
return model |
|
|
|
|
|
if self.original_vocab_size is None: |
|
|
logger.warning("Original vocab size not set, cannot resize embeddings") |
|
|
return model |
|
|
|
|
|
new_vocab_size = len(tokenizer) |
|
|
if new_vocab_size == self.original_vocab_size: |
|
|
logger.info("Vocab size unchanged, no embedding resize needed") |
|
|
return model |
|
|
|
|
|
logger.info(f"Resizing model embeddings from {self.original_vocab_size} to {new_vocab_size}") |
|
|
|
|
|
|
|
|
model.resize_token_embeddings(new_vocab_size) |
|
|
|
|
|
|
|
|
if hasattr(model, 'get_input_embeddings'): |
|
|
embeddings = model.get_input_embeddings() |
|
|
if hasattr(embeddings, 'weight'): |
|
|
with torch.no_grad(): |
|
|
|
|
|
new_embeddings = embeddings.weight[self.original_vocab_size:] |
|
|
torch.nn.init.normal_(new_embeddings, mean=0.0, std=0.02) |
|
|
|
|
|
logger.info("Model embeddings resized successfully") |
|
|
return model |
|
|
|
|
|
def format_agent_prompt(self, agent: str, text: str) -> str: |
|
|
"""Format text with agent token prefix""" |
|
|
if agent not in self.agent_tokens: |
|
|
logger.warning(f"Agent '{agent}' not found in token mappings") |
|
|
return text |
|
|
|
|
|
agent_token = self.agent_tokens[agent] |
|
|
return f"{agent_token}\n{text}" |
|
|
|
|
|
def extract_agent_from_text(self, text: str) -> Optional[str]: |
|
|
"""Extract agent name from text if it starts with agent token""" |
|
|
for agent, token in self.agent_tokens.items(): |
|
|
if text.startswith(token): |
|
|
return agent |
|
|
return None |
|
|
|
|
|
def get_agent_token_id(self, agent: str) -> Optional[int]: |
|
|
"""Get token ID for agent token""" |
|
|
return self.token_ids.get(agent) |
|
|
|
|
|
def save_agent_tokens(self, output_dir: str) -> str: |
|
|
"""Save agent tokens to file""" |
|
|
if not self.config.save_tokens: |
|
|
return "" |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
tokens_file = os.path.join(output_dir, self.config.tokens_file) |
|
|
|
|
|
tokens_data = { |
|
|
"agent_tokens": self.agent_tokens, |
|
|
"token_ids": self.token_ids, |
|
|
"config": { |
|
|
"agent_prefix": self.config.agent_prefix, |
|
|
"agent_suffix": self.config.agent_suffix, |
|
|
"original_vocab_size": self.original_vocab_size |
|
|
} |
|
|
} |
|
|
|
|
|
with open(tokens_file, 'w') as f: |
|
|
json.dump(tokens_data, f, indent=2) |
|
|
|
|
|
logger.info(f"Saved agent tokens to {tokens_file}") |
|
|
return tokens_file |
|
|
|
|
|
def load_agent_tokens(self, tokens_file: str) -> bool: |
|
|
"""Load agent tokens from file""" |
|
|
if not os.path.isfile(tokens_file): |
|
|
logger.warning(f"Agent tokens file not found: {tokens_file}") |
|
|
return False |
|
|
|
|
|
try: |
|
|
with open(tokens_file, 'r') as f: |
|
|
tokens_data = json.load(f) |
|
|
|
|
|
self.agent_tokens = tokens_data.get("agent_tokens", {}) |
|
|
self.token_ids = tokens_data.get("token_ids", {}) |
|
|
|
|
|
config_data = tokens_data.get("config", {}) |
|
|
self.original_vocab_size = config_data.get("original_vocab_size") |
|
|
|
|
|
logger.info(f"Loaded {len(self.agent_tokens)} agent tokens from {tokens_file}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load agent tokens: {e}") |
|
|
return False |
|
|
|
|
|
def get_agent_statistics(self) -> Dict[str, Any]: |
|
|
"""Get statistics about agent tokens""" |
|
|
return { |
|
|
"total_agents": len(self.agent_tokens), |
|
|
"agents": list(self.agent_tokens.keys()), |
|
|
"token_ids": self.token_ids, |
|
|
"original_vocab_size": self.original_vocab_size, |
|
|
"config": { |
|
|
"agent_prefix": self.config.agent_prefix, |
|
|
"agent_suffix": self.config.agent_suffix |
|
|
} |
|
|
} |
|
|
|
|
|
class AgentTokenizer: |
|
|
""" |
|
|
Enhanced tokenizer wrapper that integrates agent token management |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizer, agent_manager: AgentTokenManager): |
|
|
self.tokenizer = tokenizer |
|
|
self.agent_manager = agent_manager |
|
|
|
|
|
def tokenize_agent_text(self, agent: str, text: str, **kwargs) -> Dict[str, Any]: |
|
|
"""Tokenize text with agent prefix""" |
|
|
formatted_text = self.agent_manager.format_agent_prompt(agent, text) |
|
|
return self.tokenizer(formatted_text, **kwargs) |
|
|
|
|
|
def decode_agent_tokens(self, token_ids: Union[List[int], torch.Tensor], **kwargs) -> str: |
|
|
"""Decode token IDs back to text""" |
|
|
return self.tokenizer.decode(token_ids, **kwargs) |
|
|
|
|
|
def get_agent_attention_mask(self, input_ids: torch.Tensor, agent: str) -> torch.Tensor: |
|
|
"""Get attention mask with special handling for agent tokens""" |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
|
|
|
agent_token_id = self.agent_manager.get_agent_token_id(agent) |
|
|
if agent_token_id is not None: |
|
|
|
|
|
agent_positions = (input_ids == agent_token_id) |
|
|
attention_mask[agent_positions] = 1 |
|
|
|
|
|
return attention_mask |
|
|
|
|
|
def __getattr__(self, name): |
|
|
"""Delegate unknown attributes to underlying tokenizer""" |
|
|
return getattr(self.tokenizer, name) |
|
|
|
|
|
class AgentTokenValidator: |
|
|
"""Validator for agent token configurations""" |
|
|
|
|
|
@staticmethod |
|
|
def validate_agent_tokens(agents: List[str], config: AgentTokenConfig) -> Dict[str, Any]: |
|
|
"""Validate agent token configuration""" |
|
|
validation_result = { |
|
|
"valid": True, |
|
|
"errors": [], |
|
|
"warnings": [], |
|
|
"tokens": {} |
|
|
} |
|
|
|
|
|
if not agents: |
|
|
validation_result["warnings"].append("No agents provided") |
|
|
return validation_result |
|
|
|
|
|
|
|
|
if len(agents) != len(set(agents)): |
|
|
validation_result["errors"].append("Duplicate agents found") |
|
|
validation_result["valid"] = False |
|
|
|
|
|
|
|
|
manager = AgentTokenManager(config) |
|
|
tokens = manager.generate_agent_tokens(agents) |
|
|
|
|
|
|
|
|
token_set = set(tokens) |
|
|
if len(token_set) != len(tokens): |
|
|
validation_result["errors"].append("Duplicate tokens generated") |
|
|
validation_result["valid"] = False |
|
|
|
|
|
|
|
|
for agent, token in zip(agents, tokens): |
|
|
if len(token) > 50: |
|
|
validation_result["warnings"].append(f"Long token for agent '{agent}': {token}") |
|
|
|
|
|
validation_result["tokens"] = dict(zip(agents, tokens)) |
|
|
|
|
|
return validation_result |
|
|
|
|
|
@staticmethod |
|
|
def validate_tokenizer_compatibility(tokenizer: PreTrainedTokenizer, agents: List[str], config: AgentTokenConfig) -> Dict[str, Any]: |
|
|
"""Validate tokenizer compatibility with agent tokens""" |
|
|
validation_result = { |
|
|
"compatible": True, |
|
|
"errors": [], |
|
|
"warnings": [], |
|
|
"existing_tokens": [], |
|
|
"new_tokens": [] |
|
|
} |
|
|
|
|
|
if not agents: |
|
|
return validation_result |
|
|
|
|
|
|
|
|
manager = AgentTokenManager(config) |
|
|
tokens = manager.generate_agent_tokens(agents) |
|
|
|
|
|
|
|
|
vocab = tokenizer.get_vocab() |
|
|
for agent, token in zip(agents, tokens): |
|
|
if token in vocab: |
|
|
validation_result["existing_tokens"].append(agent) |
|
|
else: |
|
|
validation_result["new_tokens"].append(agent) |
|
|
|
|
|
|
|
|
for token in tokens: |
|
|
if token in vocab: |
|
|
|
|
|
if hasattr(tokenizer, 'special_tokens_map'): |
|
|
special_tokens = tokenizer.special_tokens_map |
|
|
if token not in special_tokens.values(): |
|
|
validation_result["warnings"].append(f"Token '{token}' exists in vocab but not as special token") |
|
|
|
|
|
return validation_result |
|
|
|
|
|
|
|
|
class MoEAgentTokenIntegration: |
|
|
""" |
|
|
Integration layer between agent tokens and MoE framework |
|
|
""" |
|
|
|
|
|
def __init__(self, agent_manager: AgentTokenManager): |
|
|
self.agent_manager = agent_manager |
|
|
self.agent_to_expert_mapping: Dict[str, str] = {} |
|
|
|
|
|
def map_agent_to_expert(self, agent: str, expert: str): |
|
|
"""Map agent to MoE expert specialization""" |
|
|
self.agent_to_expert_mapping[agent] = expert |
|
|
logger.info(f"Mapped agent '{agent}' to expert '{expert}'") |
|
|
|
|
|
def get_expert_for_agent(self, agent: str) -> Optional[str]: |
|
|
"""Get expert specialization for agent""" |
|
|
return self.agent_to_expert_mapping.get(agent) |
|
|
|
|
|
def format_moe_prompt(self, agent: str, text: str, expert: Optional[str] = None) -> str: |
|
|
"""Format prompt for MoE framework with agent and expert context""" |
|
|
|
|
|
formatted_text = self.agent_manager.format_agent_prompt(agent, text) |
|
|
|
|
|
|
|
|
if expert: |
|
|
expert_context = f"\n<|expert:{expert}|>\n" |
|
|
formatted_text = formatted_text.replace("\n", expert_context, 1) |
|
|
|
|
|
return formatted_text |
|
|
|
|
|
def extract_agent_and_expert(self, text: str) -> Tuple[Optional[str], Optional[str]]: |
|
|
"""Extract both agent and expert from formatted text""" |
|
|
agent = self.agent_manager.extract_agent_from_text(text) |
|
|
|
|
|
|
|
|
expert = None |
|
|
if "<|expert:" in text and "|>" in text: |
|
|
start = text.find("<|expert:") + 9 |
|
|
end = text.find("|>", start) |
|
|
if end > start: |
|
|
expert = text[start:end] |
|
|
|
|
|
return agent, expert |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
config = AgentTokenConfig( |
|
|
agent_prefix="<|agent:", |
|
|
agent_suffix="|>", |
|
|
resize_embeddings=True |
|
|
) |
|
|
|
|
|
|
|
|
agents = ["SWE", "SQE", "DevOps", "Architect", "Security"] |
|
|
|
|
|
|
|
|
manager = AgentTokenManager(config) |
|
|
|
|
|
|
|
|
tokens = manager.generate_agent_tokens(agents) |
|
|
print(f"Generated tokens: {tokens}") |
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct") |
|
|
|
|
|
|
|
|
updated_tokenizer, agent_tokens = manager.add_agent_tokens_to_tokenizer(tokenizer, agents) |
|
|
|
|
|
print(f"Updated tokenizer vocab size: {len(updated_tokenizer)}") |
|
|
print(f"Agent token IDs: {manager.token_ids}") |
|
|
|
|
|
|
|
|
test_text = "How do I implement a binary search?" |
|
|
formatted = manager.format_agent_prompt("SWE", test_text) |
|
|
print(f"Formatted prompt: {formatted}") |
|
|
|
|
|
|
|
|
extracted_agent = manager.extract_agent_from_text(formatted) |
|
|
print(f"Extracted agent: {extracted_agent}") |
|
|
|