| from typing import Dict, List, Optional, Union |
|
|
| from lagent.agents.aggregator.default_aggregator import DefaultAggregator |
| from lagent.memory.base_memory import Memory |
| from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode |
|
|
|
|
| class InternLMToolAggregator(DefaultAggregator): |
|
|
| def __init__(self, |
| environment_role='environment', |
| environment_begin='', |
| environment_end='', |
| user_names: Optional[List[str]] = None, |
| few_shot: Optional[List[List[dict]]] = None): |
| self.environment_role = environment_role |
| self.environment_begin = environment_begin |
| self.environment_end = environment_end |
| self.user_names = user_names or ['user'] |
| self.few_shot = few_shot or [] |
|
|
| def aggregate(self, |
| messages: Memory, |
| name: str, |
| parser: Union[ToolParser, MixedToolParser], |
| system_instruction: str = None) -> List[Dict[str, str]]: |
| _message = [] |
| messages = messages.get_memory() |
| if system_instruction: |
| _message.extend( |
| self.aggregate_system_intruction(system_instruction)) |
| tool_instruction = parser.format_instruction() |
| if tool_instruction: |
| if isinstance(tool_instruction, str): |
| tool_instruction = dict( |
| role='system', content=tool_instruction) |
| if parser.tool_type: |
| tool_instruction['name'] = parser.tool_type |
| if isinstance(tool_instruction, dict): |
| tool_instruction = [tool_instruction] |
| _message.extend(tool_instruction) |
|
|
| for shot in self.few_shot: |
| i = 0 |
| while i < len(shot): |
| msg = shot[i] |
| if msg['role'] in ['assistant', 'user', 'system']: |
| _message.append(msg) |
| elif msg['role'] == self.environment_role: |
| if not msg['content'].startswith(self.environment_begin): |
| msg['content'] = self.environment_begin + msg['content'] |
| if not msg['content'].endswith(self.environment_end): |
| msg['content'] += self.environment_end |
| _message.append(msg) |
| elif msg['role'] in ['thought', 'language']: |
| if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool': |
| _message.append( |
| dict( |
| role='assistant', |
| content=parser.format_response( |
| dict( |
| tool_type=shot[i + 1]['name'], |
| thought=msg['content'], |
| action=shot[i + 1]['content'], |
| status=None)))) |
| i += 1 |
| else: |
| _message.append( |
| dict( |
| role='assistant', |
| content=parser.format_response( |
| dict( |
| tool_type=None, |
| thought=msg['content'], |
| action=None, |
| status=None)))) |
| else: |
| raise KeyError(f'Unkown role: {msg["role"]}') |
| i += 1 |
|
|
| tool_type = None |
| for message in messages: |
| if message.sender == name: |
| if isinstance(message.formatted, dict): |
| parsed = message.formatted |
| if parsed['status'] == ToolStatusCode.PARSING_ERROR: |
| continue |
| _message.append( |
| dict( |
| role='assistant', |
| content=parser.format_response(parsed))) |
| tool_type = parsed['tool_type'] |
| else: |
| _message.append( |
| dict(role='assistant', content=str(message.content))) |
| elif message.sender in self.user_names: |
| _message.append(dict(role='user', content=message.content)) |
| else: |
| msg = dict( |
| role=self.environment_role, |
| content=self.environment_begin + str(message.content) + |
| self.environment_end) |
| if tool_type: |
| msg['name'] = tool_type |
| _message.append(msg) |
| return _message |
|
|