Linksome commited on
Commit
aa048fe
·
verified ·
1 Parent(s): 43a4147

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LlamaFactory/.github/ISSUE_TEMPLATE/1-bug-report.yml +61 -0
  2. LlamaFactory/.github/ISSUE_TEMPLATE/2-feature-request.yml +41 -0
  3. LlamaFactory/.github/ISSUE_TEMPLATE/config.yml +8 -0
  4. LlamaFactory/.github/workflows/docker.yml +116 -0
  5. LlamaFactory/.github/workflows/publish.yml +37 -0
  6. LlamaFactory/src/api.py +33 -0
  7. LlamaFactory/src/llamafactory/__init__.py +31 -0
  8. LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-311.pyc +0 -0
  9. LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-312.pyc +0 -0
  10. LlamaFactory/src/llamafactory/__pycache__/cli.cpython-311.pyc +0 -0
  11. LlamaFactory/src/llamafactory/__pycache__/cli.cpython-312.pyc +0 -0
  12. LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-311.pyc +0 -0
  13. LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-312.pyc +0 -0
  14. LlamaFactory/src/llamafactory/api/__init__.py +0 -0
  15. LlamaFactory/src/llamafactory/api/__pycache__/common.cpython-311.pyc +0 -0
  16. LlamaFactory/src/llamafactory/api/__pycache__/protocol.cpython-311.pyc +0 -0
  17. LlamaFactory/src/llamafactory/api/app.py +133 -0
  18. LlamaFactory/src/llamafactory/api/chat.py +291 -0
  19. LlamaFactory/src/llamafactory/api/common.py +96 -0
  20. LlamaFactory/src/llamafactory/api/protocol.py +156 -0
  21. LlamaFactory/src/llamafactory/chat/__init__.py +19 -0
  22. LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-311.pyc +0 -0
  23. LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-312.pyc +0 -0
  24. LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-311.pyc +0 -0
  25. LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc +0 -0
  26. LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-311.pyc +0 -0
  27. LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc +0 -0
  28. LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc +0 -0
  29. LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-312.pyc +0 -0
  30. LlamaFactory/src/llamafactory/chat/base_engine.py +98 -0
  31. LlamaFactory/src/llamafactory/chat/chat_model.py +210 -0
  32. LlamaFactory/src/llamafactory/chat/hf_engine.py +412 -0
  33. LlamaFactory/src/llamafactory/chat/kt_engine.py +284 -0
  34. LlamaFactory/src/llamafactory/chat/sglang_engine.py +289 -0
  35. LlamaFactory/src/llamafactory/chat/vllm_engine.py +271 -0
  36. LlamaFactory/src/llamafactory/cli.py +31 -0
  37. LlamaFactory/src/llamafactory/data/.ipynb_checkpoints/template-checkpoint.py +2175 -0
  38. LlamaFactory/src/llamafactory/data/__init__.py +37 -0
  39. LlamaFactory/src/llamafactory/data/__pycache__/__init__.cpython-311.pyc +0 -0
  40. LlamaFactory/src/llamafactory/data/__pycache__/__init__.cpython-312.pyc +0 -0
  41. LlamaFactory/src/llamafactory/data/__pycache__/collator.cpython-311.pyc +0 -0
  42. LlamaFactory/src/llamafactory/data/__pycache__/collator.cpython-312.pyc +0 -0
  43. LlamaFactory/src/llamafactory/data/__pycache__/converter.cpython-311.pyc +0 -0
  44. LlamaFactory/src/llamafactory/data/__pycache__/converter.cpython-312.pyc +0 -0
  45. LlamaFactory/src/llamafactory/data/__pycache__/data_utils.cpython-311.pyc +0 -0
  46. LlamaFactory/src/llamafactory/data/__pycache__/data_utils.cpython-312.pyc +0 -0
  47. LlamaFactory/src/llamafactory/data/__pycache__/formatter.cpython-311.pyc +0 -0
  48. LlamaFactory/src/llamafactory/data/__pycache__/formatter.cpython-312.pyc +0 -0
  49. LlamaFactory/src/llamafactory/data/__pycache__/loader.cpython-311.pyc +0 -0
  50. LlamaFactory/src/llamafactory/data/__pycache__/loader.cpython-312.pyc +0 -0
LlamaFactory/.github/ISSUE_TEMPLATE/1-bug-report.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F41B Bug / help"
2
+ description: Create a report to help us improve the LLaMA Factory
3
+ labels: ["bug", "pending"]
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: |
8
+ Issues included in **[FAQs](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** or those with **insufficient** information may be closed without a response.
9
+ 已经包含在 **[常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614)** 内或提供信息**不完整**的 issues 可能不会被回复。
10
+
11
+ - type: markdown
12
+ attributes:
13
+ value: |
14
+ Please do not create issues that are not related to framework bugs under this category, use **[Discussions](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)** instead.
15
+ 请勿在此分类下创建和框架 bug 无关的 issues,训练问题求助请使用 **[讨论区](https://github.com/hiyouga/LLaMA-Factory/discussions/categories/q-a)**。
16
+
17
+ - type: checkboxes
18
+ id: reminder
19
+ attributes:
20
+ label: Reminder
21
+ description: |
22
+ Please ensure you have read the above rules carefully and searched the existing issues (including FAQs).
23
+ 请确保您已经认真阅读了上述规则并且搜索过现有的 issues(包括常见问题)。
24
+
25
+ options:
26
+ - label: I have read the above rules and searched the existing issues.
27
+ required: true
28
+
29
+ - type: textarea
30
+ id: system-info
31
+ validations:
32
+ required: true
33
+ attributes:
34
+ label: System Info
35
+ description: |
36
+ Please share your system info with us. You can run the command **llamafactory-cli env** and copy-paste its output below.
37
+ 请提供您的系统信息。您可以在命令行运行 **llamafactory-cli env** 并将其输出复制到该文本框中。
38
+
39
+ placeholder: llamafactory version, platform, python version, ...
40
+
41
+ - type: textarea
42
+ id: reproduction
43
+ validations:
44
+ required: true
45
+ attributes:
46
+ label: Reproduction
47
+ description: |
48
+ Please provide entry arguments, error messages and stack traces that reproduces the problem.
49
+ 请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
50
+
51
+ value: |
52
+ ```text
53
+ Put your message here.
54
+ ```
55
+
56
+ - type: textarea
57
+ id: others
58
+ validations:
59
+ required: false
60
+ attributes:
61
+ label: Others
LlamaFactory/.github/ISSUE_TEMPLATE/2-feature-request.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F680 Feature request"
2
+ description: Submit a request for a new feature
3
+ labels: ["enhancement", "pending"]
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: |
8
+ Please do not create issues that are not related to new features under this category.
9
+ 请勿在此分类下创建和新特性无关的 issues。
10
+
11
+ - type: checkboxes
12
+ id: reminder
13
+ attributes:
14
+ label: Reminder
15
+ description: |
16
+ Please ensure you have read the above rules carefully and searched the existing issues.
17
+ 请确保您已经认真阅读了上述规则并且搜索过现有的 issues。
18
+
19
+ options:
20
+ - label: I have read the above rules and searched the existing issues.
21
+ required: true
22
+
23
+ - type: textarea
24
+ id: description
25
+ validations:
26
+ required: true
27
+ attributes:
28
+ label: Description
29
+ description: |
30
+ A clear and concise description of the feature proposal.
31
+ 请详细描述您希望加入的新功能特性。
32
+
33
+ - type: textarea
34
+ id: contribution
35
+ validations:
36
+ required: false
37
+ attributes:
38
+ label: Pull Request
39
+ description: |
40
+ Have you already created the relevant PR and submitted the code?
41
+ 您是否已经创建了相关 PR 并提交了代码?
LlamaFactory/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ blank_issues_enabled: false
2
+ contact_links:
3
+ - name: 📚 FAQs | 常见问题
4
+ url: https://github.com/hiyouga/LLaMA-Factory/issues/4614
5
+ about: Reading in advance is recommended | 建议提前阅读
6
+ - name: Discussions | 讨论区
7
+ url: https://github.com/hiyouga/LLaMA-Factory/discussions
8
+ about: Please ask fine-tuning questions here | 请在这里讨论训练问题
LlamaFactory/.github/workflows/docker.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: docker
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ push:
6
+ branches:
7
+ - "main"
8
+ paths:
9
+ - "**/*.py"
10
+ - "pyproject.toml"
11
+ - "docker/**"
12
+ - ".github/workflows/*.yml"
13
+ pull_request:
14
+ branches:
15
+ - "main"
16
+ paths:
17
+ - "**/*.py"
18
+ - "pyproject.toml"
19
+ - "docker/**"
20
+ - ".github/workflows/*.yml"
21
+ release:
22
+ types:
23
+ - published
24
+
25
+ jobs:
26
+ build:
27
+ strategy:
28
+ fail-fast: false
29
+ matrix:
30
+ include:
31
+ - device: "cuda"
32
+ - device: "npu-a2"
33
+ - device: "npu-a3"
34
+
35
+ runs-on: ubuntu-latest
36
+
37
+ concurrency:
38
+ group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.device }}
39
+ cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
40
+
41
+ environment:
42
+ name: docker
43
+ url: https://hub.docker.com/r/hiyouga/llamafactory
44
+
45
+ steps:
46
+ - name: Free up disk space
47
+ uses: jlumbroso/free-disk-space@v1.3.1
48
+ with:
49
+ tool-cache: true
50
+ docker-images: false
51
+
52
+ - name: Checkout
53
+ uses: actions/checkout@v6
54
+
55
+ - name: Get llamafactory version
56
+ id: version
57
+ run: |
58
+ if [ "${{ github.event_name }}" = "release" ]; then
59
+ echo "tag=$(grep -oP 'VERSION = "\K[^"]+' src/llamafactory/extras/env.py)" >> "$GITHUB_OUTPUT"
60
+ else
61
+ echo "tag=latest" >> "$GITHUB_OUTPUT"
62
+ fi
63
+
64
+ - name: Set up Docker Buildx
65
+ uses: docker/setup-buildx-action@v3
66
+
67
+ - name: Login to Docker Hub
68
+ if: ${{ github.event_name != 'pull_request' }}
69
+ uses: docker/login-action@v3
70
+ with:
71
+ username: ${{ vars.DOCKERHUB_USERNAME }}
72
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
73
+
74
+ - name: Login to Quay
75
+ if: ${{ github.event_name != 'pull_request' && startsWith(matrix.device, 'npu') }}
76
+ uses: docker/login-action@v3
77
+ with:
78
+ registry: quay.io
79
+ username: ${{ vars.QUAY_ASCEND_USERNAME }}
80
+ password: ${{ secrets.QUAY_ASCEND_TOKEN }}
81
+
82
+ - name: Build and push Docker image (CUDA)
83
+ if: ${{ matrix.device == 'cuda' }}
84
+ uses: docker/build-push-action@v6
85
+ with:
86
+ context: .
87
+ file: ./docker/docker-cuda/Dockerfile
88
+ push: ${{ github.event_name != 'pull_request' }}
89
+ tags: |
90
+ docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}
91
+
92
+ - name: Build and push Docker image (NPU-A2)
93
+ if: ${{ matrix.device == 'npu-a2' }}
94
+ uses: docker/build-push-action@v6
95
+ with:
96
+ context: .
97
+ platforms: linux/amd64,linux/arm64
98
+ file: ./docker/docker-npu/Dockerfile
99
+ push: ${{ github.event_name != 'pull_request' }}
100
+ tags: |
101
+ docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
102
+ quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a2
103
+
104
+ - name: Build and push Docker image (NPU-A3)
105
+ if: ${{ matrix.device == 'npu-a3' }}
106
+ uses: docker/build-push-action@v6
107
+ with:
108
+ context: .
109
+ platforms: linux/amd64,linux/arm64
110
+ file: ./docker/docker-npu/Dockerfile
111
+ build-args: |
112
+ BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
113
+ push: ${{ github.event_name != 'pull_request' }}
114
+ tags: |
115
+ docker.io/hiyouga/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
116
+ quay.io/ascend/llamafactory:${{ steps.version.outputs.tag }}-npu-a3
LlamaFactory/.github/workflows/publish.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: publish
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ release:
6
+ types:
7
+ - published
8
+
9
+ jobs:
10
+ publish:
11
+ name: Upload release to PyPI
12
+
13
+ runs-on: ubuntu-latest
14
+
15
+ environment:
16
+ name: release
17
+ url: https://pypi.org/p/llamafactory
18
+
19
+ permissions:
20
+ id-token: write
21
+
22
+ steps:
23
+ - name: Checkout
24
+ uses: actions/checkout@v6
25
+
26
+ - name: Install uv
27
+ uses: astral-sh/setup-uv@v7
28
+ with:
29
+ python-version: "3.11"
30
+ github-token: ${{ github.token }}
31
+
32
+ - name: Build package
33
+ run: |
34
+ make build
35
+
36
+ - name: Publish package
37
+ uses: pypa/gh-action-pypi-publish@release/v1
LlamaFactory/src/api.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import uvicorn
18
+
19
+ from llamafactory.api.app import create_app
20
+ from llamafactory.chat import ChatModel
21
+
22
+
23
+ def main():
24
+ chat_model = ChatModel()
25
+ app = create_app(chat_model)
26
+ api_host = os.getenv("API_HOST", "0.0.0.0")
27
+ api_port = int(os.getenv("API_PORT", "8000"))
28
+ print(f"Visit http://localhost:{api_port}/docs for API document.")
29
+ uvicorn.run(app, host=api_host, port=api_port)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
LlamaFactory/src/llamafactory/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ r"""Efficient fine-tuning of large language models.
16
+
17
+ Level:
18
+ api, webui > chat, eval, train > data, model > hparams > extras
19
+
20
+ Disable version checking: DISABLE_VERSION_CHECK=1
21
+ Enable VRAM recording: RECORD_VRAM=1
22
+ Force using torchrun: FORCE_TORCHRUN=1
23
+ Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
24
+ Use modelscope: USE_MODELSCOPE_HUB=1
25
+ Use openmind: USE_OPENMIND_HUB=1
26
+ """
27
+
28
+ from .extras.env import VERSION
29
+
30
+
31
+ __version__ = VERSION
LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (639 Bytes). View file
 
LlamaFactory/src/llamafactory/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (613 Bytes). View file
 
LlamaFactory/src/llamafactory/__pycache__/cli.cpython-311.pyc ADDED
Binary file (735 Bytes). View file
 
LlamaFactory/src/llamafactory/__pycache__/cli.cpython-312.pyc ADDED
Binary file (582 Bytes). View file
 
LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-311.pyc ADDED
Binary file (7.09 kB). View file
 
LlamaFactory/src/llamafactory/__pycache__/launcher.cpython-312.pyc ADDED
Binary file (6.3 kB). View file
 
LlamaFactory/src/llamafactory/api/__init__.py ADDED
File without changes
LlamaFactory/src/llamafactory/api/__pycache__/common.cpython-311.pyc ADDED
Binary file (4.77 kB). View file
 
LlamaFactory/src/llamafactory/api/__pycache__/protocol.cpython-311.pyc ADDED
Binary file (9.29 kB). View file
 
LlamaFactory/src/llamafactory/api/app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import os
17
+ from contextlib import asynccontextmanager
18
+ from functools import partial
19
+ from typing import Annotated
20
+
21
+ from ..chat import ChatModel
22
+ from ..extras.constants import EngineName
23
+ from ..extras.misc import torch_gc
24
+ from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
25
+ from .chat import (
26
+ create_chat_completion_response,
27
+ create_score_evaluation_response,
28
+ create_stream_chat_completion_response,
29
+ )
30
+ from .protocol import (
31
+ ChatCompletionRequest,
32
+ ChatCompletionResponse,
33
+ ModelCard,
34
+ ModelList,
35
+ ScoreEvaluationRequest,
36
+ ScoreEvaluationResponse,
37
+ )
38
+
39
+
40
+ if is_fastapi_available():
41
+ from fastapi import Depends, FastAPI, HTTPException, status
42
+ from fastapi.middleware.cors import CORSMiddleware
43
+ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
44
+
45
+
46
+ if is_starlette_available():
47
+ from sse_starlette import EventSourceResponse
48
+
49
+
50
+ if is_uvicorn_available():
51
+ import uvicorn
52
+
53
+
54
+ async def sweeper() -> None:
55
+ while True:
56
+ torch_gc()
57
+ await asyncio.sleep(300)
58
+
59
+
60
+ @asynccontextmanager
61
+ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
62
+ if chat_model.engine.name == EngineName.HF:
63
+ asyncio.create_task(sweeper())
64
+
65
+ yield
66
+ torch_gc()
67
+
68
+
69
+ def create_app(chat_model: "ChatModel") -> "FastAPI":
70
+ root_path = os.getenv("FASTAPI_ROOT_PATH", "")
71
+ app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
72
+ app.add_middleware(
73
+ CORSMiddleware,
74
+ allow_origins=["*"],
75
+ allow_credentials=True,
76
+ allow_methods=["*"],
77
+ allow_headers=["*"],
78
+ )
79
+ api_key = os.getenv("API_KEY")
80
+ security = HTTPBearer(auto_error=False)
81
+
82
+ async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
83
+ if api_key and (auth is None or auth.credentials != api_key):
84
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
85
+
86
+ @app.get(
87
+ "/v1/models",
88
+ response_model=ModelList,
89
+ status_code=status.HTTP_200_OK,
90
+ dependencies=[Depends(verify_api_key)],
91
+ )
92
+ async def list_models():
93
+ model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
94
+ return ModelList(data=[model_card])
95
+
96
+ @app.post(
97
+ "/v1/chat/completions",
98
+ response_model=ChatCompletionResponse,
99
+ status_code=status.HTTP_200_OK,
100
+ dependencies=[Depends(verify_api_key)],
101
+ )
102
+ async def create_chat_completion(request: ChatCompletionRequest):
103
+ if not chat_model.engine.can_generate:
104
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
105
+
106
+ if request.stream:
107
+ generate = create_stream_chat_completion_response(request, chat_model)
108
+ return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
109
+ else:
110
+ return await create_chat_completion_response(request, chat_model)
111
+
112
+ @app.post(
113
+ "/v1/score/evaluation",
114
+ response_model=ScoreEvaluationResponse,
115
+ status_code=status.HTTP_200_OK,
116
+ dependencies=[Depends(verify_api_key)],
117
+ )
118
+ async def create_score_evaluation(request: ScoreEvaluationRequest):
119
+ if chat_model.engine.can_generate:
120
+ raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
121
+
122
+ return await create_score_evaluation_response(request, chat_model)
123
+
124
+ return app
125
+
126
+
127
+ def run_api() -> None:
128
+ chat_model = ChatModel()
129
+ app = create_app(chat_model)
130
+ api_host = os.getenv("API_HOST", "0.0.0.0")
131
+ api_port = int(os.getenv("API_PORT", "8000"))
132
+ print(f"Visit http://localhost:{api_port}/docs for API document.")
133
+ uvicorn.run(app, host=api_host, port=api_port)
LlamaFactory/src/llamafactory/api/chat.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import io
17
+ import json
18
+ import os
19
+ import re
20
+ import uuid
21
+ from collections.abc import AsyncGenerator
22
+ from typing import TYPE_CHECKING, Optional
23
+
24
+ from ..data import Role as DataRole
25
+ from ..extras import logging
26
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
27
+ from ..extras.misc import is_env_enabled
28
+ from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
29
+ from .common import check_lfi_path, check_ssrf_url, dictify, jsonify
30
+ from .protocol import (
31
+ ChatCompletionMessage,
32
+ ChatCompletionResponse,
33
+ ChatCompletionResponseChoice,
34
+ ChatCompletionResponseUsage,
35
+ ChatCompletionStreamResponse,
36
+ ChatCompletionStreamResponseChoice,
37
+ Finish,
38
+ Function,
39
+ FunctionCall,
40
+ Role,
41
+ ScoreEvaluationResponse,
42
+ )
43
+
44
+
45
+ if is_fastapi_available():
46
+ from fastapi import HTTPException, status
47
+
48
+
49
+ if is_pillow_available():
50
+ from PIL import Image
51
+
52
+
53
+ if is_requests_available():
54
+ import requests
55
+
56
+
57
+ if TYPE_CHECKING:
58
+ from ..chat import ChatModel
59
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
60
+ from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
61
+
62
+
63
+ logger = logging.get_logger(__name__)
64
+ ROLE_MAPPING = {
65
+ Role.USER: DataRole.USER.value,
66
+ Role.ASSISTANT: DataRole.ASSISTANT.value,
67
+ Role.SYSTEM: DataRole.SYSTEM.value,
68
+ Role.FUNCTION: DataRole.FUNCTION.value,
69
+ Role.TOOL: DataRole.OBSERVATION.value,
70
+ }
71
+
72
+
73
+ def _process_request(
74
+ request: "ChatCompletionRequest",
75
+ ) -> tuple[
76
+ list[dict[str, str]],
77
+ Optional[str],
78
+ Optional[str],
79
+ Optional[list["ImageInput"]],
80
+ Optional[list["VideoInput"]],
81
+ Optional[list["AudioInput"]],
82
+ ]:
83
+ if is_env_enabled("API_VERBOSE", "1"):
84
+ logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
85
+
86
+ if len(request.messages) == 0:
87
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
88
+
89
+ if request.messages[0].role == Role.SYSTEM:
90
+ content = request.messages.pop(0).content
91
+ system = content[0].text if isinstance(content, list) else content
92
+ else:
93
+ system = None
94
+
95
+ if len(request.messages) % 2 == 0:
96
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
97
+
98
+ input_messages = []
99
+ images, videos, audios = [], [], []
100
+ for i, message in enumerate(request.messages):
101
+ if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
102
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
103
+ elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
104
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
105
+
106
+ if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
107
+ tool_calls = [
108
+ {"name": tool_call.function.name, "arguments": tool_call.function.arguments}
109
+ for tool_call in message.tool_calls
110
+ ]
111
+ content = json.dumps(tool_calls, ensure_ascii=False)
112
+ input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
113
+ elif isinstance(message.content, list):
114
+ text_content = ""
115
+ for input_item in message.content:
116
+ if input_item.type == "text":
117
+ text_content += input_item.text
118
+ elif input_item.type == "image_url":
119
+ text_content += IMAGE_PLACEHOLDER
120
+ image_url = input_item.image_url.url
121
+ if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
122
+ image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
123
+ elif os.path.isfile(image_url): # local file
124
+ check_lfi_path(image_url)
125
+ image_stream = open(image_url, "rb")
126
+ else: # web uri
127
+ check_ssrf_url(image_url)
128
+ image_stream = requests.get(image_url, stream=True).raw
129
+
130
+ images.append(Image.open(image_stream).convert("RGB"))
131
+ elif input_item.type == "video_url":
132
+ text_content += VIDEO_PLACEHOLDER
133
+ video_url = input_item.video_url.url
134
+ if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
135
+ video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
136
+ elif os.path.isfile(video_url): # local file
137
+ check_lfi_path(video_url)
138
+ video_stream = video_url
139
+ else: # web uri
140
+ check_ssrf_url(video_url)
141
+ video_stream = requests.get(video_url, stream=True).raw
142
+
143
+ videos.append(video_stream)
144
+ elif input_item.type == "audio_url":
145
+ text_content += AUDIO_PLACEHOLDER
146
+ audio_url = input_item.audio_url.url
147
+ if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
148
+ audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
149
+ elif os.path.isfile(audio_url): # local file
150
+ check_lfi_path(audio_url)
151
+ audio_stream = audio_url
152
+ else: # web uri
153
+ check_ssrf_url(audio_url)
154
+ audio_stream = requests.get(audio_url, stream=True).raw
155
+
156
+ audios.append(audio_stream)
157
+ else:
158
+ raise HTTPException(
159
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid input type {input_item.type}."
160
+ )
161
+
162
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content})
163
+ else:
164
+ input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
165
+
166
+ tool_list = request.tools
167
+ if isinstance(tool_list, list) and len(tool_list):
168
+ try:
169
+ tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
170
+ except json.JSONDecodeError:
171
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
172
+ else:
173
+ tools = None
174
+
175
+ return input_messages, system, tools, images or None, videos or None, audios or None
176
+
177
+
178
+ def _create_stream_chat_completion_chunk(
179
+ completion_id: str,
180
+ model: str,
181
+ delta: "ChatCompletionMessage",
182
+ index: Optional[int] = 0,
183
+ finish_reason: Optional["Finish"] = None,
184
+ ) -> str:
185
+ choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason)
186
+ chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data])
187
+ return jsonify(chunk)
188
+
189
+
190
+ async def create_chat_completion_response(
191
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
192
+ ) -> "ChatCompletionResponse":
193
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
194
+ input_messages, system, tools, images, videos, audios = _process_request(request)
195
+ responses = await chat_model.achat(
196
+ input_messages,
197
+ system,
198
+ tools,
199
+ images,
200
+ videos,
201
+ audios,
202
+ do_sample=request.do_sample,
203
+ temperature=request.temperature,
204
+ top_p=request.top_p,
205
+ max_new_tokens=request.max_tokens,
206
+ num_return_sequences=request.n,
207
+ repetition_penalty=request.presence_penalty,
208
+ stop=request.stop,
209
+ )
210
+
211
+ prompt_length, response_length = 0, 0
212
+ choices = []
213
+ for i, response in enumerate(responses):
214
+ if tools:
215
+ result = chat_model.engine.template.extract_tool(response.response_text)
216
+ else:
217
+ result = response.response_text
218
+
219
+ if isinstance(result, list):
220
+ tool_calls = []
221
+ for tool in result:
222
+ function = Function(name=tool.name, arguments=tool.arguments)
223
+ tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
224
+
225
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
226
+ finish_reason = Finish.TOOL
227
+ else:
228
+ response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
229
+ finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
230
+
231
+ choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason))
232
+ prompt_length = response.prompt_length
233
+ response_length += response.response_length
234
+
235
+ usage = ChatCompletionResponseUsage(
236
+ prompt_tokens=prompt_length,
237
+ completion_tokens=response_length,
238
+ total_tokens=prompt_length + response_length,
239
+ )
240
+
241
+ return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage)
242
+
243
+
244
+ async def create_stream_chat_completion_response(
245
+ request: "ChatCompletionRequest", chat_model: "ChatModel"
246
+ ) -> AsyncGenerator[str, None]:
247
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
248
+ input_messages, system, tools, images, videos, audios = _process_request(request)
249
+ if tools:
250
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
251
+
252
+ if request.n > 1:
253
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.")
254
+
255
+ yield _create_stream_chat_completion_chunk(
256
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
257
+ )
258
+ async for new_token in chat_model.astream_chat(
259
+ input_messages,
260
+ system,
261
+ tools,
262
+ images,
263
+ videos,
264
+ audios,
265
+ do_sample=request.do_sample,
266
+ temperature=request.temperature,
267
+ top_p=request.top_p,
268
+ max_new_tokens=request.max_tokens,
269
+ repetition_penalty=request.presence_penalty,
270
+ stop=request.stop,
271
+ ):
272
+ if len(new_token) != 0:
273
+ yield _create_stream_chat_completion_chunk(
274
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token)
275
+ )
276
+
277
+ yield _create_stream_chat_completion_chunk(
278
+ completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
279
+ )
280
+ yield "[DONE]"
281
+
282
+
283
+ async def create_score_evaluation_response(
284
+ request: "ScoreEvaluationRequest", chat_model: "ChatModel"
285
+ ) -> "ScoreEvaluationResponse":
286
+ score_id = f"scoreval-{uuid.uuid4().hex}"
287
+ if len(request.messages) == 0:
288
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
289
+
290
+ scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
291
+ return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
LlamaFactory/src/llamafactory/api/common.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ipaddress
16
+ import json
17
+ import os
18
+ import socket
19
+ from typing import TYPE_CHECKING, Any
20
+ from urllib.parse import urlparse
21
+
22
+ from ..extras.misc import is_env_enabled
23
+ from ..extras.packages import is_fastapi_available
24
+
25
+
26
+ if is_fastapi_available():
27
+ from fastapi import HTTPException, status
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from pydantic import BaseModel
32
+
33
+
34
+ SAFE_MEDIA_PATH = os.environ.get("SAFE_MEDIA_PATH", os.path.join(os.path.dirname(__file__), "safe_media"))
35
+ ALLOW_LOCAL_FILES = is_env_enabled("ALLOW_LOCAL_FILES", "1")
36
+
37
+
38
+ def dictify(data: "BaseModel") -> dict[str, Any]:
39
+ try: # pydantic v2
40
+ return data.model_dump(exclude_unset=True)
41
+ except AttributeError: # pydantic v1
42
+ return data.dict(exclude_unset=True)
43
+
44
+
45
+ def jsonify(data: "BaseModel") -> str:
46
+ try: # pydantic v2
47
+ return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
48
+ except AttributeError: # pydantic v1
49
+ return data.json(exclude_unset=True, ensure_ascii=False)
50
+
51
+
52
+ def check_lfi_path(path: str) -> None:
53
+ """Checks if a given path is vulnerable to LFI. Raises HTTPException if unsafe."""
54
+ if not ALLOW_LOCAL_FILES:
55
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Local file access is disabled.")
56
+
57
+ try:
58
+ os.makedirs(SAFE_MEDIA_PATH, exist_ok=True)
59
+ real_path = os.path.realpath(path)
60
+ safe_path = os.path.realpath(SAFE_MEDIA_PATH)
61
+
62
+ if not real_path.startswith(safe_path):
63
+ raise HTTPException(
64
+ status_code=status.HTTP_403_FORBIDDEN, detail="File access is restricted to the safe media directory."
65
+ )
66
+ except Exception:
67
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or inaccessible file path.")
68
+
69
+
70
+ def check_ssrf_url(url: str) -> None:
71
+ """Checks if a given URL is vulnerable to SSRF. Raises HTTPException if unsafe."""
72
+ try:
73
+ parsed_url = urlparse(url)
74
+ if parsed_url.scheme not in ["http", "https"]:
75
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only HTTP/HTTPS URLs are allowed.")
76
+
77
+ hostname = parsed_url.hostname
78
+ if not hostname:
79
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid URL hostname.")
80
+
81
+ ip_info = socket.getaddrinfo(hostname, parsed_url.port)
82
+ ip_address_str = ip_info[0][4][0]
83
+ ip = ipaddress.ip_address(ip_address_str)
84
+
85
+ if not ip.is_global:
86
+ raise HTTPException(
87
+ status_code=status.HTTP_403_FORBIDDEN,
88
+ detail="Access to private or reserved IP addresses is not allowed.",
89
+ )
90
+
91
+ except socket.gaierror:
92
+ raise HTTPException(
93
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not resolve hostname: {parsed_url.hostname}"
94
+ )
95
+ except Exception as e:
96
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid URL: {e}")
LlamaFactory/src/llamafactory/api/protocol.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+ from enum import Enum, unique
17
+ from typing import Any, Literal
18
+
19
+ from pydantic import BaseModel, Field
20
+
21
+
22
+ @unique
23
+ class Role(str, Enum):
24
+ USER = "user"
25
+ ASSISTANT = "assistant"
26
+ SYSTEM = "system"
27
+ FUNCTION = "function"
28
+ TOOL = "tool"
29
+
30
+
31
+ @unique
32
+ class Finish(str, Enum):
33
+ STOP = "stop"
34
+ LENGTH = "length"
35
+ TOOL = "tool_calls"
36
+
37
+
38
+ class ModelCard(BaseModel):
39
+ id: str
40
+ object: Literal["model"] = "model"
41
+ created: int = Field(default_factory=lambda: int(time.time()))
42
+ owned_by: Literal["owner"] = "owner"
43
+
44
+
45
+ class ModelList(BaseModel):
46
+ object: Literal["list"] = "list"
47
+ data: list[ModelCard] = []
48
+
49
+
50
+ class Function(BaseModel):
51
+ name: str
52
+ arguments: str
53
+
54
+
55
+ class FunctionDefinition(BaseModel):
56
+ name: str
57
+ description: str
58
+ parameters: dict[str, Any]
59
+
60
+
61
+ class FunctionAvailable(BaseModel):
62
+ type: Literal["function", "code_interpreter"] = "function"
63
+ function: FunctionDefinition | None = None
64
+
65
+
66
+ class FunctionCall(BaseModel):
67
+ id: str
68
+ type: Literal["function"] = "function"
69
+ function: Function
70
+
71
+
72
+ class URL(BaseModel):
73
+ url: str
74
+ detail: Literal["auto", "low", "high"] = "auto"
75
+
76
+
77
+ class MultimodalInputItem(BaseModel):
78
+ type: Literal["text", "image_url", "video_url", "audio_url"]
79
+ text: str | None = None
80
+ image_url: URL | None = None
81
+ video_url: URL | None = None
82
+ audio_url: URL | None = None
83
+
84
+
85
+ class ChatMessage(BaseModel):
86
+ role: Role
87
+ content: str | list[MultimodalInputItem] | None = None
88
+ tool_calls: list[FunctionCall] | None = None
89
+
90
+
91
+ class ChatCompletionMessage(BaseModel):
92
+ role: Role | None = None
93
+ content: str | None = None
94
+ tool_calls: list[FunctionCall] | None = None
95
+
96
+
97
+ class ChatCompletionRequest(BaseModel):
98
+ model: str
99
+ messages: list[ChatMessage]
100
+ tools: list[FunctionAvailable] | None = None
101
+ do_sample: bool | None = None
102
+ temperature: float | None = None
103
+ top_p: float | None = None
104
+ n: int = 1
105
+ presence_penalty: float | None = None
106
+ max_tokens: int | None = None
107
+ stop: str | list[str] | None = None
108
+ stream: bool = False
109
+
110
+
111
+ class ChatCompletionResponseChoice(BaseModel):
112
+ index: int
113
+ message: ChatCompletionMessage
114
+ finish_reason: Finish
115
+
116
+
117
+ class ChatCompletionStreamResponseChoice(BaseModel):
118
+ index: int
119
+ delta: ChatCompletionMessage
120
+ finish_reason: Finish | None = None
121
+
122
+
123
+ class ChatCompletionResponseUsage(BaseModel):
124
+ prompt_tokens: int
125
+ completion_tokens: int
126
+ total_tokens: int
127
+
128
+
129
+ class ChatCompletionResponse(BaseModel):
130
+ id: str
131
+ object: Literal["chat.completion"] = "chat.completion"
132
+ created: int = Field(default_factory=lambda: int(time.time()))
133
+ model: str
134
+ choices: list[ChatCompletionResponseChoice]
135
+ usage: ChatCompletionResponseUsage
136
+
137
+
138
+ class ChatCompletionStreamResponse(BaseModel):
139
+ id: str
140
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
141
+ created: int = Field(default_factory=lambda: int(time.time()))
142
+ model: str
143
+ choices: list[ChatCompletionStreamResponseChoice]
144
+
145
+
146
+ class ScoreEvaluationRequest(BaseModel):
147
+ model: str
148
+ messages: list[str]
149
+ max_length: int | None = None
150
+
151
+
152
+ class ScoreEvaluationResponse(BaseModel):
153
+ id: str
154
+ object: Literal["score.evaluation"] = "score.evaluation"
155
+ model: str
156
+ scores: list[float]
LlamaFactory/src/llamafactory/chat/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base_engine import BaseEngine
16
+ from .chat_model import ChatModel
17
+
18
+
19
+ __all__ = ["BaseEngine", "ChatModel"]
LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (325 Bytes). View file
 
LlamaFactory/src/llamafactory/chat/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (289 Bytes). View file
 
LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-311.pyc ADDED
Binary file (4.24 kB). View file
 
LlamaFactory/src/llamafactory/chat/__pycache__/base_engine.cpython-312.pyc ADDED
Binary file (3.55 kB). View file
 
LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
LlamaFactory/src/llamafactory/chat/__pycache__/chat_model.cpython-312.pyc ADDED
Binary file (9.28 kB). View file
 
LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
LlamaFactory/src/llamafactory/chat/__pycache__/hf_engine.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
LlamaFactory/src/llamafactory/chat/base_engine.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from collections.abc import AsyncGenerator
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from transformers import PreTrainedModel, PreTrainedTokenizer
23
+ from vllm import AsyncLLMEngine
24
+
25
+ from ..data import Template
26
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
27
+ from ..extras.constants import EngineName
28
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
29
+
30
+
31
+ @dataclass
32
+ class Response:
33
+ response_text: str
34
+ response_length: int
35
+ prompt_length: int
36
+ finish_reason: Literal["stop", "length"]
37
+
38
+
39
+ class BaseEngine(ABC):
40
+ r"""Base class for inference engine of chat models.
41
+
42
+ Must implements async methods: chat(), stream_chat() and get_scores().
43
+ """
44
+
45
+ name: "EngineName"
46
+ model: Union["PreTrainedModel", "AsyncLLMEngine"]
47
+ tokenizer: "PreTrainedTokenizer"
48
+ can_generate: bool
49
+ template: "Template"
50
+ generating_args: dict[str, Any]
51
+
52
+ @abstractmethod
53
+ def __init__(
54
+ self,
55
+ model_args: "ModelArguments",
56
+ data_args: "DataArguments",
57
+ finetuning_args: "FinetuningArguments",
58
+ generating_args: "GeneratingArguments",
59
+ ) -> None:
60
+ r"""Initialize an inference engine."""
61
+ ...
62
+
63
+ @abstractmethod
64
+ async def chat(
65
+ self,
66
+ messages: list[dict[str, str]],
67
+ system: Optional[str] = None,
68
+ tools: Optional[str] = None,
69
+ images: Optional[list["ImageInput"]] = None,
70
+ videos: Optional[list["VideoInput"]] = None,
71
+ audios: Optional[list["AudioInput"]] = None,
72
+ **input_kwargs,
73
+ ) -> list["Response"]:
74
+ r"""Get a list of responses of the chat model."""
75
+ ...
76
+
77
+ @abstractmethod
78
+ async def stream_chat(
79
+ self,
80
+ messages: list[dict[str, str]],
81
+ system: Optional[str] = None,
82
+ tools: Optional[str] = None,
83
+ images: Optional[list["ImageInput"]] = None,
84
+ videos: Optional[list["VideoInput"]] = None,
85
+ audios: Optional[list["AudioInput"]] = None,
86
+ **input_kwargs,
87
+ ) -> AsyncGenerator[str, None]:
88
+ r"""Get the response token-by-token of the chat model."""
89
+ ...
90
+
91
+ @abstractmethod
92
+ async def get_scores(
93
+ self,
94
+ batch_input: list[str],
95
+ **input_kwargs,
96
+ ) -> list[float]:
97
+ r"""Get a list of scores of the reward model."""
98
+ ...
LlamaFactory/src/llamafactory/chat/chat_model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 THUDM and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the THUDM's ChatGLM implementation.
4
+ # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import asyncio
19
+ import os
20
+ from collections.abc import AsyncGenerator, Generator
21
+ from threading import Thread
22
+ from typing import TYPE_CHECKING, Any, Optional
23
+
24
+ from ..extras.constants import EngineName
25
+ from ..extras.misc import torch_gc
26
+ from ..hparams import get_infer_args
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
31
+ from .base_engine import BaseEngine, Response
32
+
33
+
34
+ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
35
+ asyncio.set_event_loop(loop)
36
+ loop.run_forever()
37
+
38
+
39
+ class ChatModel:
40
+ r"""General class for chat models. Backed by huggingface or vllm engines.
41
+
42
+ Supports both sync and async methods.
43
+ Sync methods: chat(), stream_chat() and get_scores().
44
+ Async methods: achat(), astream_chat() and aget_scores().
45
+ """
46
+
47
+ def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
48
+ model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
49
+
50
+ if model_args.infer_backend == EngineName.HF:
51
+ from .hf_engine import HuggingfaceEngine
52
+
53
+ self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
54
+ elif model_args.infer_backend == EngineName.VLLM:
55
+ try:
56
+ from .vllm_engine import VllmEngine
57
+
58
+ self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
59
+ except ImportError as e:
60
+ raise ImportError(
61
+ "vLLM not install, you may need to run `pip install vllm`\n"
62
+ "or try to use HuggingFace backend: --infer_backend huggingface"
63
+ ) from e
64
+ elif model_args.infer_backend == EngineName.SGLANG:
65
+ try:
66
+ from .sglang_engine import SGLangEngine
67
+
68
+ self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
69
+ except ImportError as e:
70
+ raise ImportError(
71
+ "SGLang not install, you may need to run `pip install sglang[all]`\n"
72
+ "or try to use HuggingFace backend: --infer_backend huggingface"
73
+ ) from e
74
+ elif model_args.infer_backend == EngineName.KT:
75
+ try:
76
+ from .kt_engine import KTransformersEngine
77
+
78
+ self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
79
+ except ImportError as e:
80
+ raise ImportError(
81
+ "KTransformers not install, you may need to run `pip install ktransformers`\n"
82
+ "or try to use HuggingFace backend: --infer_backend huggingface"
83
+ ) from e
84
+ else:
85
+ raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
86
+
87
+ self._loop = asyncio.new_event_loop()
88
+ self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
89
+ self._thread.start()
90
+
91
+ def chat(
92
+ self,
93
+ messages: list[dict[str, str]],
94
+ system: Optional[str] = None,
95
+ tools: Optional[str] = None,
96
+ images: Optional[list["ImageInput"]] = None,
97
+ videos: Optional[list["VideoInput"]] = None,
98
+ audios: Optional[list["AudioInput"]] = None,
99
+ **input_kwargs,
100
+ ) -> list["Response"]:
101
+ r"""Get a list of responses of the chat model."""
102
+ task = asyncio.run_coroutine_threadsafe(
103
+ self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
104
+ )
105
+ return task.result()
106
+
107
+ async def achat(
108
+ self,
109
+ messages: list[dict[str, str]],
110
+ system: Optional[str] = None,
111
+ tools: Optional[str] = None,
112
+ images: Optional[list["ImageInput"]] = None,
113
+ videos: Optional[list["VideoInput"]] = None,
114
+ audios: Optional[list["AudioInput"]] = None,
115
+ **input_kwargs,
116
+ ) -> list["Response"]:
117
+ r"""Asynchronously get a list of responses of the chat model."""
118
+ return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
119
+
120
+ def stream_chat(
121
+ self,
122
+ messages: list[dict[str, str]],
123
+ system: Optional[str] = None,
124
+ tools: Optional[str] = None,
125
+ images: Optional[list["ImageInput"]] = None,
126
+ videos: Optional[list["VideoInput"]] = None,
127
+ audios: Optional[list["AudioInput"]] = None,
128
+ **input_kwargs,
129
+ ) -> Generator[str, None, None]:
130
+ r"""Get the response token-by-token of the chat model."""
131
+ generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
132
+ while True:
133
+ try:
134
+ task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
135
+ yield task.result()
136
+ except StopAsyncIteration:
137
+ break
138
+
139
+ async def astream_chat(
140
+ self,
141
+ messages: list[dict[str, str]],
142
+ system: Optional[str] = None,
143
+ tools: Optional[str] = None,
144
+ images: Optional[list["ImageInput"]] = None,
145
+ videos: Optional[list["VideoInput"]] = None,
146
+ audios: Optional[list["AudioInput"]] = None,
147
+ **input_kwargs,
148
+ ) -> AsyncGenerator[str, None]:
149
+ r"""Asynchronously get the response token-by-token of the chat model."""
150
+ async for new_token in self.engine.stream_chat(
151
+ messages, system, tools, images, videos, audios, **input_kwargs
152
+ ):
153
+ yield new_token
154
+
155
+ def get_scores(
156
+ self,
157
+ batch_input: list[str],
158
+ **input_kwargs,
159
+ ) -> list[float]:
160
+ r"""Get a list of scores of the reward model."""
161
+ task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
162
+ return task.result()
163
+
164
+ async def aget_scores(
165
+ self,
166
+ batch_input: list[str],
167
+ **input_kwargs,
168
+ ) -> list[float]:
169
+ r"""Asynchronously get a list of scores of the reward model."""
170
+ return await self.engine.get_scores(batch_input, **input_kwargs)
171
+
172
+
173
+ def run_chat() -> None:
174
+ if os.name != "nt":
175
+ try:
176
+ import readline # noqa: F401
177
+ except ImportError:
178
+ print("Install `readline` for a better experience.")
179
+
180
+ chat_model = ChatModel()
181
+ messages = []
182
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
183
+
184
+ while True:
185
+ try:
186
+ query = input("\nUser: ")
187
+ except UnicodeDecodeError:
188
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
189
+ continue
190
+ except Exception:
191
+ raise
192
+
193
+ if query.strip() == "exit":
194
+ break
195
+
196
+ if query.strip() == "clear":
197
+ messages = []
198
+ torch_gc()
199
+ print("History has been removed.")
200
+ continue
201
+
202
+ messages.append({"role": "user", "content": query})
203
+ print("Assistant: ", end="", flush=True)
204
+
205
+ response = ""
206
+ for new_text in chat_model.stream_chat(messages):
207
+ print(new_text, end="", flush=True)
208
+ response += new_text
209
+ print()
210
+ messages.append({"role": "assistant", "content": response})
LlamaFactory/src/llamafactory/chat/hf_engine.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import os
17
+ from collections.abc import AsyncGenerator, Callable
18
+ from threading import Thread
19
+ from typing import TYPE_CHECKING, Any, Optional, Union
20
+
21
+ import torch
22
+ from transformers import GenerationConfig, TextIteratorStreamer
23
+ from typing_extensions import override
24
+
25
+ from ..data import get_template_and_fix_tokenizer
26
+ from ..extras import logging
27
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
28
+ from ..model import load_model, load_tokenizer
29
+ from .base_engine import BaseEngine, Response
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
34
+ from trl import PreTrainedModelWrapper
35
+
36
+ from ..data import Template
37
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
38
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class HuggingfaceEngine(BaseEngine):
45
+ def __init__(
46
+ self,
47
+ model_args: "ModelArguments",
48
+ data_args: "DataArguments",
49
+ finetuning_args: "FinetuningArguments",
50
+ generating_args: "GeneratingArguments",
51
+ ) -> None:
52
+ self.name = EngineName.HF
53
+ self.can_generate = finetuning_args.stage == "sft"
54
+ tokenizer_module = load_tokenizer(model_args)
55
+ self.tokenizer = tokenizer_module["tokenizer"]
56
+ self.processor = tokenizer_module["processor"]
57
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
58
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
59
+ self.model = load_model(
60
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
61
+ ) # must after fixing tokenizer to resize vocab
62
+ self.generating_args = generating_args.to_dict()
63
+ try:
64
+ asyncio.get_event_loop()
65
+ except RuntimeError:
66
+ logger.warning_rank0_once("There is no current event loop, creating a new one.")
67
+ loop = asyncio.new_event_loop()
68
+ asyncio.set_event_loop(loop)
69
+
70
+ self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
71
+
72
+ @staticmethod
73
+ def _process_args(
74
+ model: "PreTrainedModel",
75
+ tokenizer: "PreTrainedTokenizer",
76
+ processor: Optional["ProcessorMixin"],
77
+ template: "Template",
78
+ generating_args: dict[str, Any],
79
+ messages: list[dict[str, str]],
80
+ system: Optional[str] = None,
81
+ tools: Optional[str] = None,
82
+ images: Optional[list["ImageInput"]] = None,
83
+ videos: Optional[list["VideoInput"]] = None,
84
+ audios: Optional[list["AudioInput"]] = None,
85
+ input_kwargs: Optional[dict[str, Any]] = {},
86
+ ) -> tuple[dict[str, Any], int]:
87
+ mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
88
+ if images is not None:
89
+ mm_input_dict.update({"images": images, "imglens": [len(images)]})
90
+ if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
91
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
92
+
93
+ if videos is not None:
94
+ mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
95
+ if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
96
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
97
+
98
+ if audios is not None:
99
+ mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
100
+ if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
101
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
102
+
103
+ messages = template.mm_plugin.process_messages(
104
+ messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
105
+ )
106
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
107
+ prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
108
+ prompt_ids, _ = template.mm_plugin.process_token_ids(
109
+ prompt_ids,
110
+ None,
111
+ mm_input_dict["images"],
112
+ mm_input_dict["videos"],
113
+ mm_input_dict["audios"],
114
+ tokenizer,
115
+ processor,
116
+ )
117
+ prompt_length = len(prompt_ids)
118
+ inputs = torch.tensor([prompt_ids], device=model.device)
119
+ attention_mask = torch.ones_like(inputs, dtype=torch.long)
120
+
121
+ do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
122
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
123
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
124
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
125
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
126
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
127
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
128
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
129
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
130
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
131
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
132
+
133
+ if stop is not None:
134
+ logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
135
+
136
+ generating_args = generating_args.copy()
137
+ generating_args.update(
138
+ dict(
139
+ do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
140
+ temperature=temperature if temperature is not None else generating_args["temperature"],
141
+ top_p=top_p if top_p is not None else generating_args["top_p"],
142
+ top_k=top_k if top_k is not None else generating_args["top_k"],
143
+ num_return_sequences=num_return_sequences,
144
+ repetition_penalty=repetition_penalty
145
+ if repetition_penalty is not None
146
+ else generating_args["repetition_penalty"],
147
+ length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
148
+ skip_special_tokens=skip_special_tokens
149
+ if skip_special_tokens is not None
150
+ else generating_args["skip_special_tokens"],
151
+ eos_token_id=template.get_stop_token_ids(tokenizer),
152
+ pad_token_id=tokenizer.pad_token_id,
153
+ )
154
+ )
155
+
156
+ if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
157
+ generating_args["do_sample"] = True
158
+ generating_args["temperature"] = generating_args["temperature"] or 1.0
159
+
160
+ if not generating_args["temperature"]:
161
+ generating_args["do_sample"] = False
162
+
163
+ if not generating_args["do_sample"]:
164
+ generating_args.pop("temperature", None)
165
+ generating_args.pop("top_p", None)
166
+
167
+ if max_length:
168
+ generating_args.pop("max_new_tokens", None)
169
+ generating_args["max_length"] = max_length
170
+
171
+ if max_new_tokens:
172
+ generating_args.pop("max_length", None)
173
+ generating_args["max_new_tokens"] = max_new_tokens
174
+
175
+ gen_kwargs = dict(
176
+ inputs=inputs,
177
+ attention_mask=attention_mask,
178
+ generation_config=GenerationConfig(**generating_args),
179
+ )
180
+
181
+ mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
182
+ for key, value in mm_inputs.items():
183
+ if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
184
+ value = torch.stack(value) # assume they have same sizes
185
+ elif (
186
+ isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
187
+ ): # for minicpmv inputs
188
+ value = torch.stack([torch.stack(v) for v in value])
189
+ elif not isinstance(value, torch.Tensor):
190
+ value = torch.tensor(value)
191
+
192
+ if torch.is_floating_point(value): # cast data dtype for paligemma
193
+ value = value.to(model.dtype)
194
+
195
+ if key == "second_per_grid_ts": # qwen2.5vl special case
196
+ gen_kwargs[key] = value.tolist()
197
+ else:
198
+ gen_kwargs[key] = value.to(model.device)
199
+
200
+ if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
201
+ gen_kwargs["input_ids"] = inputs
202
+ gen_kwargs["tokenizer"] = tokenizer
203
+ if "audio_feature_lens" in mm_inputs:
204
+ gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
205
+
206
+ gen_kwargs.pop("image_sizes", None)
207
+
208
+ return gen_kwargs, prompt_length
209
+
210
+ @staticmethod
211
+ @torch.inference_mode()
212
+ def _chat(
213
+ model: "PreTrainedModel",
214
+ tokenizer: "PreTrainedTokenizer",
215
+ processor: Optional["ProcessorMixin"],
216
+ template: "Template",
217
+ generating_args: dict[str, Any],
218
+ messages: list[dict[str, str]],
219
+ system: Optional[str] = None,
220
+ tools: Optional[str] = None,
221
+ images: Optional[list["ImageInput"]] = None,
222
+ videos: Optional[list["VideoInput"]] = None,
223
+ audios: Optional[list["AudioInput"]] = None,
224
+ input_kwargs: Optional[dict[str, Any]] = {},
225
+ ) -> list["Response"]:
226
+ gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
227
+ model,
228
+ tokenizer,
229
+ processor,
230
+ template,
231
+ generating_args,
232
+ messages,
233
+ system,
234
+ tools,
235
+ images,
236
+ videos,
237
+ audios,
238
+ input_kwargs,
239
+ )
240
+ generate_output = model.generate(**gen_kwargs)
241
+ if isinstance(generate_output, tuple):
242
+ generate_output = generate_output[1][0] # post-process the minicpm_o output
243
+
244
+ response_ids = generate_output[:, prompt_length:]
245
+ response = tokenizer.batch_decode(
246
+ response_ids,
247
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
248
+ clean_up_tokenization_spaces=True,
249
+ )
250
+ results = []
251
+ for i in range(len(response)):
252
+ eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
253
+ response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
254
+ results.append(
255
+ Response(
256
+ response_text=response[i],
257
+ response_length=response_length,
258
+ prompt_length=prompt_length,
259
+ finish_reason="stop" if len(eos_index) else "length",
260
+ )
261
+ )
262
+
263
+ return results
264
+
265
+ @staticmethod
266
+ @torch.inference_mode()
267
+ def _stream_chat(
268
+ model: "PreTrainedModel",
269
+ tokenizer: "PreTrainedTokenizer",
270
+ processor: Optional["ProcessorMixin"],
271
+ template: "Template",
272
+ generating_args: dict[str, Any],
273
+ messages: list[dict[str, str]],
274
+ system: Optional[str] = None,
275
+ tools: Optional[str] = None,
276
+ images: Optional[list["ImageInput"]] = None,
277
+ videos: Optional[list["VideoInput"]] = None,
278
+ audios: Optional[list["AudioInput"]] = None,
279
+ input_kwargs: Optional[dict[str, Any]] = {},
280
+ ) -> Callable[[], str]:
281
+ gen_kwargs, _ = HuggingfaceEngine._process_args(
282
+ model,
283
+ tokenizer,
284
+ processor,
285
+ template,
286
+ generating_args,
287
+ messages,
288
+ system,
289
+ tools,
290
+ images,
291
+ videos,
292
+ audios,
293
+ input_kwargs,
294
+ )
295
+ streamer = TextIteratorStreamer(
296
+ tokenizer,
297
+ skip_prompt=True,
298
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
299
+ )
300
+ gen_kwargs["streamer"] = streamer
301
+ thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
302
+ thread.start()
303
+
304
+ def stream():
305
+ try:
306
+ return streamer.__next__()
307
+ except StopIteration:
308
+ raise StopAsyncIteration()
309
+
310
+ return stream
311
+
312
+ @staticmethod
313
+ @torch.inference_mode()
314
+ def _get_scores(
315
+ model: "PreTrainedModelWrapper",
316
+ tokenizer: "PreTrainedTokenizer",
317
+ batch_input: list[str],
318
+ input_kwargs: Optional[dict[str, Any]] = {},
319
+ ) -> list[float]:
320
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
321
+ device = getattr(model.pretrained_model, "device", "cuda")
322
+ inputs: dict[str, torch.Tensor] = tokenizer(
323
+ batch_input,
324
+ padding=True,
325
+ truncation=True,
326
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
327
+ return_tensors="pt",
328
+ add_special_tokens=False,
329
+ ).to(device)
330
+ values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
331
+ scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
332
+ return scores
333
+
334
+ @override
335
+ async def chat(
336
+ self,
337
+ messages: list[dict[str, str]],
338
+ system: Optional[str] = None,
339
+ tools: Optional[str] = None,
340
+ images: Optional[list["ImageInput"]] = None,
341
+ videos: Optional[list["VideoInput"]] = None,
342
+ audios: Optional[list["AudioInput"]] = None,
343
+ **input_kwargs,
344
+ ) -> list["Response"]:
345
+ if not self.can_generate:
346
+ raise ValueError("The current model does not support `chat`.")
347
+
348
+ input_args = (
349
+ self.model,
350
+ self.tokenizer,
351
+ self.processor,
352
+ self.template,
353
+ self.generating_args,
354
+ messages,
355
+ system,
356
+ tools,
357
+ images,
358
+ videos,
359
+ audios,
360
+ input_kwargs,
361
+ )
362
+ async with self.semaphore:
363
+ return await asyncio.to_thread(self._chat, *input_args)
364
+
365
+ @override
366
+ async def stream_chat(
367
+ self,
368
+ messages: list[dict[str, str]],
369
+ system: Optional[str] = None,
370
+ tools: Optional[str] = None,
371
+ images: Optional[list["ImageInput"]] = None,
372
+ videos: Optional[list["VideoInput"]] = None,
373
+ audios: Optional[list["AudioInput"]] = None,
374
+ **input_kwargs,
375
+ ) -> AsyncGenerator[str, None]:
376
+ if not self.can_generate:
377
+ raise ValueError("The current model does not support `stream_chat`.")
378
+
379
+ input_args = (
380
+ self.model,
381
+ self.tokenizer,
382
+ self.processor,
383
+ self.template,
384
+ self.generating_args,
385
+ messages,
386
+ system,
387
+ tools,
388
+ images,
389
+ videos,
390
+ audios,
391
+ input_kwargs,
392
+ )
393
+ async with self.semaphore:
394
+ stream = self._stream_chat(*input_args)
395
+ while True:
396
+ try:
397
+ yield await asyncio.to_thread(stream)
398
+ except StopAsyncIteration:
399
+ break
400
+
401
+ @override
402
+ async def get_scores(
403
+ self,
404
+ batch_input: list[str],
405
+ **input_kwargs,
406
+ ) -> list[float]:
407
+ if self.can_generate:
408
+ raise ValueError("Cannot get scores using an auto-regressive model.")
409
+
410
+ input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
411
+ async with self.semaphore:
412
+ return await asyncio.to_thread(self._get_scores, *input_args)
LlamaFactory/src/llamafactory/chat/kt_engine.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import os
17
+ import platform
18
+ from collections.abc import AsyncGenerator
19
+ from threading import Thread
20
+ from typing import TYPE_CHECKING, Any, Optional
21
+
22
+ import torch
23
+ from typing_extensions import override
24
+
25
+ from ..data import get_template_and_fix_tokenizer
26
+ from ..extras import logging
27
+ from ..extras.constants import EngineName
28
+ from ..model import load_model, load_tokenizer
29
+ from .base_engine import BaseEngine, Response
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import PreTrainedTokenizer
34
+ from trl import PreTrainedModelWrapper
35
+
36
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
37
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
38
+
39
+ from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
40
+ from ktransformers.server.config.config import Config
41
+ from ktransformers.util.utils import (
42
+ get_compute_capability,
43
+ prefill_and_generate_capture,
44
+ )
45
+ from ktransformers.util.vendors import GPUVendor, device_manager
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class KTransformersEngine(BaseEngine):
52
+ def __init__(
53
+ self,
54
+ model_args: "ModelArguments",
55
+ data_args: "DataArguments",
56
+ finetuning_args: "FinetuningArguments",
57
+ generating_args: "GeneratingArguments",
58
+ ) -> None:
59
+ self.name = EngineName.KT
60
+ self.can_generate = finetuning_args.stage == "sft"
61
+
62
+ tok_mod = load_tokenizer(model_args)
63
+ self.tokenizer = tok_mod["tokenizer"]
64
+ self.tokenizer.padding_side = "left" if self.can_generate else "right"
65
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
66
+
67
+ self.model = load_model(
68
+ self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
69
+ )
70
+
71
+ self.generating_args = generating_args.to_dict()
72
+ self.max_new_tokens = model_args.kt_maxlen
73
+ self.use_cuda_graph = model_args.kt_use_cuda_graph
74
+ self.mode = model_args.kt_mode
75
+ self.force_think = model_args.kt_force_think
76
+ self.chunk_size = model_args.chunk_size
77
+
78
+ try:
79
+ asyncio.get_event_loop()
80
+ except RuntimeError:
81
+ loop = asyncio.new_event_loop()
82
+ asyncio.set_event_loop(loop)
83
+
84
+ self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
85
+
86
+ @staticmethod
87
+ @torch.inference_mode()
88
+ def _get_scores(
89
+ model: "PreTrainedModelWrapper",
90
+ tokenizer: "PreTrainedTokenizer",
91
+ batch_input: list[str],
92
+ input_kwargs: Optional[dict[str, Any]] = {},
93
+ ) -> list[float]:
94
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
95
+ device = getattr(model.pretrained_model, "device", "cuda")
96
+ inputs = tokenizer(
97
+ batch_input,
98
+ padding=True,
99
+ truncation=True,
100
+ max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
101
+ return_tensors="pt",
102
+ add_special_tokens=False,
103
+ ).to(device)
104
+ values: torch.Tensor = model(**inputs, return_dict=True, use_cache=False)[-1]
105
+ scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
106
+ return scores
107
+
108
+ async def _generate(
109
+ self,
110
+ messages: list[dict[str, str]],
111
+ system: Optional[str] = None,
112
+ tools: Optional[str] = None,
113
+ **input_kwargs,
114
+ ) -> AsyncGenerator[str, None]:
115
+ paired = messages + [{"role": "assistant", "content": ""}]
116
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired, system, tools)
117
+ prompt_len = len(prompt_ids)
118
+
119
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
120
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
121
+
122
+ if "max_new_tokens" in self.generating_args:
123
+ max_tokens = int(self.generating_args["max_new_tokens"])
124
+ elif "max_length" in self.generating_args:
125
+ gl = int(self.generating_args["max_length"])
126
+ max_tokens = gl - prompt_len if gl > prompt_len else 1
127
+ else:
128
+ max_tokens = self.max_new_tokens or 256
129
+
130
+ if max_length is not None:
131
+ max_tokens = max(max_length - prompt_len, 1)
132
+ if max_new_tokens is not None:
133
+ max_tokens = int(max_new_tokens)
134
+ max_tokens = max(1, int(max_tokens))
135
+
136
+ if self.mode == "long_context":
137
+ max_len_cfg = Config().long_context_config["max_seq_len"]
138
+ need = prompt_len + max_tokens
139
+ assert max_len_cfg > need, f"please set max_seq_len > {need} in ~/.ktransformers/config.yaml"
140
+
141
+ device = next(self.model.parameters()).device
142
+ input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
143
+ if self.force_think:
144
+ think = torch.tensor(
145
+ [self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
146
+ )
147
+ input_tensor = torch.cat([input_tensor, think], dim=1)
148
+
149
+ use_flashinfer = (
150
+ platform.system() != "Windows"
151
+ and getattr(self.model.config, "architectures", [""])[0]
152
+ in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
153
+ and flashinfer_enabled
154
+ and get_compute_capability() >= 8
155
+ and device_manager.gpu_vendor == GPUVendor.NVIDIA
156
+ )
157
+
158
+ def make_gen():
159
+ if use_flashinfer:
160
+ return prefill_and_generate_capture(
161
+ self.model,
162
+ self.tokenizer,
163
+ input_tensor,
164
+ max_tokens,
165
+ self.use_cuda_graph,
166
+ mode=self.mode,
167
+ force_think=self.force_think,
168
+ chunk_size=self.chunk_size,
169
+ use_flashinfer_mla=True,
170
+ num_heads=self.model.config.num_attention_heads,
171
+ head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
172
+ head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
173
+ q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
174
+ + getattr(self.model.config, "qk_nope_head_dim", 0),
175
+ echo_stream=False,
176
+ )
177
+ else:
178
+ return prefill_and_generate_capture(
179
+ self.model,
180
+ self.tokenizer,
181
+ input_tensor,
182
+ max_tokens,
183
+ self.use_cuda_graph,
184
+ mode=self.mode,
185
+ force_think=self.force_think,
186
+ chunk_size=self.chunk_size,
187
+ echo_stream=False,
188
+ )
189
+
190
+ loop = asyncio.get_running_loop()
191
+ q: asyncio.Queue[Optional[str]] = asyncio.Queue()
192
+
193
+ def producer():
194
+ try:
195
+ gen = make_gen()
196
+ if hasattr(gen, "__aiter__"):
197
+
198
+ async def drain_async():
199
+ async for t in gen:
200
+ loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
201
+
202
+ asyncio.run(drain_async())
203
+ elif hasattr(gen, "__iter__"):
204
+ for t in gen:
205
+ loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
206
+ else:
207
+ loop.call_soon_threadsafe(q.put_nowait, gen if isinstance(gen, str) else str(gen))
208
+ finally:
209
+ loop.call_soon_threadsafe(q.put_nowait, None)
210
+
211
+ Thread(target=producer, daemon=True).start()
212
+
213
+ while True:
214
+ item = await q.get()
215
+ if item is None:
216
+ break
217
+ yield item
218
+
219
+ @override
220
+ async def chat(
221
+ self,
222
+ messages: list[dict[str, str]],
223
+ system: Optional[str] = None,
224
+ tools: Optional[str] = None,
225
+ images: Optional[list["ImageInput"]] = None,
226
+ videos: Optional[list["VideoInput"]] = None,
227
+ audios: Optional[list["AudioInput"]] = None,
228
+ **input_kwargs,
229
+ ) -> list["Response"]:
230
+ if not self.can_generate:
231
+ raise ValueError("The current model does not support `chat`.")
232
+ async with self.semaphore:
233
+ produced = ""
234
+ final_text = ""
235
+ async for t in self._generate(messages, system, tools, **input_kwargs):
236
+ delta = t
237
+ produced = produced + delta
238
+ if delta:
239
+ final_text += delta
240
+
241
+ prompt_ids, _ = self.template.encode_oneturn(
242
+ self.tokenizer, messages + [{"role": "assistant", "content": ""}], system, tools
243
+ )
244
+ return [
245
+ Response(
246
+ response_text=final_text,
247
+ response_length=len(self.tokenizer.encode(final_text, add_special_tokens=False)),
248
+ prompt_length=len(prompt_ids),
249
+ finish_reason="stop",
250
+ )
251
+ ]
252
+
253
+ @override
254
+ async def stream_chat(
255
+ self,
256
+ messages: list[dict[str, str]],
257
+ system: Optional[str] = None,
258
+ tools: Optional[str] = None,
259
+ images: Optional[list["ImageInput"]] = None,
260
+ videos: Optional[list["VideoInput"]] = None,
261
+ audios: Optional[list["AudioInput"]] = None,
262
+ **input_kwargs,
263
+ ) -> AsyncGenerator[str, None]:
264
+ if not self.can_generate:
265
+ raise ValueError("The current model does not support `stream_chat`.")
266
+ async with self.semaphore:
267
+ produced = ""
268
+ async for t in self._generate(messages, system, tools, **input_kwargs):
269
+ delta = t[len(produced) :] if t.startswith(produced) else t
270
+ produced = t
271
+ if delta:
272
+ yield delta
273
+
274
+ @override
275
+ async def get_scores(
276
+ self,
277
+ batch_input: list[str],
278
+ **input_kwargs,
279
+ ) -> list[float]:
280
+ if self.can_generate:
281
+ raise ValueError("Cannot get scores using an auto-regressive model.")
282
+ args = (self.model, self.tokenizer, batch_input, input_kwargs)
283
+ async with self.semaphore:
284
+ return await asyncio.to_thread(self._get_scores, *args)
LlamaFactory/src/llamafactory/chat/sglang_engine.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import atexit
17
+ import json
18
+ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
19
+ from typing import TYPE_CHECKING, Any, Optional, Union
20
+
21
+ import requests
22
+ from typing_extensions import override
23
+
24
+ from ..data import get_template_and_fix_tokenizer
25
+ from ..extras import logging
26
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
27
+ from ..extras.misc import get_device_count, torch_gc
28
+ from ..extras.packages import is_sglang_available
29
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
30
+ from ..model import load_config, load_tokenizer
31
+ from ..model.model_utils.quantization import QuantizationMethod
32
+ from .base_engine import BaseEngine, Response
33
+
34
+
35
+ if is_sglang_available():
36
+ from sglang.utils import launch_server_cmd, terminate_process, wait_for_server # type: ignore
37
+
38
+
39
+ if TYPE_CHECKING:
40
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ class SGLangEngine(BaseEngine):
47
+ """Inference engine for SGLang models.
48
+
49
+ This class wraps the SGLang engine to provide a consistent interface for text generation
50
+ that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
51
+ better interaction and performance. The engine launches a server process and communicates
52
+ with it via HTTP requests.
53
+
54
+ For more details on the SGLang HTTP server approach, see:
55
+ https://docs.sglang.ai/backend/send_request.html
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ model_args: "ModelArguments",
61
+ data_args: "DataArguments",
62
+ finetuning_args: "FinetuningArguments",
63
+ generating_args: "GeneratingArguments",
64
+ ) -> None:
65
+ self.name = EngineName.SGLANG
66
+ self.model_args = model_args
67
+ config = load_config(model_args) # may download model from ms hub
68
+ if getattr(config, "quantization_config", None): # gptq models should use float16
69
+ quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
70
+ quant_method = quantization_config.get("quant_method", "")
71
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
72
+ model_args.infer_dtype = "float16"
73
+
74
+ self.can_generate = finetuning_args.stage == "sft"
75
+ tokenizer_module = load_tokenizer(model_args)
76
+ self.tokenizer = tokenizer_module["tokenizer"]
77
+ self.processor = tokenizer_module["processor"]
78
+ self.tokenizer.padding_side = "left"
79
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
80
+ self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
81
+ self.generating_args = generating_args.to_dict()
82
+ if model_args.adapter_name_or_path is not None:
83
+ self.lora_request = True
84
+ else:
85
+ self.lora_request = False
86
+
87
+ launch_cmd = [
88
+ "python3 -m sglang.launch_server",
89
+ f"--model-path {model_args.model_name_or_path}",
90
+ f"--dtype {model_args.infer_dtype}",
91
+ f"--context-length {model_args.sglang_maxlen}",
92
+ f"--mem-fraction-static {model_args.sglang_mem_fraction}",
93
+ f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}",
94
+ f"--download-dir {model_args.cache_dir}",
95
+ "--log-level error",
96
+ ]
97
+ if self.lora_request:
98
+ launch_cmd.extend(
99
+ [
100
+ "--max-loras-per-batch 1",
101
+ f"--lora-backend {model_args.sglang_lora_backend}",
102
+ f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
103
+ "--disable-radix-cache",
104
+ ]
105
+ )
106
+ launch_cmd = " ".join(launch_cmd)
107
+ logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
108
+ try:
109
+ torch_gc()
110
+ self.server_process, port = launch_server_cmd(launch_cmd)
111
+ self.base_url = f"http://localhost:{port}"
112
+ atexit.register(self._cleanup_server)
113
+
114
+ logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}")
115
+ wait_for_server(self.base_url, timeout=300)
116
+ logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}")
117
+ try:
118
+ response = requests.get(f"{self.base_url}/get_model_info", timeout=5)
119
+ if response.status_code == 200:
120
+ model_info = response.json()
121
+ logger.info(f"SGLang server model info: {model_info}")
122
+ except Exception as e:
123
+ logger.debug(f"Note: could not get model info: {str(e)}")
124
+
125
+ except Exception as e:
126
+ logger.error(f"Failed to start SGLang server: {str(e)}")
127
+ self._cleanup_server() # make sure to clean up any started process
128
+ raise RuntimeError(f"SGLang server initialization failed: {str(e)}.")
129
+
130
+ def _cleanup_server(self):
131
+ r"""Clean up the server process when the engine is destroyed."""
132
+ if hasattr(self, "server_process") and self.server_process:
133
+ try:
134
+ logger.info("Terminating SGLang server process")
135
+ terminate_process(self.server_process)
136
+ logger.info("SGLang server process terminated")
137
+ except Exception as e:
138
+ logger.warning(f"Error terminating SGLang server: {str(e)}")
139
+
140
+ async def _generate(
141
+ self,
142
+ messages: list[dict[str, str]],
143
+ system: Optional[str] = None,
144
+ tools: Optional[str] = None,
145
+ images: Optional[list["ImageInput"]] = None,
146
+ videos: Optional[list["VideoInput"]] = None,
147
+ audios: Optional[list["AudioInput"]] = None,
148
+ **input_kwargs,
149
+ ) -> AsyncIterator[dict[str, Any]]:
150
+ if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
151
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
152
+
153
+ if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
154
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
155
+
156
+ if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
157
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
158
+
159
+ messages = self.template.mm_plugin.process_messages(
160
+ messages, images or [], videos or [], audios or [], self.processor
161
+ )
162
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
163
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
164
+ prompt_length = len(prompt_ids)
165
+
166
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
167
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
168
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
169
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
170
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
171
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
172
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
173
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
174
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
175
+
176
+ if num_return_sequences != 1:
177
+ raise NotImplementedError("SGLang only supports n=1.")
178
+
179
+ if "max_new_tokens" in self.generating_args:
180
+ max_tokens = self.generating_args["max_new_tokens"]
181
+ elif "max_length" in self.generating_args:
182
+ if self.generating_args["max_length"] > prompt_length:
183
+ max_tokens = self.generating_args["max_length"] - prompt_length
184
+ else:
185
+ max_tokens = 1
186
+
187
+ if max_length:
188
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
189
+
190
+ if max_new_tokens:
191
+ max_tokens = max_new_tokens
192
+
193
+ sampling_params = {
194
+ "temperature": temperature if temperature is not None else self.generating_args["temperature"],
195
+ "top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
196
+ "top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
197
+ "stop": stop,
198
+ "stop_token_ids": self.template.get_stop_token_ids(self.tokenizer),
199
+ "max_new_tokens": max_tokens,
200
+ "repetition_penalty": (
201
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
202
+ )
203
+ or 1.0, # repetition_penalty must > 0
204
+ "skip_special_tokens": skip_special_tokens
205
+ if skip_special_tokens is not None
206
+ else self.generating_args["skip_special_tokens"],
207
+ }
208
+
209
+ def stream_request():
210
+ json_data = {
211
+ "input_ids": prompt_ids,
212
+ "sampling_params": sampling_params,
213
+ "stream": True,
214
+ }
215
+ if self.lora_request:
216
+ json_data["lora_request"] = ["lora0"]
217
+ response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
218
+ if response.status_code != 200:
219
+ raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
220
+
221
+ for chunk in response.iter_lines(decode_unicode=False):
222
+ chunk = str(chunk.decode("utf-8"))
223
+ if chunk == "data: [DONE]":
224
+ break
225
+
226
+ if chunk and chunk.startswith("data:"):
227
+ yield json.loads(chunk[5:].strip("\n"))
228
+
229
+ return await asyncio.to_thread(stream_request)
230
+
231
+ @override
232
+ async def chat(
233
+ self,
234
+ messages: Sequence[dict[str, str]],
235
+ system: Optional[str] = None,
236
+ tools: Optional[str] = None,
237
+ images: Optional[Sequence["ImageInput"]] = None,
238
+ videos: Optional[Sequence["VideoInput"]] = None,
239
+ audios: Optional[Sequence["AudioInput"]] = None,
240
+ **input_kwargs,
241
+ ) -> list["Response"]:
242
+ final_output = None
243
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
244
+ for request_output in generator:
245
+ final_output = request_output
246
+
247
+ results = [
248
+ Response(
249
+ response_text=final_output["text"],
250
+ response_length=final_output["meta_info"]["completion_tokens"],
251
+ prompt_length=final_output["meta_info"]["prompt_tokens"],
252
+ finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length",
253
+ )
254
+ ]
255
+ return results
256
+
257
+ @override
258
+ async def stream_chat(
259
+ self,
260
+ messages: list[dict[str, str]],
261
+ system: Optional[str] = None,
262
+ tools: Optional[str] = None,
263
+ images: Optional[list["ImageInput"]] = None,
264
+ videos: Optional[list["VideoInput"]] = None,
265
+ audios: Optional[list["AudioInput"]] = None,
266
+ **input_kwargs,
267
+ ) -> AsyncGenerator[str, None]:
268
+ generated_text = ""
269
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
270
+ for result in generator:
271
+ delta_text = result["text"][len(generated_text) :]
272
+ generated_text = result["text"]
273
+ yield delta_text
274
+
275
+ @override
276
+ async def get_scores(
277
+ self,
278
+ batch_input: list[str],
279
+ **input_kwargs,
280
+ ) -> list[float]:
281
+ raise NotImplementedError("SGLang engine does not support `get_scores`.")
282
+
283
+ def __del__(self):
284
+ r"""Ensure server is cleaned up when object is deleted."""
285
+ self._cleanup_server()
286
+ try:
287
+ atexit.unregister(self._cleanup_server)
288
+ except Exception:
289
+ pass
LlamaFactory/src/llamafactory/chat/vllm_engine.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import uuid
16
+ from collections.abc import AsyncGenerator, AsyncIterator
17
+ from typing import TYPE_CHECKING, Any, Optional, Union
18
+
19
+ from packaging import version
20
+ from typing_extensions import override
21
+
22
+ from ..data import get_template_and_fix_tokenizer
23
+ from ..extras import logging
24
+ from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
25
+ from ..extras.misc import get_device_count
26
+ from ..extras.packages import is_vllm_available
27
+ from ..model import load_config, load_tokenizer
28
+ from ..model.model_utils.quantization import QuantizationMethod
29
+ from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
30
+ from .base_engine import BaseEngine, Response
31
+
32
+
33
+ if is_vllm_available():
34
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
35
+ from vllm.lora.request import LoRARequest
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
40
+ from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ class VllmEngine(BaseEngine):
47
+ def __init__(
48
+ self,
49
+ model_args: "ModelArguments",
50
+ data_args: "DataArguments",
51
+ finetuning_args: "FinetuningArguments",
52
+ generating_args: "GeneratingArguments",
53
+ ) -> None:
54
+ self.name = EngineName.VLLM
55
+ self.model_args = model_args
56
+ config = load_config(model_args) # may download model from ms hub
57
+ if getattr(config, "quantization_config", None): # gptq models should use float16
58
+ quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
59
+ quant_method = quantization_config.get("quant_method", "")
60
+ if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
61
+ model_args.infer_dtype = "float16"
62
+
63
+ self.can_generate = finetuning_args.stage == "sft"
64
+ tokenizer_module = load_tokenizer(model_args)
65
+ self.tokenizer = tokenizer_module["tokenizer"]
66
+ self.processor = tokenizer_module["processor"]
67
+ self.tokenizer.padding_side = "left"
68
+ self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
69
+ self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
70
+ self.generating_args = generating_args.to_dict()
71
+
72
+ engine_args = {
73
+ "model": model_args.model_name_or_path,
74
+ "trust_remote_code": model_args.trust_remote_code,
75
+ "download_dir": model_args.cache_dir,
76
+ "dtype": model_args.infer_dtype,
77
+ "max_model_len": model_args.vllm_maxlen,
78
+ "tensor_parallel_size": get_device_count() or 1,
79
+ "gpu_memory_utilization": model_args.vllm_gpu_util,
80
+ "disable_log_stats": True,
81
+ "enforce_eager": model_args.vllm_enforce_eager,
82
+ "enable_lora": model_args.adapter_name_or_path is not None,
83
+ "max_lora_rank": model_args.vllm_max_lora_rank,
84
+ }
85
+
86
+ import vllm
87
+
88
+ if version.parse(vllm.__version__) <= version.parse("0.10.0"):
89
+ engine_args["disable_log_requests"] = True
90
+ else:
91
+ engine_args["enable_log_requests"] = False
92
+
93
+ if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
94
+ engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}
95
+
96
+ if isinstance(model_args.vllm_config, dict):
97
+ engine_args.update(model_args.vllm_config)
98
+
99
+ if getattr(config, "is_yi_vl_derived_model", None):
100
+ import vllm.model_executor.models.llava
101
+
102
+ logger.info_rank0("Detected Yi-VL model, applying projector patch.")
103
+ vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
104
+
105
+ self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
106
+ if model_args.adapter_name_or_path is not None:
107
+ self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
108
+ else:
109
+ self.lora_request = None
110
+
111
+ async def _generate(
112
+ self,
113
+ messages: list[dict[str, str]],
114
+ system: Optional[str] = None,
115
+ tools: Optional[str] = None,
116
+ images: Optional[list["ImageInput"]] = None,
117
+ videos: Optional[list["VideoInput"]] = None,
118
+ audios: Optional[list["AudioInput"]] = None,
119
+ **input_kwargs,
120
+ ) -> AsyncIterator["RequestOutput"]:
121
+ request_id = f"chatcmpl-{uuid.uuid4().hex}"
122
+ if images is not None and not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
123
+ messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
124
+
125
+ if videos is not None and not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
126
+ messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
127
+
128
+ if audios is not None and not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
129
+ messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
130
+
131
+ messages = self.template.mm_plugin.process_messages(
132
+ messages, images or [], videos or [], audios or [], self.processor
133
+ )
134
+ paired_messages = messages + [{"role": "assistant", "content": ""}]
135
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
136
+ prompt_length = len(prompt_ids)
137
+
138
+ temperature: Optional[float] = input_kwargs.pop("temperature", None)
139
+ top_p: Optional[float] = input_kwargs.pop("top_p", None)
140
+ top_k: Optional[float] = input_kwargs.pop("top_k", None)
141
+ num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
142
+ repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
143
+ length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
144
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
145
+ max_length: Optional[int] = input_kwargs.pop("max_length", None)
146
+ max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
147
+ stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None)
148
+
149
+ if length_penalty is not None:
150
+ logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
151
+
152
+ if "max_new_tokens" in self.generating_args:
153
+ max_tokens = self.generating_args["max_new_tokens"]
154
+ elif "max_length" in self.generating_args:
155
+ if self.generating_args["max_length"] > prompt_length:
156
+ max_tokens = self.generating_args["max_length"] - prompt_length
157
+ else:
158
+ max_tokens = 1
159
+
160
+ if max_length:
161
+ max_tokens = max_length - prompt_length if max_length > prompt_length else 1
162
+
163
+ if max_new_tokens:
164
+ max_tokens = max_new_tokens
165
+
166
+ sampling_params = SamplingParams(
167
+ n=num_return_sequences,
168
+ repetition_penalty=(
169
+ repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
170
+ )
171
+ or 1.0, # repetition_penalty must > 0
172
+ temperature=temperature if temperature is not None else self.generating_args["temperature"],
173
+ top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
174
+ top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
175
+ stop=stop,
176
+ stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
177
+ max_tokens=max_tokens,
178
+ skip_special_tokens=skip_special_tokens
179
+ if skip_special_tokens is not None
180
+ else self.generating_args["skip_special_tokens"],
181
+ )
182
+
183
+ if images is not None: # add image features
184
+ multi_modal_data = {
185
+ "image": self.template.mm_plugin._regularize_images(
186
+ images,
187
+ image_max_pixels=self.model_args.image_max_pixels,
188
+ image_min_pixels=self.model_args.image_min_pixels,
189
+ )["images"]
190
+ }
191
+ elif videos is not None:
192
+ multi_modal_data = {
193
+ "video": self.template.mm_plugin._regularize_videos(
194
+ videos,
195
+ image_max_pixels=self.model_args.video_max_pixels,
196
+ image_min_pixels=self.model_args.video_min_pixels,
197
+ video_fps=self.model_args.video_fps,
198
+ video_maxlen=self.model_args.video_maxlen,
199
+ )["videos"]
200
+ }
201
+ elif audios is not None:
202
+ audio_data = self.template.mm_plugin._regularize_audios(
203
+ audios,
204
+ sampling_rate=self.model_args.audio_sampling_rate,
205
+ )
206
+ multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
207
+ else:
208
+ multi_modal_data = None
209
+
210
+ result_generator = self.model.generate(
211
+ {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
212
+ sampling_params=sampling_params,
213
+ request_id=request_id,
214
+ lora_request=self.lora_request,
215
+ )
216
+ return result_generator
217
+
218
+ @override
219
+ async def chat(
220
+ self,
221
+ messages: list[dict[str, str]],
222
+ system: Optional[str] = None,
223
+ tools: Optional[str] = None,
224
+ images: Optional[list["ImageInput"]] = None,
225
+ videos: Optional[list["VideoInput"]] = None,
226
+ audios: Optional[list["AudioInput"]] = None,
227
+ **input_kwargs,
228
+ ) -> list["Response"]:
229
+ final_output = None
230
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
231
+ async for request_output in generator:
232
+ final_output = request_output
233
+
234
+ results = []
235
+ for output in final_output.outputs:
236
+ results.append(
237
+ Response(
238
+ response_text=output.text,
239
+ response_length=len(output.token_ids),
240
+ prompt_length=len(final_output.prompt_token_ids),
241
+ finish_reason=output.finish_reason,
242
+ )
243
+ )
244
+
245
+ return results
246
+
247
+ @override
248
+ async def stream_chat(
249
+ self,
250
+ messages: list[dict[str, str]],
251
+ system: Optional[str] = None,
252
+ tools: Optional[str] = None,
253
+ images: Optional[list["ImageInput"]] = None,
254
+ videos: Optional[list["VideoInput"]] = None,
255
+ audios: Optional[list["AudioInput"]] = None,
256
+ **input_kwargs,
257
+ ) -> AsyncGenerator[str, None]:
258
+ generated_text = ""
259
+ generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
260
+ async for result in generator:
261
+ delta_text = result.outputs[0].text[len(generated_text) :]
262
+ generated_text = result.outputs[0].text
263
+ yield delta_text
264
+
265
+ @override
266
+ async def get_scores(
267
+ self,
268
+ batch_input: list[str],
269
+ **input_kwargs,
270
+ ) -> list[float]:
271
+ raise NotImplementedError("vLLM engine does not support `get_scores`.")
LlamaFactory/src/llamafactory/cli.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ def main():
17
+ from .extras.misc import is_env_enabled
18
+
19
+ if is_env_enabled("USE_V1"):
20
+ from .v1 import launcher
21
+ else:
22
+ from . import launcher
23
+
24
+ launcher.launch()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ from multiprocessing import freeze_support
29
+
30
+ freeze_support()
31
+ main()
LlamaFactory/src/llamafactory/data/.ipynb_checkpoints/template-checkpoint.py ADDED
@@ -0,0 +1,2175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from copy import deepcopy
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Optional, Union
19
+
20
+ from typing_extensions import override
21
+
22
+ from ..extras import logging
23
+ from .data_utils import Role
24
+ from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
25
+ from .mm_plugin import get_mm_plugin
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers import PreTrainedTokenizer
30
+
31
+ from ..hparams import DataArguments
32
+ from .formatter import SLOTS, Formatter
33
+ from .mm_plugin import BasePlugin
34
+ from .tool_utils import FunctionCall
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ @dataclass
41
+ class Template:
42
+ format_user: "Formatter"
43
+ format_assistant: "Formatter"
44
+ format_system: "Formatter"
45
+ format_function: "Formatter"
46
+ format_observation: "Formatter"
47
+ format_tools: "Formatter"
48
+ format_prefix: "Formatter"
49
+ default_system: str
50
+ stop_words: list[str]
51
+ thought_words: tuple[str, str]
52
+ tool_call_words: tuple[str, str]
53
+ efficient_eos: bool
54
+ replace_eos: bool
55
+ replace_jinja_template: bool
56
+ enable_thinking: Optional[bool]
57
+ mm_plugin: "BasePlugin"
58
+
59
+ def encode_oneturn(
60
+ self,
61
+ tokenizer: "PreTrainedTokenizer",
62
+ messages: list[dict[str, str]],
63
+ system: Optional[str] = None,
64
+ tools: Optional[str] = None,
65
+ ) -> tuple[list[int], list[int]]:
66
+ r"""Return a single pair of token ids representing prompt and response respectively."""
67
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
68
+ prompt_ids = []
69
+ for encoded_ids in encoded_messages[:-1]:
70
+ prompt_ids += encoded_ids
71
+
72
+ response_ids = encoded_messages[-1]
73
+ return prompt_ids, response_ids
74
+
75
+ def encode_multiturn(
76
+ self,
77
+ tokenizer: "PreTrainedTokenizer",
78
+ messages: list[dict[str, str]],
79
+ system: Optional[str] = None,
80
+ tools: Optional[str] = None,
81
+ ) -> list[tuple[list[int], list[int]]]:
82
+ r"""Return multiple pairs of token ids representing prompts and responses respectively."""
83
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
84
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
85
+
86
+ def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
87
+ r"""Extract tool message."""
88
+ return self.format_tools.extract(content)
89
+
90
+ def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
91
+ r"""Return stop token ids."""
92
+ stop_token_ids = {tokenizer.eos_token_id}
93
+ for token in self.stop_words:
94
+ stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
95
+
96
+ return list(stop_token_ids)
97
+
98
+ def add_thought(self, content: str = "") -> str:
99
+ r"""Add empty thought to assistant message."""
100
+ return f"{self.thought_words[0]}{self.thought_words[1]}" + content
101
+
102
+ def remove_thought(self, content: str) -> str:
103
+ r"""Remove thought from assistant message."""
104
+ pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
105
+ return re.sub(pattern, "", content).lstrip("\n")
106
+
107
+ def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
108
+ r"""Get the token ids of thought words."""
109
+ return tokenizer.encode(self.add_thought(), add_special_tokens=False)
110
+
111
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
112
+ r"""Convert elements to token ids."""
113
+ token_ids = []
114
+ for elem in elements:
115
+ if isinstance(elem, str):
116
+ if len(elem) != 0:
117
+ token_ids += tokenizer.encode(elem, add_special_tokens=False)
118
+ elif isinstance(elem, dict):
119
+ token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
120
+ elif isinstance(elem, set):
121
+ if "bos_token" in elem and tokenizer.bos_token_id is not None:
122
+ token_ids += [tokenizer.bos_token_id]
123
+ elif "eos_token" in elem and tokenizer.eos_token_id is not None:
124
+ token_ids += [tokenizer.eos_token_id]
125
+ else:
126
+ raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
127
+
128
+ return token_ids
129
+
130
+ def _encode(
131
+ self,
132
+ tokenizer: "PreTrainedTokenizer",
133
+ messages: list[dict[str, str]],
134
+ system: Optional[str],
135
+ tools: Optional[str],
136
+ ) -> list[list[int]]:
137
+ r"""Encode formatted inputs to pairs of token ids.
138
+
139
+ Turn 0: prefix + system + query resp
140
+ Turn t: query resp.
141
+ """
142
+ system = system or self.default_system
143
+ encoded_messages = []
144
+ for i, message in enumerate(messages):
145
+ elements = []
146
+
147
+ if i == 0:
148
+ elements += self.format_prefix.apply()
149
+ if system or tools:
150
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
151
+ elements += self.format_system.apply(content=(system + tool_text))
152
+
153
+ if message["role"] == Role.USER:
154
+ elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
155
+ elif message["role"] == Role.ASSISTANT:
156
+ elements += self.format_assistant.apply(content=message["content"])
157
+ elif message["role"] == Role.OBSERVATION:
158
+ elements += self.format_observation.apply(content=message["content"])
159
+ elif message["role"] == Role.FUNCTION:
160
+ elements += self.format_function.apply(
161
+ content=message["content"], thought_words=self.thought_words, tool_call_words=self.tool_call_words
162
+ )
163
+ else:
164
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
165
+
166
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
167
+
168
+ return encoded_messages
169
+
170
+ @staticmethod
171
+ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
172
+ r"""Add or replace eos token to the tokenizer."""
173
+ if tokenizer.eos_token == eos_token:
174
+ return
175
+
176
+ is_added = tokenizer.eos_token_id is None
177
+ num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
178
+
179
+ if is_added:
180
+ logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
181
+ else:
182
+ logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
183
+
184
+ if num_added_tokens > 0:
185
+ logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
186
+
187
+ def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
188
+ r"""Add eos token and pad token to the tokenizer."""
189
+ stop_words = self.stop_words
190
+ if self.replace_eos:
191
+ if not stop_words:
192
+ raise ValueError("Stop words are required to replace the EOS token.")
193
+
194
+ self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
195
+ stop_words = stop_words[1:]
196
+
197
+ if tokenizer.eos_token_id is None:
198
+ self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
199
+
200
+ if tokenizer.pad_token_id is None:
201
+ tokenizer.pad_token = tokenizer.eos_token
202
+ logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
203
+
204
+ if stop_words:
205
+ try:
206
+ num_added_tokens = tokenizer.add_special_tokens(
207
+ dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
208
+ )
209
+ except TypeError:
210
+ num_added_tokens = tokenizer.add_special_tokens(dict(additional_special_tokens=stop_words))
211
+ logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
212
+ if num_added_tokens > 0:
213
+ logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
214
+
215
+ @staticmethod
216
+ def _jinja_escape(content: str) -> str:
217
+ r"""Escape single quotes in content."""
218
+ return content.replace("'", r"\'")
219
+
220
+ @staticmethod
221
+ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
222
+ r"""Convert slots to jinja template."""
223
+ slot_items = []
224
+ for slot in slots:
225
+ if isinstance(slot, str):
226
+ slot_pieces = slot.split("{{content}}")
227
+ if slot_pieces[0]:
228
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
229
+ if len(slot_pieces) > 1:
230
+ slot_items.append(placeholder)
231
+ if slot_pieces[1]:
232
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
233
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
234
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
235
+ slot_items.append("'" + tokenizer.bos_token + "'")
236
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
237
+ slot_items.append("'" + tokenizer.eos_token + "'")
238
+ elif isinstance(slot, dict):
239
+ raise ValueError("Dict is not supported.")
240
+
241
+ return " + ".join(slot_items)
242
+
243
+ def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
244
+ r"""Return the jinja template."""
245
+ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
246
+ system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
247
+ user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
248
+ assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
249
+ jinja_template = ""
250
+ if prefix:
251
+ jinja_template += "{{ " + prefix + " }}"
252
+
253
+ if self.default_system:
254
+ jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
255
+
256
+ jinja_template += (
257
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
258
+ "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
259
+ "{% if system_message is defined %}{{ " + system + " }}{% endif %}"
260
+ "{% for message in loop_messages %}"
261
+ "{% set content = message['content'] %}"
262
+ "{% if message['role'] == 'user' %}"
263
+ "{{ " + user + " }}"
264
+ "{% elif message['role'] == 'assistant' %}"
265
+ "{{ " + assistant + " }}"
266
+ "{% endif %}"
267
+ "{% endfor %}"
268
+ )
269
+ return jinja_template
270
+
271
+ def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
272
+ r"""Replace the jinja template in the tokenizer."""
273
+ if tokenizer.chat_template is None or self.replace_jinja_template:
274
+ try:
275
+ tokenizer.chat_template = self._get_jinja_template(tokenizer)
276
+ except ValueError as e:
277
+ logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
278
+
279
+ @staticmethod
280
+ def _convert_slots_to_ollama(
281
+ slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
282
+ ) -> str:
283
+ r"""Convert slots to ollama template."""
284
+ slot_items = []
285
+ for slot in slots:
286
+ if isinstance(slot, str):
287
+ slot_pieces = slot.split("{{content}}")
288
+ if slot_pieces[0]:
289
+ slot_items.append(slot_pieces[0])
290
+ if len(slot_pieces) > 1:
291
+ slot_items.append("{{ " + placeholder + " }}")
292
+ if slot_pieces[1]:
293
+ slot_items.append(slot_pieces[1])
294
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
295
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
296
+ slot_items.append(tokenizer.bos_token)
297
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
298
+ slot_items.append(tokenizer.eos_token)
299
+ elif isinstance(slot, dict):
300
+ raise ValueError("Dict is not supported.")
301
+
302
+ return "".join(slot_items)
303
+
304
+ def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
305
+ r"""Return the ollama template."""
306
+ prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
307
+ system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
308
+ user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
309
+ assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
310
+ return (
311
+ f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
312
+ f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
313
+ f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
314
+ )
315
+
316
+ def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
317
+ r"""Return the ollama modelfile.
318
+
319
+ TODO: support function calling.
320
+ """
321
+ modelfile = "# ollama modelfile auto-generated by llamafactory\n\n"
322
+ modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
323
+
324
+ if self.default_system:
325
+ modelfile += f'SYSTEM """{self.default_system}"""\n\n'
326
+
327
+ for stop_token_id in self.get_stop_token_ids(tokenizer):
328
+ modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
329
+
330
+ modelfile += "PARAMETER num_ctx 4096\n"
331
+ return modelfile
332
+
333
+
334
+ @dataclass
335
+ class Llama2Template(Template):
336
+ r"""A template that fuse the system message to first user message."""
337
+
338
+ @override
339
+ def _encode(
340
+ self,
341
+ tokenizer: "PreTrainedTokenizer",
342
+ messages: list[dict[str, str]],
343
+ system: str,
344
+ tools: str,
345
+ ) -> list[list[int]]:
346
+ system = system or self.default_system
347
+ encoded_messages = []
348
+ for i, message in enumerate(messages):
349
+ elements = []
350
+
351
+ system_text = ""
352
+ if i == 0:
353
+ elements += self.format_prefix.apply()
354
+ if system or tools:
355
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
356
+ system_text = self.format_system.apply(content=(system + tool_text))[0]
357
+
358
+ if message["role"] == Role.USER:
359
+ elements += self.format_user.apply(content=system_text + message["content"])
360
+ elif message["role"] == Role.ASSISTANT:
361
+ elements += self.format_assistant.apply(content=message["content"])
362
+ elif message["role"] == Role.OBSERVATION:
363
+ elements += self.format_observation.apply(content=message["content"])
364
+ elif message["role"] == Role.FUNCTION:
365
+ elements += self.format_function.apply(content=message["content"])
366
+ else:
367
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
368
+
369
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
370
+
371
+ return encoded_messages
372
+
373
+ def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
374
+ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
375
+ system_message = self._convert_slots_to_jinja(
376
+ self.format_system.apply(), tokenizer, placeholder="system_message"
377
+ )
378
+ user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
379
+ assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
380
+ jinja_template = ""
381
+ if prefix:
382
+ jinja_template += "{{ " + prefix + " }}"
383
+
384
+ if self.default_system:
385
+ jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
386
+
387
+ jinja_template += (
388
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
389
+ "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
390
+ "{% for message in loop_messages %}"
391
+ "{% if loop.index0 == 0 and system_message is defined %}"
392
+ "{% set content = " + system_message + " + message['content'] %}"
393
+ "{% else %}{% set content = message['content'] %}{% endif %}"
394
+ "{% if message['role'] == 'user' %}"
395
+ "{{ " + user_message + " }}"
396
+ "{% elif message['role'] == 'assistant' %}"
397
+ "{{ " + assistant_message + " }}"
398
+ "{% endif %}"
399
+ "{% endfor %}"
400
+ )
401
+ return jinja_template
402
+
403
+
404
+ @dataclass
405
+ class ReasoningTemplate(Template):
406
+ r"""A template that add thought to assistant message."""
407
+
408
+ @override
409
+ def encode_oneturn(
410
+ self,
411
+ tokenizer: "PreTrainedTokenizer",
412
+ messages: list[dict[str, str]],
413
+ system: Optional[str] = None,
414
+ tools: Optional[str] = None,
415
+ ) -> tuple[list[int], list[int]]:
416
+ messages = deepcopy(messages)
417
+ for i in range(1, len(messages) - 2, 2):
418
+ messages[i]["content"] = self.remove_thought(messages[i]["content"])
419
+
420
+ if self.enable_thinking is False: # remove all cot
421
+ messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
422
+
423
+ prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
424
+ if (
425
+ self.thought_words[0].strip() not in messages[-1]["content"]
426
+ and self.thought_words[1].strip() not in messages[-1]["content"]
427
+ ): # add empty cot
428
+ if not self.enable_thinking: # do not compute loss
429
+ prompt_ids += self.get_thought_word_ids(tokenizer)
430
+ else: # do compute loss
431
+ response_ids = self.get_thought_word_ids(tokenizer) + response_ids
432
+
433
+ return prompt_ids, response_ids
434
+
435
+ @override
436
+ def encode_multiturn(
437
+ self,
438
+ tokenizer: "PreTrainedTokenizer",
439
+ messages: list[dict[str, str]],
440
+ system: Optional[str] = None,
441
+ tools: Optional[str] = None,
442
+ ) -> list[tuple[list[int], list[int]]]:
443
+ messages = deepcopy(messages)
444
+ if self.enable_thinking is False: # remove all cot
445
+ for i in range(1, len(messages), 2):
446
+ messages[i]["content"] = self.remove_thought(messages[i]["content"])
447
+
448
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
449
+ for i in range(0, len(messages), 2):
450
+ if (
451
+ self.thought_words[0].strip() not in messages[i + 1]["content"]
452
+ and self.thought_words[1].strip() not in messages[i + 1]["content"]
453
+ ): # add empty cot
454
+ if not self.enable_thinking: # do not compute loss
455
+ encoded_messages[i] += self.get_thought_word_ids(tokenizer)
456
+ else: # do compute loss
457
+ encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
458
+
459
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
460
+
461
+
462
+ TEMPLATES: dict[str, "Template"] = {}
463
+
464
+
465
+ def register_template(
466
+ name: str,
467
+ format_user: Optional["Formatter"] = None,
468
+ format_assistant: Optional["Formatter"] = None,
469
+ format_system: Optional["Formatter"] = None,
470
+ format_function: Optional["Formatter"] = None,
471
+ format_observation: Optional["Formatter"] = None,
472
+ format_tools: Optional["Formatter"] = None,
473
+ format_prefix: Optional["Formatter"] = None,
474
+ default_system: str = "",
475
+ stop_words: Optional[list[str]] = None,
476
+ thought_words: Optional[tuple[str, str]] = None,
477
+ tool_call_words: Optional[tuple[str, str]] = None,
478
+ efficient_eos: bool = False,
479
+ replace_eos: bool = False,
480
+ replace_jinja_template: bool = False,
481
+ enable_thinking: Optional[bool] = True,
482
+ mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
483
+ template_class: type["Template"] = Template,
484
+ ) -> None:
485
+ r"""Register a chat template.
486
+
487
+ To add the following chat template:
488
+ ```
489
+ <s><user>user prompt here
490
+ <model>model response here</s>
491
+ <user>user prompt here
492
+ <model>model response here</s>
493
+ ```
494
+
495
+ The corresponding code should be:
496
+ ```
497
+ register_template(
498
+ name="custom",
499
+ format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
500
+ format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
501
+ format_prefix=EmptyFormatter("<s>"),
502
+ )
503
+ ```
504
+ """
505
+ if name in TEMPLATES:
506
+ raise ValueError(f"Template {name} already exists.")
507
+
508
+ default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
509
+ default_user_formatter = StringFormatter(slots=["{{content}}"])
510
+ default_assistant_formatter = StringFormatter(slots=default_slots)
511
+ if format_assistant is not None:
512
+ default_function_formatter = FunctionFormatter(slots=format_assistant.slots, tool_format="default")
513
+ else:
514
+ default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
515
+
516
+ default_tool_formatter = ToolFormatter(tool_format="default")
517
+ default_prefix_formatter = EmptyFormatter()
518
+ TEMPLATES[name] = template_class(
519
+ format_user=format_user or default_user_formatter,
520
+ format_assistant=format_assistant or default_assistant_formatter,
521
+ format_system=format_system or default_user_formatter,
522
+ format_function=format_function or default_function_formatter,
523
+ format_observation=format_observation or format_user or default_user_formatter,
524
+ format_tools=format_tools or default_tool_formatter,
525
+ format_prefix=format_prefix or default_prefix_formatter,
526
+ default_system=default_system,
527
+ stop_words=stop_words or [],
528
+ thought_words=thought_words or ("<think>\n", "\n</think>\n\n"),
529
+ tool_call_words=tool_call_words or ("<tool_call>", "</tool_call>"),
530
+ efficient_eos=efficient_eos,
531
+ replace_eos=replace_eos,
532
+ replace_jinja_template=replace_jinja_template,
533
+ enable_thinking=enable_thinking,
534
+ mm_plugin=mm_plugin,
535
+ )
536
+
537
+
538
+ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
539
+ r"""Extract a chat template from the tokenizer."""
540
+
541
+ def find_diff(short_str: str, long_str: str) -> str:
542
+ i, j = 0, 0
543
+ diff = ""
544
+ while i < len(short_str) and j < len(long_str):
545
+ if short_str[i] == long_str[j]:
546
+ i += 1
547
+ j += 1
548
+ else:
549
+ diff += long_str[j]
550
+ j += 1
551
+
552
+ return diff
553
+
554
+ prefix = tokenizer.decode(tokenizer.encode(""))
555
+
556
+ messages = [{"role": "system", "content": "{{content}}"}]
557
+ system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
558
+
559
+ messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
560
+ user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
561
+ user_slot_empty_system = user_slot_empty_system[len(prefix) :]
562
+
563
+ messages = [{"role": "user", "content": "{{content}}"}]
564
+ user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
565
+ user_slot = user_slot[len(prefix) :]
566
+
567
+ messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
568
+ assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
569
+ assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
570
+ template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
571
+ assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
572
+
573
+ if len(user_slot) > len(user_slot_empty_system):
574
+ default_system = find_diff(user_slot_empty_system, user_slot)
575
+ sole_system = system_slot.replace("{{content}}", default_system, 1)
576
+ user_slot = user_slot[len(sole_system) :]
577
+ else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
578
+ default_system = ""
579
+
580
+ return template_class(
581
+ format_user=StringFormatter(slots=[user_slot]),
582
+ format_assistant=StringFormatter(slots=[assistant_slot]),
583
+ format_system=StringFormatter(slots=[system_slot]),
584
+ format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
585
+ format_observation=StringFormatter(slots=[user_slot]),
586
+ format_tools=ToolFormatter(tool_format="default"),
587
+ format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
588
+ default_system=default_system,
589
+ stop_words=[],
590
+ thought_words=("<think>\n", "\n</think>\n\n"),
591
+ tool_call_words=("<tool_call>", "</tool_call>"),
592
+ efficient_eos=False,
593
+ replace_eos=False,
594
+ replace_jinja_template=False,
595
+ enable_thinking=True,
596
+ mm_plugin=get_mm_plugin(name="base"),
597
+ )
598
+
599
+
600
+ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
601
+ r"""Get chat template and fixes the tokenizer."""
602
+ if data_args.template is None:
603
+ if isinstance(tokenizer.chat_template, str):
604
+ logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
605
+ template = parse_template(tokenizer)
606
+ else:
607
+ logger.warning_rank0("`template` was not specified, use `empty` template.")
608
+ template = TEMPLATES["empty"] # placeholder
609
+ else:
610
+ if data_args.template not in TEMPLATES:
611
+ raise ValueError(f"Template {data_args.template} does not exist.")
612
+
613
+ template = TEMPLATES[data_args.template]
614
+
615
+ if data_args.train_on_prompt and template.efficient_eos:
616
+ raise ValueError("Current template does not support `train_on_prompt`.")
617
+
618
+ if data_args.tool_format is not None:
619
+ logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
620
+ default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
621
+ template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
622
+ template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
623
+
624
+ if data_args.default_system is not None:
625
+ logger.info_rank0(f"Using default system message: {data_args.default_system}.")
626
+ template.default_system = data_args.default_system
627
+
628
+ if isinstance(template, ReasoningTemplate):
629
+ logger.warning_rank0(
630
+ "You are using reasoning template, "
631
+ "please add `_nothink` suffix if the model is not a reasoning model. "
632
+ "e.g., qwen3_vl_nothink"
633
+ )
634
+ template.enable_thinking = data_args.enable_thinking
635
+
636
+ template.fix_special_tokens(tokenizer)
637
+ template.fix_jinja_template(tokenizer)
638
+ return template
639
+
640
+
641
+ register_template(
642
+ name="alpaca",
643
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
644
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
645
+ default_system=(
646
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
647
+ ),
648
+ replace_jinja_template=True,
649
+ )
650
+
651
+
652
+ register_template(
653
+ name="bailing",
654
+ format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]),
655
+ format_system=StringFormatter(slots=["<role>SYSTEM</role>{{content}}"]),
656
+ format_observation=StringFormatter(slots=["<role>OBSERVATION</role>{{content}}<role>ASSISTANT</role>"]),
657
+ stop_words=["<|endoftext|>"],
658
+ efficient_eos=True,
659
+ )
660
+
661
+
662
+ register_template(
663
+ name="bailing_v2",
664
+ format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<|role_end|><role>ASSISTANT</role>"]),
665
+ format_system=StringFormatter(slots=["<role>SYSTEM</role>{{content}}<|role_end|>"]),
666
+ format_assistant=StringFormatter(slots=["{{content}}<|role_end|>"]),
667
+ format_observation=StringFormatter(
668
+ slots=[
669
+ "<role>OBSERVATION</role>\n<tool_response>\n{{content}}\n</tool_response><|role_end|><role>ASSISTANT</role>"
670
+ ]
671
+ ),
672
+ format_function=FunctionFormatter(slots=["{{content}}<|role_end|>"], tool_format="ling"),
673
+ format_tools=ToolFormatter(tool_format="ling"),
674
+ stop_words=["<|endoftext|>"],
675
+ efficient_eos=True,
676
+ )
677
+
678
+
679
+ register_template(
680
+ name="breeze",
681
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
682
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
683
+ efficient_eos=True,
684
+ )
685
+
686
+
687
+ register_template(
688
+ name="chatglm3",
689
+ format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
690
+ format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
691
+ format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
692
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
693
+ format_observation=StringFormatter(
694
+ slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
695
+ ),
696
+ format_tools=ToolFormatter(tool_format="glm4"),
697
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
698
+ stop_words=["<|user|>", "<|observation|>"],
699
+ efficient_eos=True,
700
+ )
701
+
702
+
703
+ register_template(
704
+ name="chatml",
705
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
706
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
707
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
708
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
709
+ stop_words=["<|im_end|>", "<|im_start|>"],
710
+ replace_eos=True,
711
+ replace_jinja_template=True,
712
+ )
713
+
714
+
715
+ # copied from chatml template
716
+ register_template(
717
+ name="chatml_de",
718
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
719
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
720
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
721
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
722
+ default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
723
+ stop_words=["<|im_end|>", "<|im_start|>"],
724
+ replace_eos=True,
725
+ replace_jinja_template=True,
726
+ )
727
+
728
+
729
+ register_template(
730
+ name="cohere",
731
+ format_user=StringFormatter(
732
+ slots=[
733
+ (
734
+ "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
735
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
736
+ )
737
+ ]
738
+ ),
739
+ format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
740
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
741
+ )
742
+
743
+
744
+ # copied from chatml template
745
+ register_template(
746
+ name="cpm4",
747
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
748
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
749
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
750
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
751
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
752
+ stop_words=["<|im_end|>"],
753
+ )
754
+
755
+
756
+ # copied from chatml template
757
+ register_template(
758
+ name="dbrx",
759
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
760
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
761
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
762
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
763
+ default_system=(
764
+ "You are DBRX, created by Databricks. You were last updated in December 2023. "
765
+ "You answer questions based on information available up to that point.\n"
766
+ "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
767
+ "responses to more complex and open-ended questions.\nYou assist with various tasks, "
768
+ "from writing to coding (using markdown for code blocks — remember to use ``` with "
769
+ "code, JSON, and tables).\n(You do not have real-time data access or code execution "
770
+ "capabilities. You avoid stereotyping and provide balanced perspectives on "
771
+ "controversial topics. You do not provide song lyrics, poems, or news articles and "
772
+ "do not divulge details of your training data.)\nThis is your system prompt, "
773
+ "guiding your responses. Do not reference it, just respond to the user. If you find "
774
+ "yourself talking about this message, stop. You should be responding appropriately "
775
+ "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
776
+ "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
777
+ ),
778
+ stop_words=["<|im_end|>"],
779
+ replace_eos=True,
780
+ )
781
+
782
+
783
+ register_template(
784
+ name="deepseek",
785
+ format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
786
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
787
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
788
+ )
789
+
790
+
791
+ register_template(
792
+ name="deepseek3",
793
+ format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
794
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
795
+ )
796
+
797
+
798
+ # copied from deepseek3 template
799
+ register_template(
800
+ name="deepseekr1",
801
+ format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
802
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
803
+ template_class=ReasoningTemplate,
804
+ )
805
+
806
+
807
+ register_template(
808
+ name="deepseekcoder",
809
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
810
+ format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
811
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
812
+ default_system=(
813
+ "You are an AI programming assistant, utilizing the DeepSeek Coder model, "
814
+ "developed by DeepSeek Company, and you only answer questions related to computer science. "
815
+ "For politically sensitive questions, security and privacy issues, "
816
+ "and other non-computer science questions, you will refuse to answer.\n"
817
+ ),
818
+ )
819
+
820
+
821
+ register_template(
822
+ name="default",
823
+ format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]),
824
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
825
+ format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]),
826
+ replace_jinja_template=True,
827
+ )
828
+
829
+
830
+ register_template(
831
+ name="dots_ocr",
832
+ format_user=StringFormatter(slots=["<|user|>{{content}}<|endofuser|><|assistant|>"]),
833
+ format_assistant=StringFormatter(slots=["{{content}}<|endofassistant|>"]),
834
+ format_system=StringFormatter(slots=["<|system|>{{content}}<|endofsystem|>\n"]),
835
+ stop_words=["<|endofassistant|>"],
836
+ efficient_eos=True,
837
+ mm_plugin=get_mm_plugin(
838
+ name="qwen2_vl",
839
+ image_token="<|imgpad|>",
840
+ video_token="<|vidpad|>",
841
+ vision_bos_token="<|img|>",
842
+ vision_eos_token="<|endofimg|>",
843
+ ),
844
+ )
845
+
846
+
847
+ register_template(
848
+ name="empty",
849
+ format_assistant=StringFormatter(slots=["{{content}}"]),
850
+ )
851
+
852
+
853
+ # copied from chatml template
854
+ register_template(
855
+ name="ernie",
856
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]),
857
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n\n"]),
858
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n\n"]),
859
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n\n<|im_start|>assistant\n"]),
860
+ default_system="<global_setting>\nthink_mode=True\n</global_setting>",
861
+ stop_words=["<|im_end|>"],
862
+ )
863
+
864
+
865
+ register_template(
866
+ name="ernie_nothink",
867
+ format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]),
868
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_sentence|>"]),
869
+ format_system=StringFormatter(slots=["{{content}}\n"]),
870
+ format_prefix=EmptyFormatter(slots=["<|begin_of_sentence|>"]),
871
+ stop_words=["<|end_of_sentence|>"],
872
+ )
873
+
874
+
875
+ register_template(
876
+ name="ernie_vl",
877
+ format_user=StringFormatter(slots=["User: {{content}}"]),
878
+ format_assistant=StringFormatter(slots=["\nAssistant: {{content}}<|end_of_sentence|>"]),
879
+ format_system=StringFormatter(slots=["{{content}}\n"]),
880
+ stop_words=["<|end_of_sentence|>"],
881
+ replace_eos=True,
882
+ replace_jinja_template=True,
883
+ template_class=ReasoningTemplate,
884
+ mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
885
+ )
886
+
887
+
888
+ register_template(
889
+ name="exaone",
890
+ format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
891
+ format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
892
+ format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
893
+ )
894
+
895
+
896
+ register_template(
897
+ name="falcon",
898
+ format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
899
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
900
+ efficient_eos=True,
901
+ )
902
+
903
+
904
+ # copied from chatml template
905
+ register_template(
906
+ name="falcon_h1",
907
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
908
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
909
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
910
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
911
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
912
+ stop_words=["<|im_end|>", "<|end_of_text|>"],
913
+ )
914
+
915
+
916
+ register_template(
917
+ name="fewshot",
918
+ format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
919
+ efficient_eos=True,
920
+ replace_jinja_template=True,
921
+ )
922
+
923
+
924
+ register_template(
925
+ name="gemma",
926
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
927
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
928
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
929
+ format_observation=StringFormatter(
930
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
931
+ ),
932
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
933
+ stop_words=["<end_of_turn>"],
934
+ replace_eos=True,
935
+ template_class=Llama2Template,
936
+ )
937
+
938
+
939
+ # copied from gemma template
940
+ register_template(
941
+ name="gemma2",
942
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
943
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
944
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
945
+ format_observation=StringFormatter(
946
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
947
+ ),
948
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
949
+ stop_words=["<eos>", "<end_of_turn>"],
950
+ efficient_eos=True,
951
+ template_class=Llama2Template,
952
+ )
953
+
954
+
955
+ # copied from gemma template
956
+ register_template(
957
+ name="gemma3",
958
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
959
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
960
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
961
+ format_observation=StringFormatter(
962
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
963
+ ),
964
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
965
+ stop_words=["<end_of_turn>"],
966
+ replace_eos=True,
967
+ mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
968
+ template_class=Llama2Template,
969
+ )
970
+
971
+
972
+ register_template(
973
+ name="gemma3n",
974
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
975
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
976
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
977
+ format_observation=StringFormatter(
978
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
979
+ ),
980
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
981
+ stop_words=["<end_of_turn>"],
982
+ replace_eos=True,
983
+ mm_plugin=get_mm_plugin("gemma3n", image_token="<image_soft_token>", audio_token="<audio_soft_token>"),
984
+ template_class=Llama2Template,
985
+ )
986
+
987
+
988
+ register_template(
989
+ name="glm4",
990
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
991
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
992
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
993
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
994
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
995
+ format_tools=ToolFormatter(tool_format="glm4"),
996
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
997
+ stop_words=["<|user|>", "<|observation|>"],
998
+ efficient_eos=True,
999
+ )
1000
+
1001
+
1002
+ # copied from glm4 template
1003
+ register_template(
1004
+ name="glm4_moe",
1005
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
1006
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
1007
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
1008
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
1009
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
1010
+ format_tools=ToolFormatter(tool_format="glm4_moe"),
1011
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
1012
+ stop_words=["<|user|>", "<|observation|>"],
1013
+ efficient_eos=True,
1014
+ template_class=ReasoningTemplate,
1015
+ )
1016
+
1017
+
1018
+ # copied from glm4 template
1019
+ register_template(
1020
+ name="glm4v",
1021
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
1022
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
1023
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
1024
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
1025
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
1026
+ format_tools=ToolFormatter(tool_format="glm4"),
1027
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
1028
+ stop_words=["<|user|>", "<|observation|>", "</answer>"],
1029
+ efficient_eos=True,
1030
+ mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
1031
+ template_class=ReasoningTemplate,
1032
+ )
1033
+
1034
+
1035
+ # copied from glm4 template
1036
+ register_template(
1037
+ name="glm4_5v",
1038
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
1039
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
1040
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
1041
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
1042
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
1043
+ format_tools=ToolFormatter(tool_format="glm4_moe"),
1044
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
1045
+ stop_words=["<|user|>", "<|observation|>", "</answer>"],
1046
+ efficient_eos=True,
1047
+ mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
1048
+ template_class=ReasoningTemplate,
1049
+ )
1050
+
1051
+
1052
+ # copied from glm4 template
1053
+ register_template(
1054
+ name="glmz1",
1055
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
1056
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
1057
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
1058
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
1059
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
1060
+ format_tools=ToolFormatter(tool_format="glm4"),
1061
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
1062
+ stop_words=["<|user|>", "<|observation|>"],
1063
+ efficient_eos=True,
1064
+ template_class=ReasoningTemplate,
1065
+ )
1066
+
1067
+
1068
+ register_template(
1069
+ name="gpt_oss",
1070
+ format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
1071
+ format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
1072
+ format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
1073
+ default_system="You are ChatGPT, a large language model trained by OpenAI.",
1074
+ thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
1075
+ efficient_eos=True,
1076
+ template_class=ReasoningTemplate,
1077
+ )
1078
+
1079
+
1080
+ register_template(
1081
+ name="granite3",
1082
+ format_user=StringFormatter(
1083
+ slots=[
1084
+ "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
1085
+ ]
1086
+ ),
1087
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
1088
+ format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
1089
+ )
1090
+
1091
+
1092
+ register_template(
1093
+ name="granite3_vision",
1094
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]),
1095
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]),
1096
+ default_system=(
1097
+ "A chat between a curious user and an artificial intelligence assistant. "
1098
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
1099
+ ),
1100
+ mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
1101
+ )
1102
+
1103
+
1104
+ register_template(
1105
+ name="granite4",
1106
+ format_user=StringFormatter(
1107
+ slots=[
1108
+ "<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
1109
+ ]
1110
+ ),
1111
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
1112
+ format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
1113
+ format_function=FunctionFormatter(slots=["{{content}}<|end_of_text|>\n"], tool_format="default"),
1114
+ format_observation=StringFormatter(
1115
+ slots=["<|start_of_role|>tool<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant\n"]
1116
+ ),
1117
+ format_tools=ToolFormatter(tool_format="default"),
1118
+ stop_words=["<|end_of_text|>"],
1119
+ default_system="You are Granite, developed by IBM. You are a helpful AI assistant.",
1120
+ )
1121
+
1122
+
1123
+ register_template(
1124
+ name="index",
1125
+ format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
1126
+ format_system=StringFormatter(slots=["<unk>{{content}}"]),
1127
+ efficient_eos=True,
1128
+ )
1129
+
1130
+
1131
+ register_template(
1132
+ name="hunyuan",
1133
+ format_user=StringFormatter(slots=["{{content}}<|extra_0|>"]),
1134
+ format_assistant=StringFormatter(slots=["{{content}}<|eos|>"]),
1135
+ format_system=StringFormatter(slots=["{{content}}<|extra_4|>"]),
1136
+ format_prefix=EmptyFormatter(slots=["<|startoftext|>"]),
1137
+ stop_words=["<|eos|>"],
1138
+ )
1139
+
1140
+
1141
+ register_template(
1142
+ name="hunyuan_small",
1143
+ format_user=StringFormatter(slots=["<|hy_User|>{{content}}<|hy_place▁holder▁no▁8|>"]),
1144
+ format_assistant=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁2|>"]),
1145
+ format_system=StringFormatter(slots=["{{content}}<|hy_place▁holder▁no▁3|>"]),
1146
+ format_prefix=EmptyFormatter(slots=["<|hy_begin▁of▁sentence|>"]),
1147
+ stop_words=["<|hy_place▁holder▁no▁2|>"],
1148
+ )
1149
+
1150
+
1151
+ register_template(
1152
+ name="intern2",
1153
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1154
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1155
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1156
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1157
+ default_system=(
1158
+ "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1159
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
1160
+ "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1161
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language "
1162
+ "chosen by the user such as English and 中文."
1163
+ ),
1164
+ stop_words=["<|im_end|>"],
1165
+ )
1166
+
1167
+
1168
+ register_template(
1169
+ name="intern_vl",
1170
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1171
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1172
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1173
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1174
+ default_system=(
1175
+ "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
1176
+ ),
1177
+ stop_words=["<|im_end|>"],
1178
+ mm_plugin=get_mm_plugin(name="intern_vl", image_token="<image>", video_token="<video>"),
1179
+ )
1180
+
1181
+
1182
+ register_template(
1183
+ name="intern_s1",
1184
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1185
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1186
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1187
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1188
+ stop_words=["<|im_end|>"],
1189
+ mm_plugin=get_mm_plugin(name="intern_vl", image_token="<image>", video_token="<video>"),
1190
+ )
1191
+
1192
+
1193
+ # copied from qwen template
1194
+ register_template(
1195
+ name="keye_vl",
1196
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1197
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1198
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1199
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1200
+ format_observation=StringFormatter(
1201
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1202
+ ),
1203
+ format_tools=ToolFormatter(tool_format="qwen"),
1204
+ stop_words=["<|im_end|>"],
1205
+ replace_eos=True,
1206
+ mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
1207
+ template_class=ReasoningTemplate,
1208
+ )
1209
+
1210
+
1211
+ register_template(
1212
+ name="kimi_vl",
1213
+ format_user=StringFormatter(
1214
+ slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
1215
+ ),
1216
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
1217
+ format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
1218
+ default_system="You are a helpful assistant",
1219
+ stop_words=["<|im_end|>"],
1220
+ thought_words=("◁think▷", "◁/think▷"),
1221
+ mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"),
1222
+ template_class=ReasoningTemplate,
1223
+ )
1224
+
1225
+
1226
+ register_template(
1227
+ name="lfm2",
1228
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1229
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1230
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1231
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
1232
+ format_observation=StringFormatter(
1233
+ slots=[
1234
+ "<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
1235
+ "<|im_start|>assistant\n"
1236
+ ]
1237
+ ),
1238
+ format_tools=ToolFormatter(tool_format="lfm2"),
1239
+ default_system="You are a helpful AI assistant.",
1240
+ stop_words=["<|im_end|>"],
1241
+ tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
1242
+ replace_eos=True,
1243
+ )
1244
+
1245
+
1246
+ register_template(
1247
+ name="lfm2_vl",
1248
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1249
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1250
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1251
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
1252
+ format_observation=StringFormatter(
1253
+ slots=[
1254
+ "<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
1255
+ "<|im_start|>assistant\n"
1256
+ ]
1257
+ ),
1258
+ format_tools=ToolFormatter(tool_format="lfm2"),
1259
+ default_system="You are a helpful multimodal assistant by Liquid AI.",
1260
+ stop_words=["<|im_end|>"],
1261
+ tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
1262
+ replace_eos=True,
1263
+ mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
1264
+ )
1265
+
1266
+
1267
+ register_template(
1268
+ name="llama2",
1269
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
1270
+ format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
1271
+ template_class=Llama2Template,
1272
+ )
1273
+
1274
+
1275
+ # copied from llama2 template
1276
+ register_template(
1277
+ name="llama2_zh",
1278
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
1279
+ format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
1280
+ default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
1281
+ template_class=Llama2Template,
1282
+ )
1283
+
1284
+
1285
+ register_template(
1286
+ name="llama3",
1287
+ format_user=StringFormatter(
1288
+ slots=[
1289
+ (
1290
+ "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
1291
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
1292
+ )
1293
+ ]
1294
+ ),
1295
+ format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
1296
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
1297
+ format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
1298
+ format_observation=StringFormatter(
1299
+ slots=[
1300
+ (
1301
+ "<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
1302
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
1303
+ )
1304
+ ]
1305
+ ),
1306
+ format_tools=ToolFormatter(tool_format="llama3"),
1307
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1308
+ stop_words=["<|eot_id|>", "<|eom_id|>"],
1309
+ replace_eos=True,
1310
+ )
1311
+
1312
+
1313
+ register_template(
1314
+ name="llama4",
1315
+ format_user=StringFormatter(
1316
+ slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
1317
+ ),
1318
+ format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
1319
+ format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
1320
+ format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
1321
+ format_observation=StringFormatter(
1322
+ slots=[
1323
+ "<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
1324
+ ]
1325
+ ),
1326
+ format_tools=ToolFormatter(tool_format="llama3"),
1327
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1328
+ stop_words=["<|eot|>", "<|eom|>"],
1329
+ replace_eos=True,
1330
+ mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
1331
+ )
1332
+
1333
+
1334
+ # copied from llama3 template
1335
+ register_template(
1336
+ name="mllama",
1337
+ format_user=StringFormatter(
1338
+ slots=[
1339
+ (
1340
+ "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
1341
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
1342
+ )
1343
+ ]
1344
+ ),
1345
+ format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
1346
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
1347
+ format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
1348
+ format_observation=StringFormatter(
1349
+ slots=[
1350
+ (
1351
+ "<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
1352
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
1353
+ )
1354
+ ]
1355
+ ),
1356
+ format_tools=ToolFormatter(tool_format="llama3"),
1357
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1358
+ stop_words=["<|eot_id|>", "<|eom_id|>"],
1359
+ replace_eos=True,
1360
+ mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
1361
+ )
1362
+
1363
+
1364
+ register_template(
1365
+ name="moonlight",
1366
+ format_user=StringFormatter(
1367
+ slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
1368
+ ),
1369
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
1370
+ format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
1371
+ default_system="You are a helpful assistant provided by Moonshot-AI.",
1372
+ stop_words=["<|im_end|>"],
1373
+ replace_eos=True,
1374
+ )
1375
+
1376
+
1377
+ # copied from vicuna template
1378
+ register_template(
1379
+ name="llava",
1380
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
1381
+ default_system=(
1382
+ "A chat between a curious user and an artificial intelligence assistant. "
1383
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
1384
+ ),
1385
+ mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
1386
+ )
1387
+
1388
+
1389
+ # copied from vicuna template
1390
+ register_template(
1391
+ name="llava_next",
1392
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
1393
+ default_system=(
1394
+ "A chat between a curious user and an artificial intelligence assistant. "
1395
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
1396
+ ),
1397
+ mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
1398
+ )
1399
+
1400
+
1401
+ # copied from llama3 template
1402
+ register_template(
1403
+ name="llava_next_llama3",
1404
+ format_user=StringFormatter(
1405
+ slots=[
1406
+ (
1407
+ "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
1408
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
1409
+ )
1410
+ ]
1411
+ ),
1412
+ format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
1413
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
1414
+ format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
1415
+ format_observation=StringFormatter(
1416
+ slots=[
1417
+ (
1418
+ "<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
1419
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
1420
+ )
1421
+ ]
1422
+ ),
1423
+ format_tools=ToolFormatter(tool_format="llama3"),
1424
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1425
+ stop_words=["<|eot_id|>", "<|eom_id|>"],
1426
+ replace_eos=True,
1427
+ mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
1428
+ )
1429
+
1430
+
1431
+ # copied from mistral template
1432
+ register_template(
1433
+ name="llava_next_mistral",
1434
+ format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
1435
+ format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
1436
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
1437
+ format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
1438
+ format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
1439
+ format_tools=ToolFormatter(tool_format="mistral"),
1440
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1441
+ mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
1442
+ template_class=Llama2Template,
1443
+ )
1444
+
1445
+
1446
+ # copied from qwen template
1447
+ register_template(
1448
+ name="llava_next_qwen",
1449
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1450
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1451
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1452
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1453
+ format_observation=StringFormatter(
1454
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1455
+ ),
1456
+ format_tools=ToolFormatter(tool_format="qwen"),
1457
+ default_system="You are a helpful assistant.",
1458
+ stop_words=["<|im_end|>"],
1459
+ replace_eos=True,
1460
+ mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
1461
+ )
1462
+
1463
+
1464
+ # copied from chatml template
1465
+ register_template(
1466
+ name="llava_next_yi",
1467
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1468
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1469
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1470
+ stop_words=["<|im_end|>"],
1471
+ mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
1472
+ )
1473
+
1474
+
1475
+ # copied from vicuna template
1476
+ register_template(
1477
+ name="llava_next_video",
1478
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
1479
+ default_system=(
1480
+ "A chat between a curious user and an artificial intelligence assistant. "
1481
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
1482
+ ),
1483
+ mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
1484
+ )
1485
+
1486
+
1487
+ # copied from mistral template
1488
+ register_template(
1489
+ name="llava_next_video_mistral",
1490
+ format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
1491
+ format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
1492
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
1493
+ format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
1494
+ format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
1495
+ format_tools=ToolFormatter(tool_format="mistral"),
1496
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1497
+ mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
1498
+ template_class=Llama2Template,
1499
+ )
1500
+
1501
+
1502
+ # copied from chatml template
1503
+ register_template(
1504
+ name="llava_next_video_yi",
1505
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1506
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1507
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1508
+ stop_words=["<|im_end|>"],
1509
+ mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
1510
+ )
1511
+
1512
+
1513
+ # copied from qwen template
1514
+ register_template(
1515
+ name="mimo",
1516
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1517
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1518
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1519
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1520
+ format_observation=StringFormatter(
1521
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1522
+ ),
1523
+ format_tools=ToolFormatter(tool_format="qwen"),
1524
+ default_system="You are a helpful assistant.",
1525
+ stop_words=["<|im_end|>"],
1526
+ replace_eos=True,
1527
+ template_class=ReasoningTemplate,
1528
+ )
1529
+
1530
+
1531
+ # copied from qwen template
1532
+ register_template(
1533
+ name="mimo_v2",
1534
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1535
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1536
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1537
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1538
+ format_observation=StringFormatter(
1539
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1540
+ ),
1541
+ format_tools=ToolFormatter(tool_format="qwen"),
1542
+ default_system="You are MiMo, a helpful AI assistant engineered by Xiaomi.",
1543
+ stop_words=["<|im_end|>"],
1544
+ replace_eos=True,
1545
+ thought_words=("<think>", "</think>"),
1546
+ template_class=ReasoningTemplate,
1547
+ )
1548
+
1549
+
1550
+ # copied from qwen2vl
1551
+ register_template(
1552
+ name="mimo_vl",
1553
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1554
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1555
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1556
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1557
+ format_observation=StringFormatter(
1558
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1559
+ ),
1560
+ format_tools=ToolFormatter(tool_format="qwen"),
1561
+ default_system="You are MiMo, an AI assistant developed by Xiaomi.",
1562
+ stop_words=["<|im_end|>"],
1563
+ replace_eos=True,
1564
+ mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
1565
+ template_class=ReasoningTemplate,
1566
+ )
1567
+
1568
+
1569
+ # copied from chatml template
1570
+ register_template(
1571
+ name="minicpm_v",
1572
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1573
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1574
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1575
+ stop_words=["<|im_end|>"],
1576
+ default_system="You are a helpful assistant.",
1577
+ mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
1578
+ )
1579
+
1580
+
1581
+ # copied from minicpm_v template
1582
+ register_template(
1583
+ name="minicpm_o",
1584
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1585
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1586
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1587
+ stop_words=["<|im_end|>"],
1588
+ default_system="You are a helpful assistant. You can accept audio and text input and output voice and text.",
1589
+ mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"),
1590
+ )
1591
+
1592
+
1593
+ register_template(
1594
+ name="minimax1",
1595
+ format_user=StringFormatter(
1596
+ slots=[
1597
+ "<beginning_of_sentence>user name=user\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
1598
+ ]
1599
+ ),
1600
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_sentence>\n"]),
1601
+ format_system=StringFormatter(
1602
+ slots=["<beginning_of_sentence>system ai_setting=assistant\n{{content}}<end_of_sentence>\n"]
1603
+ ),
1604
+ format_function=FunctionFormatter(slots=["{{content}}<end_of_sentence>\n"], tool_format="minimax1"),
1605
+ format_observation=StringFormatter(
1606
+ slots=[
1607
+ "<beginning_of_sentence>tool name=tools\n{{content}}<end_of_sentence>\n<beginning_of_sentence>ai name=assistant\n"
1608
+ ]
1609
+ ),
1610
+ format_tools=ToolFormatter(tool_format="minimax1"),
1611
+ default_system="You are a helpful assistant.",
1612
+ stop_words=["<end_of_sentence>"],
1613
+ )
1614
+
1615
+
1616
+ register_template(
1617
+ name="minimax2",
1618
+ format_user=StringFormatter(slots=["]~b]user\n{{content}}[e~[\n]~b]ai\n"]),
1619
+ format_assistant=StringFormatter(slots=["{{content}}[e~[\n"]),
1620
+ format_system=StringFormatter(slots=["]~!b[]~b]system\n{{content}}[e~[\n"]),
1621
+ format_function=FunctionFormatter(slots=["{{content}}[e~[\n"], tool_format="minimax2"),
1622
+ format_observation=StringFormatter(slots=["]~b]tool\n<response>{{content}}</response>[e~[\n]~b]ai\n"]),
1623
+ format_tools=ToolFormatter(tool_format="minimax2"),
1624
+ default_system="You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax.",
1625
+ stop_words=["[e~["],
1626
+ template_class=ReasoningTemplate,
1627
+ )
1628
+
1629
+
1630
+ # mistral tokenizer v3 tekken
1631
+ register_template(
1632
+ name="ministral",
1633
+ format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
1634
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
1635
+ format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
1636
+ format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
1637
+ format_tools=ToolFormatter(tool_format="mistral"),
1638
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1639
+ template_class=Llama2Template,
1640
+ )
1641
+
1642
+
1643
+ # mistral tokenizer v3
1644
+ register_template(
1645
+ name="mistral",
1646
+ format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
1647
+ format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
1648
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
1649
+ format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
1650
+ format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
1651
+ format_tools=ToolFormatter(tool_format="mistral"),
1652
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1653
+ template_class=Llama2Template,
1654
+ )
1655
+
1656
+
1657
+ # mistral tokenizer v7 tekken (copied from ministral)
1658
+ register_template(
1659
+ name="mistral_small",
1660
+ format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
1661
+ format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
1662
+ format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
1663
+ format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
1664
+ format_tools=ToolFormatter(tool_format="mistral"),
1665
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1666
+ mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
1667
+ )
1668
+
1669
+
1670
+ register_template(
1671
+ name="ministral3",
1672
+ format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
1673
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
1674
+ format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
1675
+ format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
1676
+ format_tools=ToolFormatter(tool_format="mistral"),
1677
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1678
+ template_class=Llama2Template,
1679
+ mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
1680
+ )
1681
+
1682
+
1683
+ register_template(
1684
+ name="olmo",
1685
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
1686
+ format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
1687
+ )
1688
+
1689
+
1690
+ register_template(
1691
+ name="openchat",
1692
+ format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
1693
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1694
+ )
1695
+
1696
+
1697
+ register_template(
1698
+ name="openchat-3.6",
1699
+ format_user=StringFormatter(
1700
+ slots=[
1701
+ (
1702
+ "<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
1703
+ "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
1704
+ )
1705
+ ]
1706
+ ),
1707
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1708
+ stop_words=["<|eot_id|>"],
1709
+ )
1710
+
1711
+
1712
+ # copied from chatml template
1713
+ register_template(
1714
+ name="opencoder",
1715
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1716
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1717
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1718
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1719
+ default_system="You are OpenCoder, created by OpenCoder Team.",
1720
+ stop_words=["<|im_end|>"],
1721
+ )
1722
+
1723
+
1724
+ register_template(
1725
+ name="paligemma",
1726
+ format_user=StringFormatter(slots=["{{content}}\n"]),
1727
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1728
+ mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
1729
+ template_class=Llama2Template,
1730
+ )
1731
+
1732
+
1733
+ # copied from gemma template
1734
+ register_template(
1735
+ name="paligemma_chat",
1736
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
1737
+ format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
1738
+ format_observation=StringFormatter(
1739
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
1740
+ ),
1741
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1742
+ stop_words=["<end_of_turn>"],
1743
+ replace_eos=True,
1744
+ mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
1745
+ template_class=Llama2Template,
1746
+ )
1747
+
1748
+
1749
+ register_template(
1750
+ name="phi",
1751
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
1752
+ format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
1753
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
1754
+ stop_words=["<|end|>"],
1755
+ replace_eos=True,
1756
+ )
1757
+
1758
+
1759
+ register_template(
1760
+ name="phi_small",
1761
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
1762
+ format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
1763
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
1764
+ format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
1765
+ stop_words=["<|end|>"],
1766
+ replace_eos=True,
1767
+ )
1768
+
1769
+
1770
+ register_template(
1771
+ name="phi4",
1772
+ format_user=StringFormatter(
1773
+ slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
1774
+ ),
1775
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
1776
+ format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
1777
+ stop_words=["<|im_end|>"],
1778
+ replace_eos=True,
1779
+ )
1780
+
1781
+
1782
+ register_template(
1783
+ name="phi4_mini",
1784
+ format_user=StringFormatter(slots=["<|user|>{{content}}<|end|><|assistant|>"]),
1785
+ format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
1786
+ format_system=StringFormatter(slots=["<|system|>{{content}}<|end|>"]),
1787
+ format_tools=StringFormatter(slots=["<|tool|>{{content}}<|/tool|>"]),
1788
+ stop_words=["<|end|>"],
1789
+ replace_eos=True,
1790
+ )
1791
+
1792
+
1793
+ # copied from ministral template
1794
+ register_template(
1795
+ name="pixtral",
1796
+ format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
1797
+ format_system=StringFormatter(slots=["{{content}}\n\n"]),
1798
+ format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
1799
+ format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
1800
+ format_tools=ToolFormatter(tool_format="mistral"),
1801
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
1802
+ mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
1803
+ template_class=Llama2Template,
1804
+ )
1805
+
1806
+
1807
+ # copied from chatml template
1808
+ register_template(
1809
+ name="qwen",
1810
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1811
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1812
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1813
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1814
+ format_observation=StringFormatter(
1815
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1816
+ ),
1817
+ format_tools=ToolFormatter(tool_format="qwen"),
1818
+ default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
1819
+ stop_words=["<|im_end|>"],
1820
+ replace_eos=True,
1821
+ )
1822
+
1823
+
1824
+ # copied from qwen template
1825
+ register_template(
1826
+ name="qwen3",
1827
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1828
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1829
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1830
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1831
+ format_observation=StringFormatter(
1832
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1833
+ ),
1834
+ format_tools=ToolFormatter(tool_format="qwen"),
1835
+ stop_words=["<|im_end|>"],
1836
+ replace_eos=True,
1837
+ template_class=ReasoningTemplate,
1838
+ )
1839
+
1840
+
1841
+ # copied from qwen template
1842
+ register_template(
1843
+ name="qwen3_nothink",
1844
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"]),
1845
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1846
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1847
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1848
+ format_observation=StringFormatter(
1849
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1850
+ ),
1851
+ format_tools=ToolFormatter(tool_format="qwen"),
1852
+ stop_words=["<|im_end|>", "<think>", "</think>"],
1853
+ replace_eos=True,
1854
+ )
1855
+
1856
+
1857
+ # copied from chatml template
1858
+ register_template(
1859
+ name="qwen2_audio",
1860
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1861
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1862
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1863
+ default_system="You are a helpful assistant.",
1864
+ stop_words=["<|im_end|>"],
1865
+ replace_eos=True,
1866
+ mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
1867
+ )
1868
+
1869
+
1870
+ # copied from qwen template
1871
+ register_template(
1872
+ name="qwen2_omni",
1873
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1874
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1875
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1876
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1877
+ format_observation=StringFormatter(
1878
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1879
+ ),
1880
+ format_tools=ToolFormatter(tool_format="qwen"),
1881
+ default_system="You are a helpful assistant.",
1882
+ stop_words=["<|im_end|>"],
1883
+ replace_eos=True,
1884
+ mm_plugin=get_mm_plugin(
1885
+ name="qwen2_omni",
1886
+ image_token="<|IMAGE|>",
1887
+ video_token="<|VIDEO|>",
1888
+ audio_token="<|AUDIO|>",
1889
+ vision_bos_token="<|vision_bos|>",
1890
+ vision_eos_token="<|vision_eos|>",
1891
+ audio_bos_token="<|audio_bos|>",
1892
+ audio_eos_token="<|audio_eos|>",
1893
+ ),
1894
+ )
1895
+
1896
+
1897
+ register_template(
1898
+ name="qwen3_omni",
1899
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1900
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1901
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1902
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1903
+ format_observation=StringFormatter(
1904
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1905
+ ),
1906
+ format_tools=ToolFormatter(tool_format="qwen"),
1907
+ stop_words=["<|im_end|>"],
1908
+ replace_eos=True,
1909
+ mm_plugin=get_mm_plugin(
1910
+ name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
1911
+ ),
1912
+ template_class=ReasoningTemplate,
1913
+ )
1914
+
1915
+
1916
+ register_template(
1917
+ name="qwen3_omni_nothink",
1918
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1919
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1920
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1921
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1922
+ format_observation=StringFormatter(
1923
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1924
+ ),
1925
+ format_tools=ToolFormatter(tool_format="qwen"),
1926
+ stop_words=["<|im_end|>"],
1927
+ replace_eos=True,
1928
+ mm_plugin=get_mm_plugin(
1929
+ name="qwen2_omni", image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>"
1930
+ ),
1931
+ )
1932
+
1933
+
1934
+ # copied from qwen template
1935
+ register_template(
1936
+ name="qwen2_vl",
1937
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1938
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1939
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1940
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1941
+ format_observation=StringFormatter(
1942
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1943
+ ),
1944
+ format_tools=ToolFormatter(tool_format="qwen"),
1945
+ default_system="You are a helpful assistant.",
1946
+ stop_words=["<|im_end|>"],
1947
+ replace_eos=True,
1948
+ mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
1949
+ )
1950
+
1951
+
1952
+ # copied from qwen template
1953
+ register_template(
1954
+ name="qwen3_vl",
1955
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1956
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1957
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1958
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1959
+ format_observation=StringFormatter(
1960
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1961
+ ),
1962
+ format_tools=ToolFormatter(tool_format="qwen"),
1963
+ stop_words=["<|im_end|>"],
1964
+ replace_eos=True,
1965
+ mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
1966
+ template_class=ReasoningTemplate,
1967
+ )
1968
+
1969
+
1970
+ # copied from qwen template
1971
+ register_template(
1972
+ name="qwen3_vl_nothink",
1973
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
1974
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1975
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1976
+ format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
1977
+ format_observation=StringFormatter(
1978
+ slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
1979
+ ),
1980
+ format_tools=ToolFormatter(tool_format="qwen"),
1981
+ stop_words=["<|im_end|>"],
1982
+ replace_eos=True,
1983
+ mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
1984
+ )
1985
+
1986
+
1987
+ register_template(
1988
+ name="sailor",
1989
+ format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
1990
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
1991
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
1992
+ default_system=(
1993
+ "You are an AI assistant named Sailor created by Sea AI Lab. "
1994
+ "Your answer should be friendly, unbiased, faithful, informative and detailed."
1995
+ ),
1996
+ stop_words=["<|im_end|>"],
1997
+ )
1998
+
1999
+
2000
+ register_template(
2001
+ name="seed_coder",
2002
+ format_user=StringFormatter(
2003
+ slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
2004
+ ),
2005
+ format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
2006
+ default_system=(
2007
+ "You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
2008
+ "and you only answer questions related to computer science. For politically sensitive questions, "
2009
+ "security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\n"
2010
+ ),
2011
+ )
2012
+
2013
+
2014
+ # copied from seed_coder
2015
+ register_template(
2016
+ name="seed_oss",
2017
+ format_user=StringFormatter(
2018
+ slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
2019
+ ),
2020
+ format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
2021
+ format_function=FunctionFormatter(slots=[{"bos_token"}, "\n{{content}}", {"eos_token"}], tool_format="seed_oss"),
2022
+ format_tools=ToolFormatter(tool_format="seed_oss"),
2023
+ template_class=ReasoningTemplate,
2024
+ thought_words=("<seed:think>", "</seed:think>"),
2025
+ )
2026
+
2027
+
2028
+ register_template(
2029
+ name="smollm",
2030
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
2031
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
2032
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
2033
+ stop_words=["<|im_end|>"],
2034
+ )
2035
+
2036
+
2037
+ register_template(
2038
+ name="smollm2",
2039
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
2040
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
2041
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
2042
+ stop_words=["<|im_end|>"],
2043
+ default_system="You are a helpful AI assistant named SmolLM, trained by Hugging Face.",
2044
+ )
2045
+
2046
+
2047
+ register_template(
2048
+ name="solar",
2049
+ format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
2050
+ format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
2051
+ efficient_eos=True,
2052
+ )
2053
+
2054
+
2055
+ register_template(
2056
+ name="starchat",
2057
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
2058
+ format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
2059
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
2060
+ stop_words=["<|end|>"],
2061
+ )
2062
+
2063
+
2064
+ register_template(
2065
+ name="telechat2",
2066
+ format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
2067
+ format_system=StringFormatter(slots=["<_system>{{content}}"]),
2068
+ default_system=(
2069
+ "你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。"
2070
+ ),
2071
+ )
2072
+
2073
+
2074
+ register_template(
2075
+ name="vicuna",
2076
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
2077
+ default_system=(
2078
+ "A chat between a curious user and an artificial intelligence assistant. "
2079
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
2080
+ ),
2081
+ replace_jinja_template=True,
2082
+ )
2083
+
2084
+
2085
+ register_template(
2086
+ name="video_llava",
2087
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
2088
+ default_system=(
2089
+ "A chat between a curious user and an artificial intelligence assistant. "
2090
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
2091
+ ),
2092
+ mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
2093
+ )
2094
+
2095
+
2096
+ register_template(
2097
+ name="xuanyuan",
2098
+ format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
2099
+ default_system=(
2100
+ "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
2101
+ "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
2102
+ "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
2103
+ ),
2104
+ )
2105
+
2106
+
2107
+ # copied from chatml template
2108
+ register_template(
2109
+ name="yi",
2110
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
2111
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
2112
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
2113
+ stop_words=["<|im_end|>"],
2114
+ )
2115
+
2116
+
2117
+ register_template(
2118
+ name="yi_vl",
2119
+ format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
2120
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
2121
+ default_system=(
2122
+ "This is a chat between an inquisitive human and an AI assistant. "
2123
+ "Assume the role of the AI assistant. Read all the images carefully, "
2124
+ "and respond to the human's questions with informative, helpful, detailed and polite answers. "
2125
+ "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
2126
+ "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的��礼貌的回答。\n\n"
2127
+ ),
2128
+ stop_words=["###"],
2129
+ efficient_eos=True,
2130
+ mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
2131
+ )
2132
+
2133
+
2134
+ register_template(
2135
+ name="youtu",
2136
+ format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
2137
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>"]),
2138
+ format_system=StringFormatter(slots=["{{content}}"]),
2139
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="default"),
2140
+ format_observation=StringFormatter(slots=["<tool_response>\n{{content}}\n</tool_response><|Assistant|>"]),
2141
+ format_tools=ToolFormatter(tool_format="default"),
2142
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
2143
+ stop_words=["<|end_of_text|>"],
2144
+ replace_eos=True,
2145
+ template_class=ReasoningTemplate,
2146
+ )
2147
+
2148
+
2149
+ register_template(
2150
+ name="youtu_vl",
2151
+ format_user=StringFormatter(
2152
+ slots=["<|begin_of_text|>user\n{{content}}<|end_of_text|>\n<|begin_of_text|>assistant\n"]
2153
+ ),
2154
+ format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
2155
+ format_system=StringFormatter(slots=["<|begin_of_text|>system\n{{content}}<|end_of_text|>\n"]),
2156
+ default_system="You are a helpful assistant.",
2157
+ stop_words=["<|end_of_text|>"],
2158
+ mm_plugin=get_mm_plugin(name="youtu_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
2159
+ )
2160
+
2161
+
2162
+ register_template(
2163
+ name="yuan",
2164
+ format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
2165
+ format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
2166
+ stop_words=["<eod>"],
2167
+ )
2168
+
2169
+
2170
+ register_template(
2171
+ name="zephyr",
2172
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
2173
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
2174
+ default_system="You are Zephyr, a helpful assistant.",
2175
+ )
LlamaFactory/src/llamafactory/data/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .collator import (
16
+ KTODataCollatorWithPadding,
17
+ MultiModalDataCollatorForSeq2Seq,
18
+ PairwiseDataCollatorWithPadding,
19
+ SFTDataCollatorWith4DAttentionMask,
20
+ )
21
+ from .data_utils import Role, split_dataset
22
+ from .loader import get_dataset
23
+ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
24
+
25
+
26
+ __all__ = [
27
+ "TEMPLATES",
28
+ "KTODataCollatorWithPadding",
29
+ "MultiModalDataCollatorForSeq2Seq",
30
+ "PairwiseDataCollatorWithPadding",
31
+ "Role",
32
+ "SFTDataCollatorWith4DAttentionMask",
33
+ "Template",
34
+ "get_dataset",
35
+ "get_template_and_fix_tokenizer",
36
+ "split_dataset",
37
+ ]
LlamaFactory/src/llamafactory/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (770 Bytes). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (644 Bytes). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/collator.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/collator.cpython-312.pyc ADDED
Binary file (15.3 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/converter.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/converter.cpython-312.pyc ADDED
Binary file (21.7 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/data_utils.cpython-312.pyc ADDED
Binary file (8.68 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/formatter.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/formatter.cpython-312.pyc ADDED
Binary file (8.88 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/loader.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
LlamaFactory/src/llamafactory/data/__pycache__/loader.cpython-312.pyc ADDED
Binary file (14.9 kB). View file