| |
| from dataclasses import asdict, dataclass, field |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import json |
| from PIL import Image |
|
|
| from swift.utils import get_logger |
| from ..utils import Messages, Tool, messages_to_history |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class InferRequest: |
| """ |
| messages: Input in messages format. |
| Examples: [{ |
| "role": "user", # or assistant/system/role |
| "content": [ # str or List[Dict[str, Any]] |
| { |
| "type": "image", # or audio/video |
| "image": "<url/path/base64/PIL.Image>", |
| }, |
| {"type": "text", "text": "Please describe the picture."}, |
| ], |
| }] |
| The above content is equivalent to: |
| [{"role": "user", "content": "<image>Please describe the picture."}] |
| and additionally passing in images: ["<url/path/base64/PIL.Image>"]. |
| tools: Organize tools into the format of agent_template for system. for example, 'react_en'. |
| """ |
| messages: Messages |
|
|
| images: List[Union[str, Image.Image]] = field(default_factory=list) |
| audios: List[str] = field(default_factory=list) |
| videos: List[str] = field(default_factory=list) |
|
|
| tools: Optional[List[Tool]] = None |
| objects: Dict[str, List[Any]] = field(default_factory=dict) |
|
|
| def __post_init__(self): |
| for key in ['images', 'audios', 'videos']: |
| val = getattr(self, key) |
| if isinstance(val, str): |
| setattr(self, key, [val]) |
| assert isinstance(self.messages, list), f'messages: {self.messages}' |
|
|
| @staticmethod |
| def remove_response(messages) -> Optional[str]: |
| last_role = messages[-1]['role'] if messages else None |
| if last_role == 'assistant': |
| return messages.pop()['content'] |
|
|
| @staticmethod |
| def _to_printable(obj, key: Optional[str] = None): |
| if isinstance(obj, str) and key not in {'content', 'text'} and len(obj) >= 1000: |
| return f'<<<base64:{obj[:50]}..>>>' |
| elif isinstance(obj, list): |
| res = [] |
| for item in obj: |
| res.append(InferRequest._to_printable(item)) |
| return res |
| elif isinstance(obj, dict): |
| res = {} |
| for k, v in obj.items(): |
| res[k] = InferRequest._to_printable(v, key=k) |
| return res |
| return obj |
|
|
| def to_printable(self): |
| return InferRequest._to_printable(asdict(self)) |
|
|
|
|
| @dataclass |
| class RolloutInferRequest(InferRequest): |
| """ |
| A request class that modifies the 'images' attribute |
| to be a list of strings for compatibility with POST requests. |
| The strings can represent image URLs or Base64 encoded images. |
| """ |
| images: List[str] = field(default_factory=list) |
|
|
|
|
| @dataclass |
| class TemplateInputs(InferRequest): |
| """The training functionality has been added on top of the InferRequest. |
| |
| objects: Used for grounding tasks in a general format. |
| """ |
| rejected_response: Optional[str] = None |
| label: Optional[bool] = None |
|
|
|
|
| @dataclass |
| class StdTemplateInputs: |
| |
| messages: List[Dict[str, str]] |
| |
| system: Optional[str] = None |
| tools: Optional[List[Tool]] = None |
|
|
| rejected_response: Optional[str] = None |
| label: Optional[int] = None |
|
|
| images: List[Union[str, Image.Image]] = field(default_factory=list) |
| audios: List[str] = field(default_factory=list) |
| videos: List[str] = field(default_factory=list) |
| objects: Dict[str, List[Any]] = field(default_factory=dict) |
|
|
| def __post_init__(self): |
| self.image_idx = 0 |
| self.audio_idx = 0 |
| self.video_idx = 0 |
| self.ref_idx = 0 |
| self.bbox_idx = 0 |
| if self.images and not isinstance(self.images, (list, tuple)): |
| self.images = [self.images] |
| if self.videos and not isinstance(self.videos, (list, tuple)): |
| self.videos = [self.videos] |
| if self.audios and not isinstance(self.audios, (list, tuple)): |
| self.audios = [self.audios] |
|
|
| def to_history(self): |
| if not self.messages: |
| return None |
| return messages_to_history(self.messages) |
|
|
| @property |
| def is_multimodal(self): |
| return bool(self.images or self.audios or self.videos or self.objects) |
|
|
| @classmethod |
| def from_dict(cls, inputs: Dict[str, Any]) -> 'StdTemplateInputs': |
| kwargs = {} |
| for key in ['rejected_response', 'label']: |
| if key in inputs: |
| kwargs[key] = inputs[key] |
| messages = inputs['messages'] |
| tools = inputs.get('tools') |
| objects = inputs.get('objects') or {} |
|
|
| if messages and messages[0]['role'] == 'system': |
| message = messages.pop(0) |
| system = message['content'] |
| else: |
| system = None |
|
|
| for message in messages: |
| if message['role'] == 'tool_response': |
| message['role'] = 'tool' |
| if message['role'] in {'tool_call', 'tool'} and not isinstance(message['content'], str): |
| message['content'] = json.dumps(message['content'], ensure_ascii=False) |
|
|
| media_kwargs = StdTemplateInputs.remove_messages_media(messages) |
| for k in list(media_kwargs.keys()): |
| mm_data = media_kwargs[k] |
|
|
| inputs_mm_data = inputs.get(k) |
| if isinstance(inputs_mm_data, str): |
| inputs_mm_data = [inputs_mm_data] |
| inputs_mm_data = (inputs_mm_data or []).copy() |
| if mm_data: |
| assert not inputs_mm_data, f'self.{k}: {inputs_mm_data}' |
| else: |
| media_kwargs[k] = inputs_mm_data |
|
|
| return cls(messages=messages, system=system, tools=tools, objects=objects, **kwargs, **media_kwargs) |
|
|
| @staticmethod |
| def remove_messages_media(messages: Messages) -> Dict[str, Any]: |
| res = {'images': [], 'audios': [], 'videos': []} |
| for message in messages: |
| content = message['content'] |
| if isinstance(content, str): |
| continue |
| |
| new_content = '' |
| for item in content: |
| key: str = item['type'] |
| value = item.get(key) |
| if key == 'text': |
| new_content += value |
| continue |
| |
| |
| if key.endswith('_url'): |
| key = key[:-len('_url')] |
| new_content += f'<{key}>' |
| if isinstance(value, dict): |
| value = value['url'] |
| if value: |
| res[f'{key}s'].append(value) |
| message['content'] = new_content |
| return res |
|
|