File size: 15,289 Bytes
c8b77b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
#!/usr/bin/env python3
"""
Agent Token Management System

This module provides comprehensive agent token management for multi-agent training,
including special token handling, embedding management, and integration with
existing tokenization systems.
"""

import os
import json
import logging
from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass
from pathlib import Path

import torch
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

logger = logging.getLogger(__name__)

@dataclass
class AgentTokenConfig:
    """Configuration for agent token management"""
    agent_prefix: str = "<|agent:"
    agent_suffix: str = "|>"
    special_tokens: Optional[Dict[str, str]] = None
    resize_embeddings: bool = True
    save_tokens: bool = True
    tokens_file: str = "agent_tokens.json"

class AgentTokenManager:
    """
    Manages agent-specific tokens and their integration with tokenizers
    """
    
    def __init__(self, config: AgentTokenConfig):
        self.config = config
        self.agent_tokens: Dict[str, str] = {}
        self.token_ids: Dict[str, int] = {}
        self.original_vocab_size: Optional[int] = None
        
    def generate_agent_tokens(self, agents: List[str]) -> List[str]:
        """Generate agent tokens for given agent list"""
        tokens = []
        for agent in agents:
            token = f"{self.config.agent_prefix}{agent}{self.config.agent_suffix}"
            tokens.append(token)
            self.agent_tokens[agent] = token
        
        logger.info(f"Generated {len(tokens)} agent tokens: {tokens}")
        return tokens
    
    def add_agent_tokens_to_tokenizer(self, tokenizer: PreTrainedTokenizer, agents: List[str]) -> Tuple[PreTrainedTokenizer, List[str]]:
        """
        Add agent tokens to tokenizer and return updated tokenizer with token list
        """
        if not agents:
            logger.warning("No agents provided, skipping token addition")
            return tokenizer, []
        
        # Generate agent tokens
        agent_tokens = self.generate_agent_tokens(agents)
        
        # Check which tokens need to be added
        existing_tokens = set(tokenizer.get_vocab().keys())
        tokens_to_add = [token for token in agent_tokens if token not in existing_tokens]
        
        if not tokens_to_add:
            logger.info("All agent tokens already exist in tokenizer")
            return tokenizer, agent_tokens
        
        # Store original vocab size
        self.original_vocab_size = len(tokenizer)
        
        # Add special tokens
        logger.info(f"Adding {len(tokens_to_add)} new agent tokens to tokenizer")
        tokenizer.add_special_tokens({
            "additional_special_tokens": tokens_to_add
        })
        
        # Update token IDs mapping
        for agent, token in self.agent_tokens.items():
            if token in tokenizer.get_vocab():
                self.token_ids[agent] = tokenizer.convert_tokens_to_ids(token)
        
        logger.info(f"Added agent tokens. New vocab size: {len(tokenizer)}")
        return tokenizer, agent_tokens
    
    def resize_model_embeddings(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer) -> torch.nn.Module:
        """
        Resize model embeddings to accommodate new agent tokens
        """
        if not self.config.resize_embeddings:
            logger.info("Embedding resize disabled, skipping")
            return model
        
        if self.original_vocab_size is None:
            logger.warning("Original vocab size not set, cannot resize embeddings")
            return model
        
        new_vocab_size = len(tokenizer)
        if new_vocab_size == self.original_vocab_size:
            logger.info("Vocab size unchanged, no embedding resize needed")
            return model
        
        logger.info(f"Resizing model embeddings from {self.original_vocab_size} to {new_vocab_size}")
        
        # Resize embeddings
        model.resize_token_embeddings(new_vocab_size)
        
        # Initialize new embeddings (copy from unk token or use random initialization)
        if hasattr(model, 'get_input_embeddings'):
            embeddings = model.get_input_embeddings()
            if hasattr(embeddings, 'weight'):
                with torch.no_grad():
                    # Initialize new embeddings with small random values
                    new_embeddings = embeddings.weight[self.original_vocab_size:]
                    torch.nn.init.normal_(new_embeddings, mean=0.0, std=0.02)
        
        logger.info("Model embeddings resized successfully")
        return model
    
    def format_agent_prompt(self, agent: str, text: str) -> str:
        """Format text with agent token prefix"""
        if agent not in self.agent_tokens:
            logger.warning(f"Agent '{agent}' not found in token mappings")
            return text
        
        agent_token = self.agent_tokens[agent]
        return f"{agent_token}\n{text}"
    
    def extract_agent_from_text(self, text: str) -> Optional[str]:
        """Extract agent name from text if it starts with agent token"""
        for agent, token in self.agent_tokens.items():
            if text.startswith(token):
                return agent
        return None
    
    def get_agent_token_id(self, agent: str) -> Optional[int]:
        """Get token ID for agent token"""
        return self.token_ids.get(agent)
    
    def save_agent_tokens(self, output_dir: str) -> str:
        """Save agent tokens to file"""
        if not self.config.save_tokens:
            return ""
        
        os.makedirs(output_dir, exist_ok=True)
        tokens_file = os.path.join(output_dir, self.config.tokens_file)
        
        tokens_data = {
            "agent_tokens": self.agent_tokens,
            "token_ids": self.token_ids,
            "config": {
                "agent_prefix": self.config.agent_prefix,
                "agent_suffix": self.config.agent_suffix,
                "original_vocab_size": self.original_vocab_size
            }
        }
        
        with open(tokens_file, 'w') as f:
            json.dump(tokens_data, f, indent=2)
        
        logger.info(f"Saved agent tokens to {tokens_file}")
        return tokens_file
    
    def load_agent_tokens(self, tokens_file: str) -> bool:
        """Load agent tokens from file"""
        if not os.path.isfile(tokens_file):
            logger.warning(f"Agent tokens file not found: {tokens_file}")
            return False
        
        try:
            with open(tokens_file, 'r') as f:
                tokens_data = json.load(f)
            
            self.agent_tokens = tokens_data.get("agent_tokens", {})
            self.token_ids = tokens_data.get("token_ids", {})
            
            config_data = tokens_data.get("config", {})
            self.original_vocab_size = config_data.get("original_vocab_size")
            
            logger.info(f"Loaded {len(self.agent_tokens)} agent tokens from {tokens_file}")
            return True
            
        except Exception as e:
            logger.error(f"Failed to load agent tokens: {e}")
            return False
    
    def get_agent_statistics(self) -> Dict[str, Any]:
        """Get statistics about agent tokens"""
        return {
            "total_agents": len(self.agent_tokens),
            "agents": list(self.agent_tokens.keys()),
            "token_ids": self.token_ids,
            "original_vocab_size": self.original_vocab_size,
            "config": {
                "agent_prefix": self.config.agent_prefix,
                "agent_suffix": self.config.agent_suffix
            }
        }

class AgentTokenizer:
    """
    Enhanced tokenizer wrapper that integrates agent token management
    """
    
    def __init__(self, tokenizer: PreTrainedTokenizer, agent_manager: AgentTokenManager):
        self.tokenizer = tokenizer
        self.agent_manager = agent_manager
    
    def tokenize_agent_text(self, agent: str, text: str, **kwargs) -> Dict[str, Any]:
        """Tokenize text with agent prefix"""
        formatted_text = self.agent_manager.format_agent_prompt(agent, text)
        return self.tokenizer(formatted_text, **kwargs)
    
    def decode_agent_tokens(self, token_ids: Union[List[int], torch.Tensor], **kwargs) -> str:
        """Decode token IDs back to text"""
        return self.tokenizer.decode(token_ids, **kwargs)
    
    def get_agent_attention_mask(self, input_ids: torch.Tensor, agent: str) -> torch.Tensor:
        """Get attention mask with special handling for agent tokens"""
        attention_mask = torch.ones_like(input_ids)
        
        # Find agent token position
        agent_token_id = self.agent_manager.get_agent_token_id(agent)
        if agent_token_id is not None:
            # Ensure agent token is attended to
            agent_positions = (input_ids == agent_token_id)
            attention_mask[agent_positions] = 1
        
        return attention_mask
    
    def __getattr__(self, name):
        """Delegate unknown attributes to underlying tokenizer"""
        return getattr(self.tokenizer, name)

class AgentTokenValidator:
    """Validator for agent token configurations"""
    
    @staticmethod
    def validate_agent_tokens(agents: List[str], config: AgentTokenConfig) -> Dict[str, Any]:
        """Validate agent token configuration"""
        validation_result = {
            "valid": True,
            "errors": [],
            "warnings": [],
            "tokens": {}
        }
        
        if not agents:
            validation_result["warnings"].append("No agents provided")
            return validation_result
        
        # Check for duplicate agents
        if len(agents) != len(set(agents)):
            validation_result["errors"].append("Duplicate agents found")
            validation_result["valid"] = False
        
        # Generate and validate tokens
        manager = AgentTokenManager(config)
        tokens = manager.generate_agent_tokens(agents)
        
        # Check for token conflicts
        token_set = set(tokens)
        if len(token_set) != len(tokens):
            validation_result["errors"].append("Duplicate tokens generated")
            validation_result["valid"] = False
        
        # Check token length
        for agent, token in zip(agents, tokens):
            if len(token) > 50:  # Reasonable limit
                validation_result["warnings"].append(f"Long token for agent '{agent}': {token}")
        
        validation_result["tokens"] = dict(zip(agents, tokens))
        
        return validation_result
    
    @staticmethod
    def validate_tokenizer_compatibility(tokenizer: PreTrainedTokenizer, agents: List[str], config: AgentTokenConfig) -> Dict[str, Any]:
        """Validate tokenizer compatibility with agent tokens"""
        validation_result = {
            "compatible": True,
            "errors": [],
            "warnings": [],
            "existing_tokens": [],
            "new_tokens": []
        }
        
        if not agents:
            return validation_result
        
        # Generate tokens
        manager = AgentTokenManager(config)
        tokens = manager.generate_agent_tokens(agents)
        
        # Check existing vocabulary
        vocab = tokenizer.get_vocab()
        for agent, token in zip(agents, tokens):
            if token in vocab:
                validation_result["existing_tokens"].append(agent)
            else:
                validation_result["new_tokens"].append(agent)
        
        # Check for potential conflicts
        for token in tokens:
            if token in vocab:
                # Check if it's already a special token
                if hasattr(tokenizer, 'special_tokens_map'):
                    special_tokens = tokenizer.special_tokens_map
                    if token not in special_tokens.values():
                        validation_result["warnings"].append(f"Token '{token}' exists in vocab but not as special token")
        
        return validation_result

# Integration with existing MoE framework
class MoEAgentTokenIntegration:
    """
    Integration layer between agent tokens and MoE framework
    """
    
    def __init__(self, agent_manager: AgentTokenManager):
        self.agent_manager = agent_manager
        self.agent_to_expert_mapping: Dict[str, str] = {}
    
    def map_agent_to_expert(self, agent: str, expert: str):
        """Map agent to MoE expert specialization"""
        self.agent_to_expert_mapping[agent] = expert
        logger.info(f"Mapped agent '{agent}' to expert '{expert}'")
    
    def get_expert_for_agent(self, agent: str) -> Optional[str]:
        """Get expert specialization for agent"""
        return self.agent_to_expert_mapping.get(agent)
    
    def format_moe_prompt(self, agent: str, text: str, expert: Optional[str] = None) -> str:
        """Format prompt for MoE framework with agent and expert context"""
        # Start with agent token
        formatted_text = self.agent_manager.format_agent_prompt(agent, text)
        
        # Add expert context if available
        if expert:
            expert_context = f"\n<|expert:{expert}|>\n"
            formatted_text = formatted_text.replace("\n", expert_context, 1)
        
        return formatted_text
    
    def extract_agent_and_expert(self, text: str) -> Tuple[Optional[str], Optional[str]]:
        """Extract both agent and expert from formatted text"""
        agent = self.agent_manager.extract_agent_from_text(text)
        
        # Extract expert if present
        expert = None
        if "<|expert:" in text and "|>" in text:
            start = text.find("<|expert:") + 9
            end = text.find("|>", start)
            if end > start:
                expert = text[start:end]
        
        return agent, expert

# Example usage and testing
if __name__ == "__main__":
    # Configure logging
    logging.basicConfig(level=logging.INFO)
    
    # Example configuration
    config = AgentTokenConfig(
        agent_prefix="<|agent:",
        agent_suffix="|>",
        resize_embeddings=True
    )
    
    # Example agents
    agents = ["SWE", "SQE", "DevOps", "Architect", "Security"]
    
    # Create agent manager
    manager = AgentTokenManager(config)
    
    # Generate tokens
    tokens = manager.generate_agent_tokens(agents)
    print(f"Generated tokens: {tokens}")
    
    # Example tokenizer (would be loaded from actual model)
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
    
    # Add tokens to tokenizer
    updated_tokenizer, agent_tokens = manager.add_agent_tokens_to_tokenizer(tokenizer, agents)
    
    print(f"Updated tokenizer vocab size: {len(updated_tokenizer)}")
    print(f"Agent token IDs: {manager.token_ids}")
    
    # Test formatting
    test_text = "How do I implement a binary search?"
    formatted = manager.format_agent_prompt("SWE", test_text)
    print(f"Formatted prompt: {formatted}")
    
    # Test extraction
    extracted_agent = manager.extract_agent_from_text(formatted)
    print(f"Extracted agent: {extracted_agent}")