| |
|
|
| from typing import Dict, Literal, Optional |
|
|
| from ..utils import Processor |
| from .base import Template |
| from .template_meta import TemplateMeta |
|
|
| TEMPLATE_MAPPING: Dict[str, TemplateMeta] = {} |
|
|
|
|
| def register_template(template_meta: TemplateMeta, *, exist_ok: bool = False) -> None: |
| template_type = template_meta.template_type |
| if not exist_ok and template_type in TEMPLATE_MAPPING: |
| raise ValueError(f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.') |
| TEMPLATE_MAPPING[template_type] = template_meta |
|
|
|
|
| def get_template( |
| template_type: str, |
| processor: Processor, |
| default_system: Optional[str] = None, |
| max_length: Optional[int] = None, |
| *, |
| use_chat_template: bool = True, |
| template_backend: Literal['swift', 'jinja'] = 'swift', |
| truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', |
| max_pixels: Optional[int] = None, |
| agent_template: Optional[str] = None, |
| norm_bbox: Literal['norm1000', 'none', None] = None, |
| response_prefix: Optional[str] = None, |
| |
| padding_side: Literal['left', 'right'] = 'right', |
| loss_scale: str = 'default', |
| sequence_parallel_size: int = 1) -> 'Template': |
| template_meta = TEMPLATE_MAPPING[template_type] |
| template_cls = template_meta.template_cls |
| return template_cls( |
| processor, |
| template_meta, |
| default_system, |
| max_length, |
| use_chat_template=use_chat_template, |
| template_backend=template_backend, |
| truncation_strategy=truncation_strategy, |
| max_pixels=max_pixels, |
| agent_template=agent_template, |
| norm_bbox=norm_bbox, |
| response_prefix=response_prefix, |
| padding_side=padding_side, |
| loss_scale=loss_scale, |
| sequence_parallel_size=sequence_parallel_size, |
| ) |
|
|
|
|
| def get_template_meta(template_type: str) -> TemplateMeta: |
| return TEMPLATE_MAPPING[template_type] |
|
|